for_api / app.py
Reyall's picture
Update app.py
daf2580 verified
import gradio as gr
from transformers import BertTokenizer, BertForSequenceClassification
import torch
import pickle
from collections import defaultdict
import random
import os
# Model və label_encoder yüklənməsi
def load_model():
try:
# Label encoder
with open("label_encoder.pkl", "rb") as f:
label_encoder = pickle.load(f)
# Tokenizer və Model - iki variant sınayırıq
try:
# Variant 1: best_model qovluğu
tokenizer = BertTokenizer.from_pretrained("best_model")
model = BertForSequenceClassification.from_pretrained("best_model")
except:
# Variant 2: Ana qovluq
tokenizer = BertTokenizer.from_pretrained(".")
model = BertForSequenceClassification.from_pretrained(".")
model.eval()
print(f"✅ Model yükləndi. Label sayı: {len(label_encoder.classes_)}")
return tokenizer, model, label_encoder
except Exception as e:
print(f"❌ Model yüklənmə xətası: {e}")
print(f"📁 Mövcud fayllar: {os.listdir('.')}")
return None, None, None
# Model yükləmə
tokenizer, model, label_encoder = load_model()
# Prediction funksiyası
def predict_disease(text):
if tokenizer is None or model is None or label_encoder is None:
return "❌ Model yüklənməyib!"
if not text.strip():
return "⚠️ Simptomları daxil edin!"
symptoms = [s.strip() for s in text.split(",") if s.strip()]
if not symptoms:
return "⚠️ Düzgün simptomlar yazın (vergüllə ayırın)!"
try:
agg_probs = defaultdict(float)
n_shuffles = 5 # Sürəti artırmaq üçün azaltdım
for i in range(n_shuffles):
random.shuffle(symptoms)
shuffled_text = ", ".join(symptoms)
inputs = tokenizer(
shuffled_text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=128
)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze()
for idx, p in enumerate(probs):
agg_probs[idx] += p.item()
# Ortalama hesabla
for k in agg_probs:
agg_probs[k] /= n_shuffles
# Top 3 nəticə
top_3 = sorted(agg_probs.items(), key=lambda x: x[1], reverse=True)[:3]
result = "🏥 **Mümkün xəstəliklər:**\n\n"
for idx, prob in top_3:
label = label_encoder.classes_[idx]
result += f"🔸 **{label}** — %{prob*100:.1f}\n"
return result
except Exception as e:
return f"❌ Xəta: {str(e)}"
# Gradio interface
demo = gr.Interface(
fn=predict_disease,
inputs=gr.Textbox(
lines=3,
placeholder="Məsələn: fever, cough, headache",
label="🩺 Simptomlarınızı yazın (vergüllə ayırın)"
),
outputs=gr.Textbox(
label="📋 Nəticələr",
lines=10
),
title="🏥 Xəstəlik Təyin Edici AI",
description="Simptomlarınızı yazın və ən mümkün xəstəlikləri görün. ⚠️ Bu sadəcə kömək vasitəsidir, həkim məsləhəti əvəz etmir!",
examples=[
["fever, cough, headache"],
["stomach pain, nausea"],
["chest pain, shortness of breath"],
["dizziness, fatigue"],
["skin rash, itching"]
],
theme=gr.themes.Soft(),
allow_flagging="never"
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_api=True # API dokumentasiyası göstərir
)