vvvvvv / src /streamlit_app.py
Reyall's picture
Upload 9 files
84187cf verified
import os
import sys
import streamlit as st
from transformers import BertTokenizer, BertForSequenceClassification # Burada BertTokenizer istifadə edirik
import torch
import pickle
import random
from collections import defaultdict
import requests
# GitHub-dan fayl yükləmək üçün funksiyanın təyin edilməsi
def download_label_encoder():
url = "https://github.com/AxundovReyal/nlp-disease/raw/main/label_encoder.pkl"
headers = {
"Authorization": f"token {os.getenv('GITHUB_TOKEN')}" # GitHub personal access token mühit dəyişəni olaraq qeyd olunmalı
}
response = requests.get(url, headers=headers)
if response.status_code == 200:
with open("label_encoder.pkl", "wb") as f:
f.write(response.content)
print("label_encoder.pkl faylı uğurla yükləndi.")
else:
raise Exception(f"Fayl yüklənə bilmədi, error kodu: {response.status_code}")
# Modelin və label_encoder-in yüklənməsi
@st.cache_resource
def load_model():
# GitHub-dan label_encoder yükləmək
download_label_encoder()
# Label encoder yüklənməsi əvvəlcə edilir
with open("label_encoder.pkl", "rb") as f:
label_encoder = pickle.load(f)
# Burada AutoTokenizer əvəzinə BertTokenizer istifadə edirik
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') # BERT Tokenizer
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=len(label_encoder.classes_)) # BERT Model
model.eval()
return tokenizer, model, label_encoder
tokenizer, model, label_encoder = load_model()
st.title("Disease NLP Classifier")
text = st.text_area("Enter your symptoms separated by commas (e.g. fever, cough, headache):")
def predict(text_input):
inputs = tokenizer(text_input, return_tensors="pt", truncation=True, padding=True, max_length=128)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze()
return probs
if st.button("Predict"):
if not text.strip():
st.warning("Please enter some symptoms!")
else:
symptoms = [s.strip() for s in text.split(",") if s.strip()]
if not symptoms:
st.warning("Please enter valid symptoms separated by commas!")
else:
agg_probs = defaultdict(float)
n_shuffles = 10
for _ in range(n_shuffles):
random.shuffle(symptoms)
shuffled_text = ", ".join(symptoms)
probs = predict(shuffled_text)
for i, p in enumerate(probs):
agg_probs[i] += p.item()
for k in agg_probs:
agg_probs[k] /= n_shuffles
top_3 = sorted(agg_probs.items(), key=lambda x: x[1], reverse=True)[:3]
st.subheader("Top 3 Predicted Diseases (averaged over shuffled inputs):")
for idx, prob in top_3:
label = label_encoder.classes_[idx] # Etiketləri doğru alırıq
st.write(f"**{label}** — Probability: `{prob * 100:.2f}%`")
# Render port düzəlişi
if __name__ == "__main__":
port = int(os.environ.get("PORT", 8501))
sys.argv = ["streamlit", "run", "streamlit_app.py", f"--server.port={port}", "--server.address=0.0.0.0"]
from streamlit.web.cli import main
sys.exit(main())