| | import gradio as gr |
| | from transformers import BertTokenizer, BertForSequenceClassification |
| | import torch |
| | import pickle |
| | from collections import defaultdict |
| | import random |
| | import os |
| | from safetensors.torch import load_file |
| |
|
| | |
| | def load_model(): |
| | try: |
| | |
| | with open("best_model/label_encoder.pkl", "rb") as f: |
| | label_encoder = pickle.load(f) |
| | |
| | |
| | tokenizer = BertTokenizer.from_pretrained("best_model") |
| | |
| | |
| | model = BertForSequenceClassification.from_pretrained("best_model", use_safetensors=True) |
| | model.eval() |
| | |
| | print(f"Model uğurla yükləndi. Label sayı: {len(label_encoder.classes_)}") |
| | return tokenizer, model, label_encoder |
| | |
| | except Exception as e: |
| | print(f"Model yüklənmə xətası: {e}") |
| | |
| | if os.path.exists("best_model"): |
| | files = os.listdir("best_model") |
| | print(f"best_model qovluğundakı fayllar: {files}") |
| | else: |
| | print("best_model qovluğu mövcud deyil") |
| | |
| | return None, None, None |
| |
|
| | |
| | tokenizer, model, label_encoder = load_model() |
| |
|
| | |
| | def predict_disease(text): |
| | if tokenizer is None or model is None or label_encoder is None: |
| | return "❌ Model yüklənməyib! Xəta var." |
| | |
| | if not text.strip(): |
| | return "⚠️ Please enter some symptoms!" |
| | |
| | symptoms = [s.strip() for s in text.split(",") if s.strip()] |
| | if not symptoms: |
| | return "⚠️ Please enter valid symptoms separated by commas!" |
| | |
| | try: |
| | agg_probs = defaultdict(float) |
| | n_shuffles = 10 |
| | |
| | for i in range(n_shuffles): |
| | random.shuffle(symptoms) |
| | shuffled_text = ", ".join(symptoms) |
| | |
| | inputs = tokenizer( |
| | shuffled_text, |
| | return_tensors="pt", |
| | truncation=True, |
| | padding=True, |
| | max_length=128 |
| | ) |
| | |
| | with torch.no_grad(): |
| | outputs = model(**inputs) |
| | probs = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze() |
| | |
| | for idx, p in enumerate(probs): |
| | agg_probs[idx] += p.item() |
| | |
| | |
| | for k in agg_probs: |
| | agg_probs[k] /= n_shuffles |
| | |
| | |
| | top_3 = sorted(agg_probs.items(), key=lambda x: x[1], reverse=True)[:3] |
| | |
| | results = ["🏥 Top 3 Predicted Diseases:\n"] |
| | for idx, prob in top_3: |
| | label = label_encoder.classes_[idx] |
| | results.append(f"• **{label}** — Probability: {prob*100:.2f}%") |
| | |
| | return "\n".join(results) |
| | |
| | except Exception as e: |
| | return f"❌ Prediction xətası: {str(e)}" |
| |
|
| | |
| | iface = gr.Interface( |
| | fn=predict_disease, |
| | inputs=gr.Textbox( |
| | lines=2, |
| | placeholder="fever, cough, headache, shortness of breath", |
| | label="Enter your symptoms (comma separated)" |
| | ), |
| | outputs=gr.Textbox(label="Predicted Diseases"), |
| | title="🏥 Disease NLP Classifier", |
| | description="Enter your symptoms separated by commas and get top 3 predicted diseases with confidence scores.", |
| | examples=[ |
| | ["fever, cough, headache"], |
| | ["stomach pain, nausea, vomiting"], |
| | ["chest pain, shortness of breath"], |
| | ["dizziness, fatigue, weakness"], |
| | ["skin rash, itching, redness"] |
| | ] |
| | ) |
| |
|
| | |
| | if __name__ == "__main__": |
| | if tokenizer and model and label_encoder: |
| | print("✅ Model hazırdır, Gradio başladılır...") |
| | iface.launch( |
| | server_name="0.0.0.0", |
| | server_port=int(os.environ.get("PORT", 7860)), |
| | share=True |
| | ) |
| | else: |
| | print("❌ Model yüklənmədi, Gradio başladıla bilmir!") |
| | print("\nDebug məlumatları:") |
| | print(f"Hazırkı qovluq: {os.getcwd()}") |
| | print(f"Qovluq məzmunu: {os.listdir('.')}") |
| | |
| | |
| | def debug_info(): |
| | return f"Debug məlumatları:\nHazırkı qovluq: {os.getcwd()}\nFayllar: {os.listdir('.')}" |
| | |
| | debug_iface = gr.Interface( |
| | fn=debug_info, |
| | inputs=gr.Textbox(placeholder="Debug üçün hər hansı mətn yazın"), |
| | outputs=gr.Textbox(), |
| | title="🔧 Debug Interface" |
| | ) |
| | |
| | debug_iface.launch( |
| | server_name="0.0.0.0", |
| | server_port=int(os.environ.get("PORT", 7860)) |
| | ) |