|
|
import gradio as gr |
|
|
from transformers import BertTokenizer, BertForSequenceClassification |
|
|
import torch |
|
|
import pickle |
|
|
from collections import defaultdict |
|
|
import random |
|
|
import os |
|
|
|
|
|
|
|
|
def load_model(): |
|
|
try: |
|
|
|
|
|
with open("label_encoder.pkl", "rb") as f: |
|
|
label_encoder = pickle.load(f) |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
tokenizer = BertTokenizer.from_pretrained("best_model") |
|
|
model = BertForSequenceClassification.from_pretrained("best_model") |
|
|
except: |
|
|
|
|
|
tokenizer = BertTokenizer.from_pretrained(".") |
|
|
model = BertForSequenceClassification.from_pretrained(".") |
|
|
|
|
|
model.eval() |
|
|
print(f"✅ Model yükləndi. Label sayı: {len(label_encoder.classes_)}") |
|
|
return tokenizer, model, label_encoder |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Model yüklənmə xətası: {e}") |
|
|
print(f"📁 Mövcud fayllar: {os.listdir('.')}") |
|
|
return None, None, None |
|
|
|
|
|
|
|
|
tokenizer, model, label_encoder = load_model() |
|
|
|
|
|
|
|
|
def predict_disease(text): |
|
|
if tokenizer is None or model is None or label_encoder is None: |
|
|
return "❌ Model yüklənməyib!" |
|
|
|
|
|
if not text.strip(): |
|
|
return "⚠️ Simptomları daxil edin!" |
|
|
|
|
|
symptoms = [s.strip() for s in text.split(",") if s.strip()] |
|
|
if not symptoms: |
|
|
return "⚠️ Düzgün simptomlar yazın (vergüllə ayırın)!" |
|
|
|
|
|
try: |
|
|
agg_probs = defaultdict(float) |
|
|
n_shuffles = 5 |
|
|
|
|
|
for i in range(n_shuffles): |
|
|
random.shuffle(symptoms) |
|
|
shuffled_text = ", ".join(symptoms) |
|
|
|
|
|
inputs = tokenizer( |
|
|
shuffled_text, |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
padding=True, |
|
|
max_length=128 |
|
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
probs = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze() |
|
|
|
|
|
for idx, p in enumerate(probs): |
|
|
agg_probs[idx] += p.item() |
|
|
|
|
|
|
|
|
for k in agg_probs: |
|
|
agg_probs[k] /= n_shuffles |
|
|
|
|
|
|
|
|
top_3 = sorted(agg_probs.items(), key=lambda x: x[1], reverse=True)[:3] |
|
|
|
|
|
result = "🏥 **Mümkün xəstəliklər:**\n\n" |
|
|
for idx, prob in top_3: |
|
|
label = label_encoder.classes_[idx] |
|
|
result += f"🔸 **{label}** — %{prob*100:.1f}\n" |
|
|
|
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
return f"❌ Xəta: {str(e)}" |
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=predict_disease, |
|
|
inputs=gr.Textbox( |
|
|
lines=3, |
|
|
placeholder="Məsələn: fever, cough, headache", |
|
|
label="🩺 Simptomlarınızı yazın (vergüllə ayırın)" |
|
|
), |
|
|
outputs=gr.Textbox( |
|
|
label="📋 Nəticələr", |
|
|
lines=10 |
|
|
), |
|
|
title="🏥 Xəstəlik Təyin Edici AI", |
|
|
description="Simptomlarınızı yazın və ən mümkün xəstəlikləri görün. ⚠️ Bu sadəcə kömək vasitəsidir, həkim məsləhəti əvəz etmir!", |
|
|
examples=[ |
|
|
["fever, cough, headache"], |
|
|
["stomach pain, nausea"], |
|
|
["chest pain, shortness of breath"], |
|
|
["dizziness, fatigue"], |
|
|
["skin rash, itching"] |
|
|
], |
|
|
theme=gr.themes.Soft(), |
|
|
allow_flagging="never" |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
show_api=True |
|
|
) |