|
|
import os |
|
|
import sys |
|
|
import streamlit as st |
|
|
from transformers import BertTokenizer, BertForSequenceClassification |
|
|
import torch |
|
|
import pickle |
|
|
import random |
|
|
from collections import defaultdict |
|
|
import requests |
|
|
|
|
|
|
|
|
def download_label_encoder(): |
|
|
url = "https://github.com/AxundovReyal/nlp-disease/raw/main/label_encoder.pkl" |
|
|
headers = { |
|
|
"Authorization": f"token {os.getenv('GITHUB_TOKEN')}" |
|
|
} |
|
|
response = requests.get(url, headers=headers) |
|
|
|
|
|
if response.status_code == 200: |
|
|
with open("label_encoder.pkl", "wb") as f: |
|
|
f.write(response.content) |
|
|
print("label_encoder.pkl faylı uğurla yükləndi.") |
|
|
else: |
|
|
raise Exception(f"Fayl yüklənə bilmədi, error kodu: {response.status_code}") |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_model(): |
|
|
|
|
|
download_label_encoder() |
|
|
|
|
|
|
|
|
with open("label_encoder.pkl", "rb") as f: |
|
|
label_encoder = pickle.load(f) |
|
|
|
|
|
|
|
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
|
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=len(label_encoder.classes_)) |
|
|
|
|
|
model.eval() |
|
|
|
|
|
return tokenizer, model, label_encoder |
|
|
|
|
|
tokenizer, model, label_encoder = load_model() |
|
|
|
|
|
st.title("Disease NLP Classifier") |
|
|
|
|
|
text = st.text_area("Enter your symptoms separated by commas (e.g. fever, cough, headache):") |
|
|
|
|
|
def predict(text_input): |
|
|
inputs = tokenizer(text_input, return_tensors="pt", truncation=True, padding=True, max_length=128) |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
probs = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze() |
|
|
return probs |
|
|
|
|
|
if st.button("Predict"): |
|
|
if not text.strip(): |
|
|
st.warning("Please enter some symptoms!") |
|
|
else: |
|
|
symptoms = [s.strip() for s in text.split(",") if s.strip()] |
|
|
if not symptoms: |
|
|
st.warning("Please enter valid symptoms separated by commas!") |
|
|
else: |
|
|
agg_probs = defaultdict(float) |
|
|
n_shuffles = 10 |
|
|
for _ in range(n_shuffles): |
|
|
random.shuffle(symptoms) |
|
|
shuffled_text = ", ".join(symptoms) |
|
|
probs = predict(shuffled_text) |
|
|
for i, p in enumerate(probs): |
|
|
agg_probs[i] += p.item() |
|
|
for k in agg_probs: |
|
|
agg_probs[k] /= n_shuffles |
|
|
top_3 = sorted(agg_probs.items(), key=lambda x: x[1], reverse=True)[:3] |
|
|
|
|
|
st.subheader("Top 3 Predicted Diseases (averaged over shuffled inputs):") |
|
|
for idx, prob in top_3: |
|
|
label = label_encoder.classes_[idx] |
|
|
st.write(f"**{label}** — Probability: `{prob * 100:.2f}%`") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
port = int(os.environ.get("PORT", 8501)) |
|
|
sys.argv = ["streamlit", "run", "streamlit_app.py", f"--server.port={port}", "--server.address=0.0.0.0"] |
|
|
from streamlit.web.cli import main |
|
|
sys.exit(main()) |
|
|
|