Reyall commited on
Commit
eaeb0e4
·
verified ·
1 Parent(s): 7cf20b1

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +16 -26
src/streamlit_app.py CHANGED
@@ -7,32 +7,32 @@ import random
7
  from collections import defaultdict
8
  import json
9
 
10
- # Label encoder yükləmə funksiyası
11
- def load_label_encoder():
12
- file_path = os.path.join(os.getcwd(), "best_model", "label_encoder.pkl")
13
  if not os.path.exists(file_path):
14
- st.error(f"Label encoder faylı tapılmadı: {file_path}")
15
  st.stop()
16
  with open(file_path, "rb") as f:
17
- label_encoder = pickle.load(f)
18
- return label_encoder
19
 
20
  # Model və tokenizer yükləmə
21
  @st.cache_resource
22
  def load_model():
23
- label_encoder = load_label_encoder()
24
  model_path = os.path.join(os.getcwd(), "best_model")
25
 
26
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
27
  model = BertForSequenceClassification.from_pretrained(
28
  model_path,
29
- num_labels=len(label_encoder.classes_) if hasattr(label_encoder, "classes_") else len(label_encoder)
30
  )
31
  model.eval()
32
- return tokenizer, model, label_encoder
33
 
34
  # Prediction funksiyası
35
- def predict_disease(symptoms_text, tokenizer, model, label_encoder):
36
  symptoms = [s.strip() for s in symptoms_text.split(",") if s.strip()]
37
  agg_probs = defaultdict(float)
38
  n_shuffles = 10
@@ -60,13 +60,7 @@ def predict_disease(symptoms_text, tokenizer, model, label_encoder):
60
  top_3 = sorted(agg_probs.items(), key=lambda x: x[1], reverse=True)[:3]
61
  results = []
62
  for idx, prob in top_3:
63
- # LabelEncoder obyekti olub-olmamasını yoxla
64
- if hasattr(label_encoder, "classes_"):
65
- label = label_encoder.classes_[idx]
66
- elif isinstance(label_encoder, dict):
67
- label = label_encoder.get(idx, f"Unknown label {idx}")
68
- else:
69
- label = f"Unknown label {idx}"
70
  results.append({"disease": label, "probability": prob})
71
 
72
  return results
@@ -79,13 +73,13 @@ query_params = st.query_params
79
  is_api_mode = str(query_params.get("api", ["false"])[0]).lower() == "true"
80
 
81
  # Model yüklə
82
- tokenizer, model, label_encoder = load_model()
83
 
84
  # API mode
85
  if is_api_mode:
86
  symptoms = query_params.get("symptoms", [""])[0]
87
  if symptoms:
88
- results = predict_disease(symptoms, tokenizer, model, label_encoder)
89
  api_response = {
90
  "status": "success",
91
  "input": symptoms,
@@ -97,7 +91,6 @@ if is_api_mode:
97
  "message": "symptoms parameter required"
98
  }
99
 
100
- # JSON olaraq qaytar (raw)
101
  st.markdown(
102
  f"```json\n{json.dumps(api_response, ensure_ascii=False, indent=2)}\n```"
103
  )
@@ -108,10 +101,7 @@ st.title("🏥 Disease Prediction")
108
  st.success("Model yükləndi!")
109
 
110
  # Debug: Siniflər
111
- if hasattr(label_encoder, "classes_"):
112
- st.write("Available classes:", list(label_encoder.classes_))
113
- elif isinstance(label_encoder, dict):
114
- st.write("Available classes:", list(label_encoder.values()))
115
 
116
  # API usage info
117
  st.markdown("### API İstifadəsi")
@@ -119,6 +109,7 @@ space_url = "https://your-username-your-space-name.hf.space"
119
  api_example = f"{space_url}/?api=true&symptoms=fever,cough,headache"
120
  st.code(api_example)
121
 
 
122
  with st.form(key="predict_form"):
123
  text = st.text_area("Simptomları daxil edin (vergüllə ayırın):")
124
  submit_button = st.form_submit_button(label="Predict")
@@ -127,8 +118,7 @@ if submit_button:
127
  if not text.strip():
128
  st.warning("Simptomları daxil edin!")
129
  else:
130
- results = predict_disease(text, tokenizer, model, label_encoder)
131
  st.subheader("🔍 Nəticələr:")
132
  for result in results:
133
  st.write(f"**{result['disease']}** — {result['probability']*100:.2f}%")
134
-
 
7
  from collections import defaultdict
8
  import json
9
 
10
+ # Name encoder yükləmə funksiyası
11
+ def load_name_encoder():
12
+ file_path = os.path.join(os.getcwd(), "best_model", "name_encoder.pkl")
13
  if not os.path.exists(file_path):
14
+ st.error(f"Name encoder faylı tapılmadı: {file_path}")
15
  st.stop()
16
  with open(file_path, "rb") as f:
17
+ name_encoder = pickle.load(f)
18
+ return name_encoder
19
 
20
  # Model və tokenizer yükləmə
21
  @st.cache_resource
22
  def load_model():
23
+ name_encoder = load_name_encoder()
24
  model_path = os.path.join(os.getcwd(), "best_model")
25
 
26
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
27
  model = BertForSequenceClassification.from_pretrained(
28
  model_path,
29
+ num_labels=len(name_encoder.classes_)
30
  )
31
  model.eval()
32
+ return tokenizer, model, name_encoder
33
 
34
  # Prediction funksiyası
35
+ def predict_disease(symptoms_text, tokenizer, model, name_encoder):
36
  symptoms = [s.strip() for s in symptoms_text.split(",") if s.strip()]
37
  agg_probs = defaultdict(float)
38
  n_shuffles = 10
 
60
  top_3 = sorted(agg_probs.items(), key=lambda x: x[1], reverse=True)[:3]
61
  results = []
62
  for idx, prob in top_3:
63
+ label = name_encoder.classes_[idx]
 
 
 
 
 
 
64
  results.append({"disease": label, "probability": prob})
65
 
66
  return results
 
73
  is_api_mode = str(query_params.get("api", ["false"])[0]).lower() == "true"
74
 
75
  # Model yüklə
76
+ tokenizer, model, name_encoder = load_model()
77
 
78
  # API mode
79
  if is_api_mode:
80
  symptoms = query_params.get("symptoms", [""])[0]
81
  if symptoms:
82
+ results = predict_disease(symptoms, tokenizer, model, name_encoder)
83
  api_response = {
84
  "status": "success",
85
  "input": symptoms,
 
91
  "message": "symptoms parameter required"
92
  }
93
 
 
94
  st.markdown(
95
  f"```json\n{json.dumps(api_response, ensure_ascii=False, indent=2)}\n```"
96
  )
 
101
  st.success("Model yükləndi!")
102
 
103
  # Debug: Siniflər
104
+ st.write("Available classes:", list(name_encoder.classes_))
 
 
 
105
 
106
  # API usage info
107
  st.markdown("### API İstifadəsi")
 
109
  api_example = f"{space_url}/?api=true&symptoms=fever,cough,headache"
110
  st.code(api_example)
111
 
112
+ # Form
113
  with st.form(key="predict_form"):
114
  text = st.text_area("Simptomları daxil edin (vergüllə ayırın):")
115
  submit_button = st.form_submit_button(label="Predict")
 
118
  if not text.strip():
119
  st.warning("Simptomları daxil edin!")
120
  else:
121
+ results = predict_disease(text, tokenizer, model, name_encoder)
122
  st.subheader("🔍 Nəticələr:")
123
  for result in results:
124
  st.write(f"**{result['disease']}** — {result['probability']*100:.2f}%")