Reyall commited on
Commit
633be3e
·
verified ·
1 Parent(s): 6e20e5c

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +53 -44
src/streamlit_app.py CHANGED
@@ -7,7 +7,7 @@ import random
7
  from collections import defaultdict
8
  import json
9
 
10
- # Label encoder faylını yükləmək
11
  def load_label_encoder():
12
  file_path = os.path.join(os.getcwd(), "best_model", "label_encoder.pkl")
13
  if not os.path.exists(file_path):
@@ -24,88 +24,97 @@ def load_model():
24
  model_path = os.path.join(os.getcwd(), "best_model")
25
 
26
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
27
- model = BertForSequenceClassification.from_pretrained(model_path, num_labels=len(label_encoder.classes_))
 
 
 
28
  model.eval()
29
  return tokenizer, model, label_encoder
30
 
31
  # Prediction funksiyası
32
  def predict_disease(symptoms_text, tokenizer, model, label_encoder):
33
  symptoms = [s.strip() for s in symptoms_text.split(",") if s.strip()]
34
-
35
  agg_probs = defaultdict(float)
36
  n_shuffles = 10
37
-
38
  for _ in range(n_shuffles):
39
  random.shuffle(symptoms)
40
  shuffled_text = ", ".join(symptoms)
41
-
42
- inputs = tokenizer(shuffled_text, return_tensors="pt", truncation=True, padding=True, max_length=128)
 
 
 
 
 
43
  with torch.no_grad():
44
  outputs = model(**inputs)
45
  probs = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze()
46
-
47
  for i, p in enumerate(probs):
48
  agg_probs[i] += p.item()
49
-
50
  for k in agg_probs:
51
  agg_probs[k] /= n_shuffles
52
-
53
  top_3 = sorted(agg_probs.items(), key=lambda x: x[1], reverse=True)[:3]
54
-
55
  results = []
56
  for idx, prob in top_3:
57
  label = label_encoder.classes_[idx] if idx < len(label_encoder.classes_) else f"Unknown label {idx}"
58
  results.append({"disease": label, "probability": prob})
59
-
60
  return results
61
 
62
  # Page config
63
  st.set_page_config(page_title="Disease API", layout="wide")
64
 
65
- # API mode detection
66
- # API mode detection
67
  query_params = st.query_params
68
- is_api_mode = query_params.get("api", ["false"])[0].lower() == "true"
69
 
70
- # Load model
71
  tokenizer, model, label_encoder = load_model()
72
 
 
73
  if is_api_mode:
74
- st.markdown("### API Mode")
75
  symptoms = query_params.get("symptoms", [""])[0]
76
  if symptoms:
77
  results = predict_disease(symptoms, tokenizer, model, label_encoder)
78
- st.json({
79
  "status": "success",
80
  "input": symptoms,
81
  "predictions": results
82
- })
83
  else:
84
- st.json({
85
  "status": "error",
86
  "message": "symptoms parameter required"
87
- })
88
-
89
- else:
90
- st.title("🏥 Disease Prediction")
91
- st.success("Model yükləndi!")
92
-
93
- # Debug: Siniflər
94
- st.write("Available classes:", list(label_encoder.classes_))
95
-
96
- # API usage info
97
- st.markdown("### API İstifadəsi")
98
- space_url = "https://your-username-your-space-name.hf.space"
99
- api_example = f"{space_url}/?api=true&symptoms=fever,cough,headache"
100
- st.code(api_example)
101
-
102
- text = st.text_area("Simptomları daxil edin (vergüllə ayırın):")
103
-
104
- if st.button("Predict"):
105
- if not text.strip():
106
- st.warning("Simptomları daxil edin!")
107
- else:
108
- results = predict_disease(text, tokenizer, model, label_encoder)
109
- st.subheader("🔍 Nəticələr:")
110
- for result in results:
111
- st.write(f"**{result['disease']}** — {result['probability']*100:.2f}%")
 
 
 
 
 
7
  from collections import defaultdict
8
  import json
9
 
10
+ # Label encoder yükləmə funksiyası
11
  def load_label_encoder():
12
  file_path = os.path.join(os.getcwd(), "best_model", "label_encoder.pkl")
13
  if not os.path.exists(file_path):
 
24
  model_path = os.path.join(os.getcwd(), "best_model")
25
 
26
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
27
+ model = BertForSequenceClassification.from_pretrained(
28
+ model_path,
29
+ num_labels=len(label_encoder.classes_)
30
+ )
31
  model.eval()
32
  return tokenizer, model, label_encoder
33
 
34
  # Prediction funksiyası
35
  def predict_disease(symptoms_text, tokenizer, model, label_encoder):
36
  symptoms = [s.strip() for s in symptoms_text.split(",") if s.strip()]
 
37
  agg_probs = defaultdict(float)
38
  n_shuffles = 10
39
+
40
  for _ in range(n_shuffles):
41
  random.shuffle(symptoms)
42
  shuffled_text = ", ".join(symptoms)
43
+ inputs = tokenizer(
44
+ shuffled_text,
45
+ return_tensors="pt",
46
+ truncation=True,
47
+ padding=True,
48
+ max_length=128
49
+ )
50
  with torch.no_grad():
51
  outputs = model(**inputs)
52
  probs = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze()
53
+
54
  for i, p in enumerate(probs):
55
  agg_probs[i] += p.item()
56
+
57
  for k in agg_probs:
58
  agg_probs[k] /= n_shuffles
59
+
60
  top_3 = sorted(agg_probs.items(), key=lambda x: x[1], reverse=True)[:3]
 
61
  results = []
62
  for idx, prob in top_3:
63
  label = label_encoder.classes_[idx] if idx < len(label_encoder.classes_) else f"Unknown label {idx}"
64
  results.append({"disease": label, "probability": prob})
65
+
66
  return results
67
 
68
  # Page config
69
  st.set_page_config(page_title="Disease API", layout="wide")
70
 
71
+ # Query parametrlər
 
72
  query_params = st.query_params
73
+ is_api_mode = str(query_params.get("api", ["false"])[0]).lower() == "true"
74
 
75
+ # Model yüklə
76
  tokenizer, model, label_encoder = load_model()
77
 
78
+ # API mode
79
  if is_api_mode:
 
80
  symptoms = query_params.get("symptoms", [""])[0]
81
  if symptoms:
82
  results = predict_disease(symptoms, tokenizer, model, label_encoder)
83
+ api_response = {
84
  "status": "success",
85
  "input": symptoms,
86
  "predictions": results
87
+ }
88
  else:
89
+ api_response = {
90
  "status": "error",
91
  "message": "symptoms parameter required"
92
+ }
93
+
94
+ # JSON olaraq qaytar (raw)
95
+ st.write(json.dumps(api_response, ensure_ascii=False))
96
+ st.stop()
97
+
98
+ # Web interfeys
99
+ st.title("🏥 Disease Prediction")
100
+ st.success("Model yükləndi!")
101
+
102
+ # Debug: Siniflər
103
+ st.write("Available classes:", list(label_encoder.classes_))
104
+
105
+ # API usage info
106
+ st.markdown("### API İstifadəsi")
107
+ space_url = "https://your-username-your-space-name.hf.space"
108
+ api_example = f"{space_url}/?api=true&symptoms=fever,cough,headache"
109
+ st.code(api_example)
110
+
111
+ text = st.text_area("Simptomları daxil edin (vergüllə ayırın):")
112
+
113
+ if st.button("Predict"):
114
+ if not text.strip():
115
+ st.warning("Simptomları daxil edin!")
116
+ else:
117
+ results = predict_disease(text, tokenizer, model, label_encoder)
118
+ st.subheader("🔍 Nəticələr:")
119
+ for result in results:
120
+ st.write(f"**{result['disease']}** — {result['probability']*100:.2f}%")