import os import streamlit as st from transformers.models.bert import BertTokenizer, BertForSequenceClassification import torch import pickle import random from collections import defaultdict import json # Name encoder yükləmə funksiyası def load_name_encoder(): file_path = os.path.join(os.getcwd(), "best_model", "name_encoder.pkl") if not os.path.exists(file_path): st.error(f"Name encoder faylı tapılmadı: {file_path}") st.stop() with open(file_path, "rb") as f: name_encoder = pickle.load(f) return name_encoder # Model və tokenizer yükləmə @st.cache_resource def load_model(): name_encoder = load_name_encoder() model_path = os.path.join(os.getcwd(), "best_model") tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') model = BertForSequenceClassification.from_pretrained( model_path, num_labels=len(name_encoder.classes_) ) model.eval() return tokenizer, model, name_encoder # Prediction funksiyası def predict_disease(symptoms_text, tokenizer, model, name_encoder): symptoms = [s.strip() for s in symptoms_text.split(",") if s.strip()] agg_probs = defaultdict(float) n_shuffles = 10 for _ in range(n_shuffles): random.shuffle(symptoms) shuffled_text = ", ".join(symptoms) inputs = tokenizer( shuffled_text, return_tensors="pt", truncation=True, padding=True, max_length=128 ) with torch.no_grad(): outputs = model(**inputs) probs = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze() for i, p in enumerate(probs): agg_probs[i] += p.item() for k in agg_probs: agg_probs[k] /= n_shuffles top_3 = sorted(agg_probs.items(), key=lambda x: x[1], reverse=True)[:3] results = [] for idx, prob in top_3: label = name_encoder.classes_[idx] results.append({"disease": label, "probability": prob}) return results # Page config st.set_page_config(page_title="Disease API", layout="wide") # Query parametrlər query_params = st.query_params is_api_mode = str(query_params.get("api", ["false"])[0]).lower() == "true" # Model yüklə tokenizer, model, name_encoder = load_model() # API mode if is_api_mode: symptoms = query_params.get("symptoms", [""])[0] if symptoms: results = predict_disease(symptoms, tokenizer, model, name_encoder) api_response = { "status": "success", "input": symptoms, "predictions": results } else: api_response = { "status": "error", "message": "symptoms parameter required" } st.markdown( f"```json\n{json.dumps(api_response, ensure_ascii=False, indent=2)}\n```" ) st.stop() # Web interfeys st.title("🏥 Disease Prediction") st.success("Model yükləndi!") # Debug: Siniflər st.write("Available classes:", list(name_encoder.classes_)) # API usage info st.markdown("### API İstifadəsi") space_url = "https://your-username-your-space-name.hf.space" api_example = f"{space_url}/?api=true&symptoms=fever,cough,headache" st.code(api_example) # Form with st.form(key="predict_form"): text = st.text_area("Simptomları daxil edin (vergüllə ayırın):") submit_button = st.form_submit_button(label="Predict") if submit_button: if not text.strip(): st.warning("Simptomları daxil edin!") else: results = predict_disease(text, tokenizer, model, name_encoder) st.subheader("🔍 Nəticələr:") for result in results: st.write(f"**{result['disease']}** — {result['probability']*100:.2f}%")