Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| import pickle | |
| import random | |
| from collections import defaultdict | |
| import requests | |
| import os | |
| def load_model(): | |
| MODEL_REPO = "Reyall/nlp-disease-model" | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| # Token varsa use_auth_token parametrini ver, yoxsa vermə | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_REPO, | |
| use_auth_token=HF_TOKEN if HF_TOKEN else None | |
| ) | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| MODEL_REPO, | |
| use_auth_token=HF_TOKEN if HF_TOKEN else None | |
| ) | |
| model.eval() | |
| # label_encoder.pkl faylını Hugging Face-dən çək, token varsa başlıqda əlavə et | |
| headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {} | |
| url = f"https://huggingface.co/{MODEL_REPO}/resolve/main/label_encoder.pkl" | |
| r = requests.get(url, headers=headers) | |
| r.raise_for_status() | |
| label_encoder = pickle.loads(r.content) | |
| return tokenizer, model, label_encoder | |
| tokenizer, model, label_encoder = load_model() | |
| st.title("Disease NLP Classifier") | |
| text = st.text_area("Enter your symptoms separated by commas (e.g. fever, cough, headache):") | |
| def predict(text_input): | |
| inputs = tokenizer(text_input, return_tensors="pt", truncation=True, padding=True) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probs = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze() | |
| return probs | |
| if st.button("Predict"): | |
| if not text.strip(): | |
| st.warning("Please enter some symptoms!") | |
| else: | |
| symptoms = [s.strip() for s in text.split(",") if s.strip()] | |
| if not symptoms: | |
| st.warning("Please enter valid symptoms separated by commas!") | |
| else: | |
| agg_probs = defaultdict(float) | |
| n_shuffles = 10 | |
| for _ in range(n_shuffles): | |
| random.shuffle(symptoms) | |
| shuffled_text = ", ".join(symptoms) | |
| probs = predict(shuffled_text) | |
| for i, p in enumerate(probs): | |
| agg_probs[i] += p.item() | |
| for k in agg_probs: | |
| agg_probs[k] /= n_shuffles | |
| top_3 = sorted(agg_probs.items(), key=lambda x: x[1], reverse=True)[:3] | |
| st.subheader("Top 3 Predicted Diseases (averaged over shuffled inputs):") | |
| for idx, prob in top_3: | |
| label = label_encoder.inverse_transform([idx])[0] | |
| st.write(f"**{label}** — Probability: `{prob * 100:.2f}%`") | |