import gradio as gr from transformers import BertTokenizer, BertForSequenceClassification import torch import pickle from collections import defaultdict import random import os # Model və label_encoder yüklənməsi def load_model(): try: # Label encoder with open("label_encoder.pkl", "rb") as f: label_encoder = pickle.load(f) # Tokenizer və Model - iki variant sınayırıq try: # Variant 1: best_model qovluğu tokenizer = BertTokenizer.from_pretrained("best_model") model = BertForSequenceClassification.from_pretrained("best_model") except: # Variant 2: Ana qovluq tokenizer = BertTokenizer.from_pretrained(".") model = BertForSequenceClassification.from_pretrained(".") model.eval() print(f"✅ Model yükləndi. Label sayı: {len(label_encoder.classes_)}") return tokenizer, model, label_encoder except Exception as e: print(f"❌ Model yüklənmə xətası: {e}") print(f"📁 Mövcud fayllar: {os.listdir('.')}") return None, None, None # Model yükləmə tokenizer, model, label_encoder = load_model() # Prediction funksiyası def predict_disease(text): if tokenizer is None or model is None or label_encoder is None: return "❌ Model yüklənməyib!" if not text.strip(): return "⚠️ Simptomları daxil edin!" symptoms = [s.strip() for s in text.split(",") if s.strip()] if not symptoms: return "⚠️ Düzgün simptomlar yazın (vergüllə ayırın)!" try: agg_probs = defaultdict(float) n_shuffles = 5 # Sürəti artırmaq üçün azaltdım for i in range(n_shuffles): random.shuffle(symptoms) shuffled_text = ", ".join(symptoms) inputs = tokenizer( shuffled_text, return_tensors="pt", truncation=True, padding=True, max_length=128 ) with torch.no_grad(): outputs = model(**inputs) probs = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze() for idx, p in enumerate(probs): agg_probs[idx] += p.item() # Ortalama hesabla for k in agg_probs: agg_probs[k] /= n_shuffles # Top 3 nəticə top_3 = sorted(agg_probs.items(), key=lambda x: x[1], reverse=True)[:3] result = "🏥 **Mümkün xəstəliklər:**\n\n" for idx, prob in top_3: label = label_encoder.classes_[idx] result += f"🔸 **{label}** — %{prob*100:.1f}\n" return result except Exception as e: return f"❌ Xəta: {str(e)}" # Gradio interface demo = gr.Interface( fn=predict_disease, inputs=gr.Textbox( lines=3, placeholder="Məsələn: fever, cough, headache", label="🩺 Simptomlarınızı yazın (vergüllə ayırın)" ), outputs=gr.Textbox( label="📋 Nəticələr", lines=10 ), title="🏥 Xəstəlik Təyin Edici AI", description="Simptomlarınızı yazın və ən mümkün xəstəlikləri görün. ⚠️ Bu sadəcə kömək vasitəsidir, həkim məsləhəti əvəz etmir!", examples=[ ["fever, cough, headache"], ["stomach pain, nausea"], ["chest pain, shortness of breath"], ["dizziness, fatigue"], ["skin rash, itching"] ], theme=gr.themes.Soft(), allow_flagging="never" ) if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, show_api=True # API dokumentasiyası göstərir )