for_api / app.py
Reyall's picture
Update app.py
c123726 verified
raw
history blame
4.73 kB
import gradio as gr
from transformers import BertTokenizer, BertForSequenceClassification
import torch
import pickle
from collections import defaultdict
import random
import os
from safetensors.torch import load_file
# Model və label_encoder yüklənməsi
def load_model():
try:
# Label encoder
with open("best_model/label_encoder.pkl", "rb") as f:
label_encoder = pickle.load(f)
# Tokenizer
tokenizer = BertTokenizer.from_pretrained("best_model")
# Model (safetensors avtomatik dəstəklənir)
model = BertForSequenceClassification.from_pretrained("best_model", use_safetensors=True)
model.eval()
print(f"Model uğurla yükləndi. Label sayı: {len(label_encoder.classes_)}")
return tokenizer, model, label_encoder
except Exception as e:
print(f"Model yüklənmə xətası: {e}")
# Faylları yoxla
if os.path.exists("best_model"):
files = os.listdir("best_model")
print(f"best_model qovluğundakı fayllar: {files}")
else:
print("best_model qovluğu mövcud deyil")
return None, None, None
# Model yükləmə
tokenizer, model, label_encoder = load_model()
# Prediction funksiyası
def predict_disease(text):
if tokenizer is None or model is None or label_encoder is None:
return "❌ Model yüklənməyib! Xəta var."
if not text.strip():
return "⚠️ Please enter some symptoms!"
symptoms = [s.strip() for s in text.split(",") if s.strip()]
if not symptoms:
return "⚠️ Please enter valid symptoms separated by commas!"
try:
agg_probs = defaultdict(float)
n_shuffles = 10
for i in range(n_shuffles):
random.shuffle(symptoms)
shuffled_text = ", ".join(symptoms)
inputs = tokenizer(
shuffled_text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=128
)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze()
for idx, p in enumerate(probs):
agg_probs[idx] += p.item()
# Ortalama hesabla
for k in agg_probs:
agg_probs[k] /= n_shuffles
# Top 3 nəticə
top_3 = sorted(agg_probs.items(), key=lambda x: x[1], reverse=True)[:3]
results = ["🏥 Top 3 Predicted Diseases:\n"]
for idx, prob in top_3:
label = label_encoder.classes_[idx]
results.append(f"• **{label}** — Probability: {prob*100:.2f}%")
return "\n".join(results)
except Exception as e:
return f"❌ Prediction xətası: {str(e)}"
# Gradio interface
iface = gr.Interface(
fn=predict_disease,
inputs=gr.Textbox(
lines=2,
placeholder="fever, cough, headache, shortness of breath",
label="Enter your symptoms (comma separated)"
),
outputs=gr.Textbox(label="Predicted Diseases"),
title="🏥 Disease NLP Classifier",
description="Enter your symptoms separated by commas and get top 3 predicted diseases with confidence scores.",
examples=[
["fever, cough, headache"],
["stomach pain, nausea, vomiting"],
["chest pain, shortness of breath"],
["dizziness, fatigue, weakness"],
["skin rash, itching, redness"]
]
)
# Launch
if __name__ == "__main__":
if tokenizer and model and label_encoder:
print("✅ Model hazırdır, Gradio başladılır...")
iface.launch(
server_name="0.0.0.0",
server_port=int(os.environ.get("PORT", 7860)),
share=True # Public link yaradır
)
else:
print("❌ Model yüklənmədi, Gradio başladıla bilmir!")
print("\nDebug məlumatları:")
print(f"Hazırkı qovluq: {os.getcwd()}")
print(f"Qovluq məzmunu: {os.listdir('.')}")
# Sadə debug interface
def debug_info():
return f"Debug məlumatları:\nHazırkı qovluq: {os.getcwd()}\nFayllar: {os.listdir('.')}"
debug_iface = gr.Interface(
fn=debug_info,
inputs=gr.Textbox(placeholder="Debug üçün hər hansı mətn yazın"),
outputs=gr.Textbox(),
title="🔧 Debug Interface"
)
debug_iface.launch(
server_name="0.0.0.0",
server_port=int(os.environ.get("PORT", 7860))
)