import streamlit as st from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch import pickle import random from collections import defaultdict import requests import os @st.cache_resource def load_model(): MODEL_REPO = "Reyall/nlp-disease-model" HF_TOKEN = os.getenv("HF_TOKEN") # Token varsa use_auth_token parametrini ver, yoxsa vermə tokenizer = AutoTokenizer.from_pretrained( MODEL_REPO, use_auth_token=HF_TOKEN if HF_TOKEN else None ) model = AutoModelForSequenceClassification.from_pretrained( MODEL_REPO, use_auth_token=HF_TOKEN if HF_TOKEN else None ) model.eval() # label_encoder.pkl faylını Hugging Face-dən çək, token varsa başlıqda əlavə et headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {} url = f"https://huggingface.co/{MODEL_REPO}/resolve/main/label_encoder.pkl" r = requests.get(url, headers=headers) r.raise_for_status() label_encoder = pickle.loads(r.content) return tokenizer, model, label_encoder tokenizer, model, label_encoder = load_model() st.title("Disease NLP Classifier") text = st.text_area("Enter your symptoms separated by commas (e.g. fever, cough, headache):") def predict(text_input): inputs = tokenizer(text_input, return_tensors="pt", truncation=True, padding=True) with torch.no_grad(): outputs = model(**inputs) probs = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze() return probs if st.button("Predict"): if not text.strip(): st.warning("Please enter some symptoms!") else: symptoms = [s.strip() for s in text.split(",") if s.strip()] if not symptoms: st.warning("Please enter valid symptoms separated by commas!") else: agg_probs = defaultdict(float) n_shuffles = 10 for _ in range(n_shuffles): random.shuffle(symptoms) shuffled_text = ", ".join(symptoms) probs = predict(shuffled_text) for i, p in enumerate(probs): agg_probs[i] += p.item() for k in agg_probs: agg_probs[k] /= n_shuffles top_3 = sorted(agg_probs.items(), key=lambda x: x[1], reverse=True)[:3] st.subheader("Top 3 Predicted Diseases (averaged over shuffled inputs):") for idx, prob in top_3: label = label_encoder.inverse_transform([idx])[0] st.write(f"**{label}** — Probability: `{prob * 100:.2f}%`")