Reyall commited on
Commit
daf2580
·
verified ·
1 Parent(s): 89c2d96

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -64
app.py CHANGED
@@ -5,34 +5,31 @@ import pickle
5
  from collections import defaultdict
6
  import random
7
  import os
8
- from safetensors.torch import load_file
9
 
10
  # Model və label_encoder yüklənməsi
11
  def load_model():
12
  try:
13
  # Label encoder
14
- with open("best_model/label_encoder.pkl", "rb") as f:
15
  label_encoder = pickle.load(f)
16
 
17
- # Tokenizer
18
- tokenizer = BertTokenizer.from_pretrained("best_model")
 
 
 
 
 
 
 
19
 
20
- # Model (safetensors avtomatik dəstəklənir)
21
- model = BertForSequenceClassification.from_pretrained("best_model", use_safetensors=True)
22
  model.eval()
23
-
24
- print(f"Model uğurla yükləndi. Label sayı: {len(label_encoder.classes_)}")
25
  return tokenizer, model, label_encoder
26
 
27
  except Exception as e:
28
- print(f"Model yüklənmə xətası: {e}")
29
- # Faylları yoxla
30
- if os.path.exists("best_model"):
31
- files = os.listdir("best_model")
32
- print(f"best_model qovluğundakı fayllar: {files}")
33
- else:
34
- print("best_model qovluğu mövcud deyil")
35
-
36
  return None, None, None
37
 
38
  # Model yükləmə
@@ -41,18 +38,18 @@ tokenizer, model, label_encoder = load_model()
41
  # Prediction funksiyası
42
  def predict_disease(text):
43
  if tokenizer is None or model is None or label_encoder is None:
44
- return "❌ Model yüklənməyib! Xəta var."
45
 
46
  if not text.strip():
47
- return "⚠️ Please enter some symptoms!"
48
 
49
  symptoms = [s.strip() for s in text.split(",") if s.strip()]
50
  if not symptoms:
51
- return "⚠️ Please enter valid symptoms separated by commas!"
52
 
53
  try:
54
  agg_probs = defaultdict(float)
55
- n_shuffles = 10
56
 
57
  for i in range(n_shuffles):
58
  random.shuffle(symptoms)
@@ -80,63 +77,44 @@ def predict_disease(text):
80
  # Top 3 nəticə
81
  top_3 = sorted(agg_probs.items(), key=lambda x: x[1], reverse=True)[:3]
82
 
83
- results = ["🏥 Top 3 Predicted Diseases:\n"]
84
  for idx, prob in top_3:
85
  label = label_encoder.classes_[idx]
86
- results.append(f" **{label}** — Probability: {prob*100:.2f}%")
87
 
88
- return "\n".join(results)
89
 
90
  except Exception as e:
91
- return f"❌ Prediction xətası: {str(e)}"
92
 
93
  # Gradio interface
94
- iface = gr.Interface(
95
  fn=predict_disease,
96
  inputs=gr.Textbox(
97
- lines=2,
98
- placeholder="fever, cough, headache, shortness of breath",
99
- label="Enter your symptoms (comma separated)"
 
 
 
 
100
  ),
101
- outputs=gr.Textbox(label="Predicted Diseases"),
102
- title="🏥 Disease NLP Classifier",
103
- description="Enter your symptoms separated by commas and get top 3 predicted diseases with confidence scores.",
104
  examples=[
105
  ["fever, cough, headache"],
106
- ["stomach pain, nausea, vomiting"],
107
  ["chest pain, shortness of breath"],
108
- ["dizziness, fatigue, weakness"],
109
- ["skin rash, itching, redness"]
110
- ]
 
 
111
  )
112
 
113
- # Launch
114
  if __name__ == "__main__":
115
- if tokenizer and model and label_encoder:
116
- print("✅ Model hazırdır, Gradio başladılır...")
117
- iface.launch(
118
- server_name="0.0.0.0",
119
- server_port=int(os.environ.get("PORT", 7860)),
120
- share=True # Public link yaradır
121
- )
122
- else:
123
- print("❌ Model yüklənmədi, Gradio başladıla bilmir!")
124
- print("\nDebug məlumatları:")
125
- print(f"Hazırkı qovluq: {os.getcwd()}")
126
- print(f"Qovluq məzmunu: {os.listdir('.')}")
127
-
128
- # Sadə debug interface
129
- def debug_info():
130
- return f"Debug məlumatları:\nHazırkı qovluq: {os.getcwd()}\nFayllar: {os.listdir('.')}"
131
-
132
- debug_iface = gr.Interface(
133
- fn=debug_info,
134
- inputs=gr.Textbox(placeholder="Debug üçün hər hansı mətn yazın"),
135
- outputs=gr.Textbox(),
136
- title="🔧 Debug Interface"
137
- )
138
-
139
- debug_iface.launch(
140
- server_name="0.0.0.0",
141
- server_port=int(os.environ.get("PORT", 7860))
142
- )
 
5
  from collections import defaultdict
6
  import random
7
  import os
 
8
 
9
  # Model və label_encoder yüklənməsi
10
  def load_model():
11
  try:
12
  # Label encoder
13
+ with open("label_encoder.pkl", "rb") as f:
14
  label_encoder = pickle.load(f)
15
 
16
+ # Tokenizer və Model - iki variant sınayırıq
17
+ try:
18
+ # Variant 1: best_model qovluğu
19
+ tokenizer = BertTokenizer.from_pretrained("best_model")
20
+ model = BertForSequenceClassification.from_pretrained("best_model")
21
+ except:
22
+ # Variant 2: Ana qovluq
23
+ tokenizer = BertTokenizer.from_pretrained(".")
24
+ model = BertForSequenceClassification.from_pretrained(".")
25
 
 
 
26
  model.eval()
27
+ print(f"✅ Model yükləndi. Label sayı: {len(label_encoder.classes_)}")
 
28
  return tokenizer, model, label_encoder
29
 
30
  except Exception as e:
31
+ print(f"Model yüklənmə xətası: {e}")
32
+ print(f"📁 Mövcud fayllar: {os.listdir('.')}")
 
 
 
 
 
 
33
  return None, None, None
34
 
35
  # Model yükləmə
 
38
  # Prediction funksiyası
39
  def predict_disease(text):
40
  if tokenizer is None or model is None or label_encoder is None:
41
+ return "❌ Model yüklənməyib!"
42
 
43
  if not text.strip():
44
+ return "⚠️ Simptomları daxil edin!"
45
 
46
  symptoms = [s.strip() for s in text.split(",") if s.strip()]
47
  if not symptoms:
48
+ return "⚠️ Düzgün simptomlar yazın (vergüllə ayırın)!"
49
 
50
  try:
51
  agg_probs = defaultdict(float)
52
+ n_shuffles = 5 # Sürəti artırmaq üçün azaltdım
53
 
54
  for i in range(n_shuffles):
55
  random.shuffle(symptoms)
 
77
  # Top 3 nəticə
78
  top_3 = sorted(agg_probs.items(), key=lambda x: x[1], reverse=True)[:3]
79
 
80
+ result = "🏥 **Mümkün xəstəliklər:**\n\n"
81
  for idx, prob in top_3:
82
  label = label_encoder.classes_[idx]
83
+ result += f"🔸 **{label}** — %{prob*100:.1f}\n"
84
 
85
+ return result
86
 
87
  except Exception as e:
88
+ return f"❌ Xəta: {str(e)}"
89
 
90
  # Gradio interface
91
+ demo = gr.Interface(
92
  fn=predict_disease,
93
  inputs=gr.Textbox(
94
+ lines=3,
95
+ placeholder="Məsələn: fever, cough, headache",
96
+ label="🩺 Simptomlarınızı yazın (vergüllə ayırın)"
97
+ ),
98
+ outputs=gr.Textbox(
99
+ label="📋 Nəticələr",
100
+ lines=10
101
  ),
102
+ title="🏥 Xəstəlik Təyin Edici AI",
103
+ description="Simptomlarınızı yazın ən mümkün xəstəlikləri görün. ⚠️ Bu sadəcə kömək vasitəsidir, həkim məsləhəti əvəz etmir!",
 
104
  examples=[
105
  ["fever, cough, headache"],
106
+ ["stomach pain, nausea"],
107
  ["chest pain, shortness of breath"],
108
+ ["dizziness, fatigue"],
109
+ ["skin rash, itching"]
110
+ ],
111
+ theme=gr.themes.Soft(),
112
+ allow_flagging="never"
113
  )
114
 
 
115
  if __name__ == "__main__":
116
+ demo.launch(
117
+ server_name="0.0.0.0",
118
+ server_port=7860,
119
+ show_api=True # API dokumentasiyası göstərir
120
+ )