nlp-disease-classification / src /streamlit_app.py
Reyall's picture
Upload 2 files
a99556d verified
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import pickle
import random
from collections import defaultdict
import requests
import os
@st.cache_resource
def load_model():
MODEL_REPO = "Reyall/nlp-disease-model"
HF_TOKEN = os.getenv("HF_TOKEN")
# Token varsa use_auth_token parametrini ver, yoxsa vermə
tokenizer = AutoTokenizer.from_pretrained(
MODEL_REPO,
use_auth_token=HF_TOKEN if HF_TOKEN else None
)
model = AutoModelForSequenceClassification.from_pretrained(
MODEL_REPO,
use_auth_token=HF_TOKEN if HF_TOKEN else None
)
model.eval()
# label_encoder.pkl faylını Hugging Face-dən çək, token varsa başlıqda əlavə et
headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}
url = f"https://huggingface.co/{MODEL_REPO}/resolve/main/label_encoder.pkl"
r = requests.get(url, headers=headers)
r.raise_for_status()
label_encoder = pickle.loads(r.content)
return tokenizer, model, label_encoder
tokenizer, model, label_encoder = load_model()
st.title("Disease NLP Classifier")
text = st.text_area("Enter your symptoms separated by commas (e.g. fever, cough, headache):")
def predict(text_input):
inputs = tokenizer(text_input, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze()
return probs
if st.button("Predict"):
if not text.strip():
st.warning("Please enter some symptoms!")
else:
symptoms = [s.strip() for s in text.split(",") if s.strip()]
if not symptoms:
st.warning("Please enter valid symptoms separated by commas!")
else:
agg_probs = defaultdict(float)
n_shuffles = 10
for _ in range(n_shuffles):
random.shuffle(symptoms)
shuffled_text = ", ".join(symptoms)
probs = predict(shuffled_text)
for i, p in enumerate(probs):
agg_probs[i] += p.item()
for k in agg_probs:
agg_probs[k] /= n_shuffles
top_3 = sorted(agg_probs.items(), key=lambda x: x[1], reverse=True)[:3]
st.subheader("Top 3 Predicted Diseases (averaged over shuffled inputs):")
for idx, prob in top_3:
label = label_encoder.inverse_transform([idx])[0]
st.write(f"**{label}** — Probability: `{prob * 100:.2f}%`")