Reyall commited on
Commit
c123726
·
verified ·
1 Parent(s): 025f941

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -42
app.py CHANGED
@@ -7,64 +7,136 @@ import random
7
  import os
8
  from safetensors.torch import load_file
9
 
10
-
11
  # Model və label_encoder yüklənməsi
12
  def load_model():
13
- # Label encoder
14
- with open("best_model/label_encoder.pkl", "rb") as f:
15
- label_encoder = pickle.load(f)
16
-
17
- # Tokenizer
18
- tokenizer = BertTokenizer.from_pretrained("best_model")
19
-
20
- # Model (safetensors avtomatik dəstəklənir)
21
- model = BertForSequenceClassification.from_pretrained("best_model", use_safetensors=True)
22
- model.eval()
23
-
24
- return tokenizer, model, label_encoder
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
 
26
  tokenizer, model, label_encoder = load_model()
27
 
28
  # Prediction funksiyası
29
  def predict_disease(text):
 
 
 
30
  if not text.strip():
31
- return "Please enter some symptoms!"
32
-
33
  symptoms = [s.strip() for s in text.split(",") if s.strip()]
34
  if not symptoms:
35
- return "Please enter valid symptoms separated by commas!"
36
-
37
- agg_probs = defaultdict(float)
38
- n_shuffles = 10
39
- for _ in range(n_shuffles):
40
- random.shuffle(symptoms)
41
- shuffled_text = ", ".join(symptoms)
42
- inputs = tokenizer(shuffled_text, return_tensors="pt", truncation=True, padding=True, max_length=128)
43
- with torch.no_grad():
44
- outputs = model(**inputs)
45
- probs = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze()
46
- for i, p in enumerate(probs):
47
- agg_probs[i] += p.item()
48
 
49
- for k in agg_probs:
50
- agg_probs[k] /= n_shuffles
51
-
52
- top_3 = sorted(agg_probs.items(), key=lambda x: x[1], reverse=True)[:3]
53
- results = []
54
- for idx, prob in top_3:
55
- label = label_encoder.classes_[idx]
56
- results.append(f"{label} — Probability: {prob*100:.2f}%")
57
- return "\n".join(results)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  # Gradio interface
60
  iface = gr.Interface(
61
  fn=predict_disease,
62
- inputs=gr.Textbox(lines=2, placeholder="Enter your symptoms separated by commas"),
63
- outputs=gr.Textbox(),
64
- title="Disease NLP Classifier",
65
- description="Enter your symptoms (comma separated) and get top 3 predicted diseases."
 
 
 
 
 
 
 
 
 
 
 
66
  )
67
 
68
  # Launch
69
  if __name__ == "__main__":
70
- iface.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import os
8
  from safetensors.torch import load_file
9
 
 
10
  # Model və label_encoder yüklənməsi
11
  def load_model():
12
+ try:
13
+ # Label encoder
14
+ with open("best_model/label_encoder.pkl", "rb") as f:
15
+ label_encoder = pickle.load(f)
16
+
17
+ # Tokenizer
18
+ tokenizer = BertTokenizer.from_pretrained("best_model")
19
+
20
+ # Model (safetensors avtomatik dəstəklənir)
21
+ model = BertForSequenceClassification.from_pretrained("best_model", use_safetensors=True)
22
+ model.eval()
23
+
24
+ print(f"Model uğurla yükləndi. Label sayı: {len(label_encoder.classes_)}")
25
+ return tokenizer, model, label_encoder
26
+
27
+ except Exception as e:
28
+ print(f"Model yüklənmə xətası: {e}")
29
+ # Faylları yoxla
30
+ if os.path.exists("best_model"):
31
+ files = os.listdir("best_model")
32
+ print(f"best_model qovluğundakı fayllar: {files}")
33
+ else:
34
+ print("best_model qovluğu mövcud deyil")
35
+
36
+ return None, None, None
37
 
38
+ # Model yükləmə
39
  tokenizer, model, label_encoder = load_model()
40
 
41
  # Prediction funksiyası
42
  def predict_disease(text):
43
+ if tokenizer is None or model is None or label_encoder is None:
44
+ return "❌ Model yüklənməyib! Xəta var."
45
+
46
  if not text.strip():
47
+ return "⚠️ Please enter some symptoms!"
48
+
49
  symptoms = [s.strip() for s in text.split(",") if s.strip()]
50
  if not symptoms:
51
+ return "⚠️ Please enter valid symptoms separated by commas!"
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ try:
54
+ agg_probs = defaultdict(float)
55
+ n_shuffles = 10
56
+
57
+ for i in range(n_shuffles):
58
+ random.shuffle(symptoms)
59
+ shuffled_text = ", ".join(symptoms)
60
+
61
+ inputs = tokenizer(
62
+ shuffled_text,
63
+ return_tensors="pt",
64
+ truncation=True,
65
+ padding=True,
66
+ max_length=128
67
+ )
68
+
69
+ with torch.no_grad():
70
+ outputs = model(**inputs)
71
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze()
72
+
73
+ for idx, p in enumerate(probs):
74
+ agg_probs[idx] += p.item()
75
+
76
+ # Ortalama hesabla
77
+ for k in agg_probs:
78
+ agg_probs[k] /= n_shuffles
79
+
80
+ # Top 3 nəticə
81
+ top_3 = sorted(agg_probs.items(), key=lambda x: x[1], reverse=True)[:3]
82
+
83
+ results = ["🏥 Top 3 Predicted Diseases:\n"]
84
+ for idx, prob in top_3:
85
+ label = label_encoder.classes_[idx]
86
+ results.append(f"• **{label}** — Probability: {prob*100:.2f}%")
87
+
88
+ return "\n".join(results)
89
+
90
+ except Exception as e:
91
+ return f"❌ Prediction xətası: {str(e)}"
92
 
93
  # Gradio interface
94
  iface = gr.Interface(
95
  fn=predict_disease,
96
+ inputs=gr.Textbox(
97
+ lines=2,
98
+ placeholder="fever, cough, headache, shortness of breath",
99
+ label="Enter your symptoms (comma separated)"
100
+ ),
101
+ outputs=gr.Textbox(label="Predicted Diseases"),
102
+ title="🏥 Disease NLP Classifier",
103
+ description="Enter your symptoms separated by commas and get top 3 predicted diseases with confidence scores.",
104
+ examples=[
105
+ ["fever, cough, headache"],
106
+ ["stomach pain, nausea, vomiting"],
107
+ ["chest pain, shortness of breath"],
108
+ ["dizziness, fatigue, weakness"],
109
+ ["skin rash, itching, redness"]
110
+ ]
111
  )
112
 
113
  # Launch
114
  if __name__ == "__main__":
115
+ if tokenizer and model and label_encoder:
116
+ print("✅ Model hazırdır, Gradio başladılır...")
117
+ iface.launch(
118
+ server_name="0.0.0.0",
119
+ server_port=int(os.environ.get("PORT", 7860)),
120
+ share=True # Public link yaradır
121
+ )
122
+ else:
123
+ print("❌ Model yüklənmədi, Gradio başladıla bilmir!")
124
+ print("\nDebug məlumatları:")
125
+ print(f"Hazırkı qovluq: {os.getcwd()}")
126
+ print(f"Qovluq məzmunu: {os.listdir('.')}")
127
+
128
+ # Sadə debug interface
129
+ def debug_info():
130
+ return f"Debug məlumatları:\nHazırkı qovluq: {os.getcwd()}\nFayllar: {os.listdir('.')}"
131
+
132
+ debug_iface = gr.Interface(
133
+ fn=debug_info,
134
+ inputs=gr.Textbox(placeholder="Debug üçün hər hansı mətn yazın"),
135
+ outputs=gr.Textbox(),
136
+ title="🔧 Debug Interface"
137
+ )
138
+
139
+ debug_iface.launch(
140
+ server_name="0.0.0.0",
141
+ server_port=int(os.environ.get("PORT", 7860))
142
+ )