Reyall commited on
Commit
072ed30
·
verified ·
1 Parent(s): 51f0272

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +75 -37
src/streamlit_app.py CHANGED
@@ -5,6 +5,7 @@ import torch
5
  import pickle
6
  import random
7
  from collections import defaultdict
 
8
 
9
  # Label encoder faylını yükləmək
10
  def load_label_encoder():
@@ -16,57 +17,94 @@ def load_label_encoder():
16
  label_encoder = pickle.load(f)
17
  return label_encoder
18
 
 
19
  @st.cache_resource
20
  def load_model():
21
  label_encoder = load_label_encoder()
22
  model_path = os.path.join(os.getcwd(), "best_model")
23
 
24
- # Tokenizer və modelin yüklənməsi
25
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
26
  model = BertForSequenceClassification.from_pretrained(model_path, num_labels=len(label_encoder.classes_))
27
  model.eval()
28
  return tokenizer, model, label_encoder
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  tokenizer, model, label_encoder = load_model()
31
 
32
- st.title("Disease NLP Classifier")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- # Siniflərin sırasını göstərək (Debug məqsədilə)
35
- st.write("Available classes:", list(label_encoder.classes_))
 
36
 
37
- text = st.text_area("Enter your symptoms separated by commas (e.g. fever, cough, headache):")
 
38
 
39
- def predict(text_input):
40
- inputs = tokenizer(text_input, return_tensors="pt", truncation=True, padding=True, max_length=128)
41
- with torch.no_grad():
42
- outputs = model(**inputs)
43
- probs = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze()
44
- return probs
45
 
46
- if st.button("Predict"):
47
- if not text.strip():
48
- st.warning("Please enter some symptoms!")
49
- else:
50
- symptoms = [s.strip() for s in text.split(",") if s.strip()]
51
- if not symptoms:
52
- st.warning("Please enter valid symptoms separated by commas!")
53
- else:
54
- agg_probs = defaultdict(float)
55
- n_shuffles = 10
56
- for _ in range(n_shuffles):
57
- random.shuffle(symptoms)
58
- shuffled_text = ", ".join(symptoms)
59
- probs = predict(shuffled_text)
60
- for i, p in enumerate(probs):
61
- agg_probs[i] += p.item()
62
- for k in agg_probs:
63
- agg_probs[k] /= n_shuffles
64
- top_3 = sorted(agg_probs.items(), key=lambda x: x[1], reverse=True)[:3]
65
 
66
- st.subheader("Top 3 Predicted Diseases (averaged over shuffled inputs):")
67
- for idx, prob in top_3:
68
- try:
69
- label = label_encoder.classes_[idx]
70
- except IndexError:
71
- label = f"Unknown label idx {idx}"
72
- st.write(f"**{label}** Probability: `{prob * 100:.2f}%`")
 
 
5
  import pickle
6
  import random
7
  from collections import defaultdict
8
+ import json
9
 
10
  # Label encoder faylını yükləmək
11
  def load_label_encoder():
 
17
  label_encoder = pickle.load(f)
18
  return label_encoder
19
 
20
+ # Model və tokenizer yükləmə
21
  @st.cache_resource
22
  def load_model():
23
  label_encoder = load_label_encoder()
24
  model_path = os.path.join(os.getcwd(), "best_model")
25
 
 
26
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
27
  model = BertForSequenceClassification.from_pretrained(model_path, num_labels=len(label_encoder.classes_))
28
  model.eval()
29
  return tokenizer, model, label_encoder
30
 
31
+ # Prediction funksiyası
32
+ def predict_disease(symptoms_text, tokenizer, model, label_encoder):
33
+ symptoms = [s.strip() for s in symptoms_text.split(",") if s.strip()]
34
+
35
+ agg_probs = defaultdict(float)
36
+ n_shuffles = 10
37
+
38
+ for _ in range(n_shuffles):
39
+ random.shuffle(symptoms)
40
+ shuffled_text = ", ".join(symptoms)
41
+
42
+ inputs = tokenizer(shuffled_text, return_tensors="pt", truncation=True, padding=True, max_length=128)
43
+ with torch.no_grad():
44
+ outputs = model(**inputs)
45
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze()
46
+
47
+ for i, p in enumerate(probs):
48
+ agg_probs[i] += p.item()
49
+
50
+ for k in agg_probs:
51
+ agg_probs[k] /= n_shuffles
52
+
53
+ top_3 = sorted(agg_probs.items(), key=lambda x: x[1], reverse=True)[:3]
54
+
55
+ results = []
56
+ for idx, prob in top_3:
57
+ label = label_encoder.classes_[idx] if idx < len(label_encoder.classes_) else f"Unknown label {idx}"
58
+ results.append({"disease": label, "probability": prob})
59
+
60
+ return results
61
+
62
+ # Page config
63
+ st.set_page_config(page_title="Disease API", layout="wide")
64
+
65
+ # API mode detection
66
+ query_params = st.experimental_get_query_params()
67
+ is_api_mode = "api" in query_params
68
+
69
+ # Load model
70
  tokenizer, model, label_encoder = load_model()
71
 
72
+ if is_api_mode:
73
+ st.markdown("### API Mode")
74
+ symptoms = query_params.get("symptoms", [""])[0]
75
+ if symptoms:
76
+ results = predict_disease(symptoms, tokenizer, model, label_encoder)
77
+ st.json({
78
+ "status": "success",
79
+ "input": symptoms,
80
+ "predictions": results
81
+ })
82
+ else:
83
+ st.json({
84
+ "status": "error",
85
+ "message": "symptoms parameter required"
86
+ })
87
 
88
+ else:
89
+ st.title("🏥 Disease Prediction")
90
+ st.success("Model yükləndi!")
91
 
92
+ # Debug: Siniflər
93
+ st.write("Available classes:", list(label_encoder.classes_))
94
 
95
+ # API usage info
96
+ st.markdown("### API İstifadəsi")
97
+ space_url = "https://your-username-your-space-name.hf.space"
98
+ api_example = f"{space_url}/?api=true&symptoms=fever,cough,headache"
99
+ st.code(api_example)
 
100
 
101
+ text = st.text_area("Simptomları daxil edin (vergüllə ayırın):")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
+ if st.button("Predict"):
104
+ if not text.strip():
105
+ st.warning("Simptomları daxil edin!")
106
+ else:
107
+ results = predict_disease(text, tokenizer, model, label_encoder)
108
+ st.subheader("🔍 Nəticələr:")
109
+ for result in results:
110
+ st.write(f"**{result['disease']}** — {result['probability']*100:.2f}%")