File size: 3,382 Bytes
84187cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
import sys
import streamlit as st
from transformers import BertTokenizer, BertForSequenceClassification  # Burada BertTokenizer istifadə edirik
import torch
import pickle
import random
from collections import defaultdict
import requests

# GitHub-dan fayl yükləmək üçün funksiyanın təyin edilməsi
def download_label_encoder():
    url = "https://github.com/AxundovReyal/nlp-disease/raw/main/label_encoder.pkl"
    headers = {
        "Authorization": f"token {os.getenv('GITHUB_TOKEN')}"  # GitHub personal access token mühit dəyişəni olaraq qeyd olunmalı
    }
    response = requests.get(url, headers=headers)

    if response.status_code == 200:
        with open("label_encoder.pkl", "wb") as f:
            f.write(response.content)
        print("label_encoder.pkl faylı uğurla yükləndi.")
    else:
        raise Exception(f"Fayl yüklənə bilmədi, error kodu: {response.status_code}")

# Modelin və label_encoder-in yüklənməsi
@st.cache_resource
def load_model():
    # GitHub-dan label_encoder yükləmək
    download_label_encoder()

    # Label encoder yüklənməsi əvvəlcə edilir
    with open("label_encoder.pkl", "rb") as f:
        label_encoder = pickle.load(f)
    
    # Burada AutoTokenizer əvəzinə BertTokenizer istifadə edirik
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')  # BERT Tokenizer
    model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=len(label_encoder.classes_))  # BERT Model

    model.eval()

    return tokenizer, model, label_encoder

tokenizer, model, label_encoder = load_model()

st.title("Disease NLP Classifier")

text = st.text_area("Enter your symptoms separated by commas (e.g. fever, cough, headache):")

def predict(text_input):
    inputs = tokenizer(text_input, return_tensors="pt", truncation=True, padding=True, max_length=128)
    with torch.no_grad():
        outputs = model(**inputs)
        probs = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze()
    return probs

if st.button("Predict"):
    if not text.strip():
        st.warning("Please enter some symptoms!")
    else:
        symptoms = [s.strip() for s in text.split(",") if s.strip()]
        if not symptoms:
            st.warning("Please enter valid symptoms separated by commas!")
        else:
            agg_probs = defaultdict(float)
            n_shuffles = 10
            for _ in range(n_shuffles):
                random.shuffle(symptoms)
                shuffled_text = ", ".join(symptoms)
                probs = predict(shuffled_text)
                for i, p in enumerate(probs):
                    agg_probs[i] += p.item()
            for k in agg_probs:
                agg_probs[k] /= n_shuffles
            top_3 = sorted(agg_probs.items(), key=lambda x: x[1], reverse=True)[:3]

            st.subheader("Top 3 Predicted Diseases (averaged over shuffled inputs):")
            for idx, prob in top_3:
                label = label_encoder.classes_[idx]  # Etiketləri doğru alırıq
                st.write(f"**{label}** — Probability: `{prob * 100:.2f}%`")

# Render port düzəlişi
if __name__ == "__main__":
    port = int(os.environ.get("PORT", 8501))
    sys.argv = ["streamlit", "run", "streamlit_app.py", f"--server.port={port}", "--server.address=0.0.0.0"]
    from streamlit.web.cli import main
    sys.exit(main())