File size: 3,739 Bytes
77e49dc
a8e9ef6
f34c52b
77e49dc
 
 
 
072ed30
a8e9ef6
eaeb0e4
 
 
33715c8
eaeb0e4
33715c8
 
eaeb0e4
 
77e49dc
072ed30
77e49dc
 
eaeb0e4
33715c8
4821ed3
33715c8
633be3e
 
eaeb0e4
633be3e
77e49dc
eaeb0e4
77e49dc
072ed30
eaeb0e4
072ed30
 
 
633be3e
072ed30
 
 
633be3e
 
 
 
 
 
 
072ed30
 
 
633be3e
072ed30
 
633be3e
072ed30
 
633be3e
072ed30
 
 
eaeb0e4
072ed30
633be3e
072ed30
 
 
 
 
633be3e
6e20e5c
633be3e
072ed30
633be3e
eaeb0e4
77e49dc
633be3e
072ed30
 
 
eaeb0e4
633be3e
072ed30
 
 
633be3e
072ed30
633be3e
072ed30
 
633be3e
 
00bd8a4
 
 
633be3e
 
 
 
 
 
 
eaeb0e4
633be3e
 
 
 
 
 
 
eaeb0e4
0b4c6e1
 
 
633be3e
0b4c6e1
633be3e
 
 
eaeb0e4
633be3e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import streamlit as st
from transformers.models.bert import BertTokenizer, BertForSequenceClassification
import torch
import pickle
import random
from collections import defaultdict
import json

# Name encoder yükləmə funksiyası
def load_name_encoder():
    file_path = os.path.join(os.getcwd(), "best_model", "name_encoder.pkl")
    if not os.path.exists(file_path):
        st.error(f"Name encoder faylı tapılmadı: {file_path}")
        st.stop()
    with open(file_path, "rb") as f:
        name_encoder = pickle.load(f)
    return name_encoder

# Model və tokenizer yükləmə
@st.cache_resource
def load_model():
    name_encoder = load_name_encoder()
    model_path = os.path.join(os.getcwd(), "best_model")

    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    model = BertForSequenceClassification.from_pretrained(
        model_path,
        num_labels=len(name_encoder.classes_)
    )
    model.eval()
    return tokenizer, model, name_encoder

# Prediction funksiyası
def predict_disease(symptoms_text, tokenizer, model, name_encoder):
    symptoms = [s.strip() for s in symptoms_text.split(",") if s.strip()]
    agg_probs = defaultdict(float)
    n_shuffles = 10

    for _ in range(n_shuffles):
        random.shuffle(symptoms)
        shuffled_text = ", ".join(symptoms)
        inputs = tokenizer(
            shuffled_text,
            return_tensors="pt",
            truncation=True,
            padding=True,
            max_length=128
        )
        with torch.no_grad():
            outputs = model(**inputs)
            probs = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze()

        for i, p in enumerate(probs):
            agg_probs[i] += p.item()

    for k in agg_probs:
        agg_probs[k] /= n_shuffles

    top_3 = sorted(agg_probs.items(), key=lambda x: x[1], reverse=True)[:3]
    results = []
    for idx, prob in top_3:
        label = name_encoder.classes_[idx]
        results.append({"disease": label, "probability": prob})

    return results

# Page config
st.set_page_config(page_title="Disease API", layout="wide")

# Query parametrlər
query_params = st.query_params
is_api_mode = str(query_params.get("api", ["false"])[0]).lower() == "true"

# Model yüklə
tokenizer, model, name_encoder = load_model()

# API mode
if is_api_mode:
    symptoms = query_params.get("symptoms", [""])[0]
    if symptoms:
        results = predict_disease(symptoms, tokenizer, model, name_encoder)
        api_response = {
            "status": "success",
            "input": symptoms,
            "predictions": results
        }
    else:
        api_response = {
            "status": "error",
            "message": "symptoms parameter required"
        }

    st.markdown(
        f"```json\n{json.dumps(api_response, ensure_ascii=False, indent=2)}\n```"
    )
    st.stop()

# Web interfeys
st.title("🏥 Disease Prediction")
st.success("Model yükləndi!")

# Debug: Siniflər
st.write("Available classes:", list(name_encoder.classes_))

# API usage info
st.markdown("### API İstifadəsi")
space_url = "https://your-username-your-space-name.hf.space"
api_example = f"{space_url}/?api=true&symptoms=fever,cough,headache"
st.code(api_example)

# Form
with st.form(key="predict_form"):
    text = st.text_area("Simptomları daxil edin (vergüllə ayırın):")
    submit_button = st.form_submit_button(label="Predict")

if submit_button:
    if not text.strip():
        st.warning("Simptomları daxil edin!")
    else:
        results = predict_disease(text, tokenizer, model, name_encoder)
        st.subheader("🔍 Nəticələr:")
        for result in results:
            st.write(f"**{result['disease']}** — {result['probability']*100:.2f}%")