File size: 3,739 Bytes
77e49dc a8e9ef6 f34c52b 77e49dc 072ed30 a8e9ef6 eaeb0e4 33715c8 eaeb0e4 33715c8 eaeb0e4 77e49dc 072ed30 77e49dc eaeb0e4 33715c8 4821ed3 33715c8 633be3e eaeb0e4 633be3e 77e49dc eaeb0e4 77e49dc 072ed30 eaeb0e4 072ed30 633be3e 072ed30 633be3e 072ed30 633be3e 072ed30 633be3e 072ed30 633be3e 072ed30 eaeb0e4 072ed30 633be3e 072ed30 633be3e 6e20e5c 633be3e 072ed30 633be3e eaeb0e4 77e49dc 633be3e 072ed30 eaeb0e4 633be3e 072ed30 633be3e 072ed30 633be3e 072ed30 633be3e 00bd8a4 633be3e eaeb0e4 633be3e eaeb0e4 0b4c6e1 633be3e 0b4c6e1 633be3e eaeb0e4 633be3e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import os
import streamlit as st
from transformers.models.bert import BertTokenizer, BertForSequenceClassification
import torch
import pickle
import random
from collections import defaultdict
import json
# Name encoder yükləmə funksiyası
def load_name_encoder():
file_path = os.path.join(os.getcwd(), "best_model", "name_encoder.pkl")
if not os.path.exists(file_path):
st.error(f"Name encoder faylı tapılmadı: {file_path}")
st.stop()
with open(file_path, "rb") as f:
name_encoder = pickle.load(f)
return name_encoder
# Model və tokenizer yükləmə
@st.cache_resource
def load_model():
name_encoder = load_name_encoder()
model_path = os.path.join(os.getcwd(), "best_model")
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained(
model_path,
num_labels=len(name_encoder.classes_)
)
model.eval()
return tokenizer, model, name_encoder
# Prediction funksiyası
def predict_disease(symptoms_text, tokenizer, model, name_encoder):
symptoms = [s.strip() for s in symptoms_text.split(",") if s.strip()]
agg_probs = defaultdict(float)
n_shuffles = 10
for _ in range(n_shuffles):
random.shuffle(symptoms)
shuffled_text = ", ".join(symptoms)
inputs = tokenizer(
shuffled_text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=128
)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze()
for i, p in enumerate(probs):
agg_probs[i] += p.item()
for k in agg_probs:
agg_probs[k] /= n_shuffles
top_3 = sorted(agg_probs.items(), key=lambda x: x[1], reverse=True)[:3]
results = []
for idx, prob in top_3:
label = name_encoder.classes_[idx]
results.append({"disease": label, "probability": prob})
return results
# Page config
st.set_page_config(page_title="Disease API", layout="wide")
# Query parametrlər
query_params = st.query_params
is_api_mode = str(query_params.get("api", ["false"])[0]).lower() == "true"
# Model yüklə
tokenizer, model, name_encoder = load_model()
# API mode
if is_api_mode:
symptoms = query_params.get("symptoms", [""])[0]
if symptoms:
results = predict_disease(symptoms, tokenizer, model, name_encoder)
api_response = {
"status": "success",
"input": symptoms,
"predictions": results
}
else:
api_response = {
"status": "error",
"message": "symptoms parameter required"
}
st.markdown(
f"```json\n{json.dumps(api_response, ensure_ascii=False, indent=2)}\n```"
)
st.stop()
# Web interfeys
st.title("🏥 Disease Prediction")
st.success("Model yükləndi!")
# Debug: Siniflər
st.write("Available classes:", list(name_encoder.classes_))
# API usage info
st.markdown("### API İstifadəsi")
space_url = "https://your-username-your-space-name.hf.space"
api_example = f"{space_url}/?api=true&symptoms=fever,cough,headache"
st.code(api_example)
# Form
with st.form(key="predict_form"):
text = st.text_area("Simptomları daxil edin (vergüllə ayırın):")
submit_button = st.form_submit_button(label="Predict")
if submit_button:
if not text.strip():
st.warning("Simptomları daxil edin!")
else:
results = predict_disease(text, tokenizer, model, name_encoder)
st.subheader("🔍 Nəticələr:")
for result in results:
st.write(f"**{result['disease']}** — {result['probability']*100:.2f}%")
|