import os import sys import streamlit as st from transformers import BertTokenizer, BertForSequenceClassification # Burada BertTokenizer istifadə edirik import torch import pickle import random from collections import defaultdict import requests # GitHub-dan fayl yükləmək üçün funksiyanın təyin edilməsi def download_label_encoder(): url = "https://github.com/AxundovReyal/nlp-disease/raw/main/label_encoder.pkl" headers = { "Authorization": f"token {os.getenv('GITHUB_TOKEN')}" # GitHub personal access token mühit dəyişəni olaraq qeyd olunmalı } response = requests.get(url, headers=headers) if response.status_code == 200: with open("label_encoder.pkl", "wb") as f: f.write(response.content) print("label_encoder.pkl faylı uğurla yükləndi.") else: raise Exception(f"Fayl yüklənə bilmədi, error kodu: {response.status_code}") # Modelin və label_encoder-in yüklənməsi @st.cache_resource def load_model(): # GitHub-dan label_encoder yükləmək download_label_encoder() # Label encoder yüklənməsi əvvəlcə edilir with open("label_encoder.pkl", "rb") as f: label_encoder = pickle.load(f) # Burada AutoTokenizer əvəzinə BertTokenizer istifadə edirik tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') # BERT Tokenizer model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=len(label_encoder.classes_)) # BERT Model model.eval() return tokenizer, model, label_encoder tokenizer, model, label_encoder = load_model() st.title("Disease NLP Classifier") text = st.text_area("Enter your symptoms separated by commas (e.g. fever, cough, headache):") def predict(text_input): inputs = tokenizer(text_input, return_tensors="pt", truncation=True, padding=True, max_length=128) with torch.no_grad(): outputs = model(**inputs) probs = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze() return probs if st.button("Predict"): if not text.strip(): st.warning("Please enter some symptoms!") else: symptoms = [s.strip() for s in text.split(",") if s.strip()] if not symptoms: st.warning("Please enter valid symptoms separated by commas!") else: agg_probs = defaultdict(float) n_shuffles = 10 for _ in range(n_shuffles): random.shuffle(symptoms) shuffled_text = ", ".join(symptoms) probs = predict(shuffled_text) for i, p in enumerate(probs): agg_probs[i] += p.item() for k in agg_probs: agg_probs[k] /= n_shuffles top_3 = sorted(agg_probs.items(), key=lambda x: x[1], reverse=True)[:3] st.subheader("Top 3 Predicted Diseases (averaged over shuffled inputs):") for idx, prob in top_3: label = label_encoder.classes_[idx] # Etiketləri doğru alırıq st.write(f"**{label}** — Probability: `{prob * 100:.2f}%`") # Render port düzəlişi if __name__ == "__main__": port = int(os.environ.get("PORT", 8501)) sys.argv = ["streamlit", "run", "streamlit_app.py", f"--server.port={port}", "--server.address=0.0.0.0"] from streamlit.web.cli import main sys.exit(main())