sohug commited on
Commit
ccf6609
·
1 Parent(s): 9a20999

targeting 80% recall

Browse files
Files changed (1) hide show
  1. app.py +74 -38
app.py CHANGED
@@ -9,7 +9,7 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
9
 
10
  # 2. Load Fine-Tuned Model and Processor
11
  model_path = "actorcritic/twak" # <-- Change this back to just the repo name
12
- #model_path = "/home/shohog/Documents/twok"
13
 
14
  try:
15
  # Add subfolder="model" here
@@ -48,8 +48,28 @@ disease_info = {
48
  "Vitiligo": "শ্বেতী (Vitiligo) হলো ত্বকের এমন একটি অবস্থা যেখানে ত্বক থেকে মেলানিন রঞ্জক পদার্থ নষ্ট হয়ে সাদা ছোপের সৃষ্টি হয়।"
49
  }
50
 
51
- # Confidence threshold
52
- CONFIDENCE_THRESHOLD = 80.0 # percent
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  # 4. Prediction Function
55
  def predict(image):
@@ -68,50 +88,66 @@ def predict(image):
68
 
69
  logits = outputs.logits
70
  scores = torch.sigmoid(logits)[0]
 
 
 
71
 
72
- # Get top prediction
73
- top_idx = torch.argmax(scores).item()
74
- confidence = scores[top_idx].item() * 100
75
-
76
- # Check confidence threshold
77
- if confidence < CONFIDENCE_THRESHOLD:
78
- # Show "Unknown" with the closest guess in tooltip
79
- guessed_class = ft_model.config.id2label.get(top_idx, f"Class {top_idx}")
80
- html_output = f"""
81
- <div class="result-container">
82
- <span class="disease-name unknown">UNK</span>
 
 
 
 
 
 
 
83
  <div class="tooltip">
84
- <span class="info-icon">?</span>
85
  <div class="tooltiptext">
86
- <strong>Low Confidence Prediction</strong><br><br>
87
- The model is not confident enough for a reliable diagnosis.<br><br>
88
- Closest match: <b>{guessed_class}</b> ({confidence:.1f}%)<br><br>
89
- <i>Please consult a healthcare professional for proper diagnosis.</i>
90
  </div>
91
  </div>
92
  </div>
93
  """
94
- return html_output
95
-
96
- predicted_class = ft_model.config.id2label.get(top_idx, f"Class {top_idx}")
97
- info_text = disease_info.get(predicted_class, "Fundamental details for this condition are not available.")
98
-
99
- # Build HTML output with tooltip for confident predictions
100
- html_output = f"""
101
- <div class="result-container">
102
- <span class="disease-name">{predicted_class}</span>
103
- <div class="tooltip">
104
- <span class="info-icon">i</span>
105
- <div class="tooltiptext">
106
- <strong>{predicted_class}</strong><br><br>
107
- {info_text}
 
 
 
 
 
 
 
 
 
108
  </div>
109
- </div>
110
- </div>
111
- """
112
  return html_output
113
 
114
-
115
  # 5. Custom CSS
116
  custom_css = """
117
  .result-container {
@@ -263,4 +299,4 @@ with gr.Blocks() as demo:
263
 
264
  predict_btn.click(fn=predict, inputs=image_input, outputs=output_html)
265
 
266
- demo.launch(css=custom_css, theme=gr.themes.Default())
 
9
 
10
  # 2. Load Fine-Tuned Model and Processor
11
  model_path = "actorcritic/twak" # <-- Change this back to just the repo name
12
+ model_path = "/home/shohog/Documents/twok"
13
 
14
  try:
15
  # Add subfolder="model" here
 
48
  "Vitiligo": "শ্বেতী (Vitiligo) হলো ত্বকের এমন একটি অবস্থা যেখানে ত্বক থেকে মেলানিন রঞ্জক পদার্থ নষ্ট হয়ে সাদা ছোপের সৃষ্টি হয়।"
49
  }
50
 
51
+
52
+ # 15 Original Folders (Index 6 is the 'Others/Healthy' folder)
53
+ ORIGINAL_CLASSES =[
54
+ "Acne", "Arsenic", "Atopic_Dermatitis", "Candidal_Intertrigo",
55
+ "Contact_Dermatitis", "Eczema", "Healthy_or_Others", "Psoriasis",
56
+ "Scabies", "Seborrheic_Dermatitis", "Steroid_Modified_Tinea",
57
+ "Tinea_Corporis", "Tinea_Cruris", "Tinea_Faciei", "Vitiligo"
58
+ ]
59
+
60
+ # Optimal Thresholds per Output Neuron (Targeting 80% Recall)
61
+ CLASS_THRESHOLDS = {
62
+ 0: 0.89, 1: 0.88, 2: 0.34, 3: 0.04, 4: 0.66, 5: 0.48,
63
+ 6: 0.28, 7: 0.73, 8: 0.54, 9: 0.44, 10: 0.68, 11: 0.83,
64
+ 12: 0.85, 13: 0.95
65
+ }
66
+
67
+ def get_original_class_idx(neuron_idx):
68
+ """Maps the 14 output neurons back to the original 15 folder indices."""
69
+ if neuron_idx < 6:
70
+ return neuron_idx
71
+ else:
72
+ return neuron_idx + 1 # Shift back to account for the skipped Class 6
73
 
74
  # 4. Prediction Function
75
  def predict(image):
 
88
 
89
  logits = outputs.logits
90
  scores = torch.sigmoid(logits)[0]
91
+ print(logits, scores)
92
+
93
+ detected_diseases =[]
94
 
95
+ # Check each of the 14 neurons against its specific threshold
96
+ for neuron_idx in range(14):
97
+ prob = scores[neuron_idx].item()
98
+ if prob >= CLASS_THRESHOLDS[neuron_idx]:
99
+ orig_idx = get_original_class_idx(neuron_idx)
100
+ detected_diseases.append((orig_idx, prob))
101
+
102
+ html_output = '<div class="result-container">'
103
+
104
+ # SCENARIO A: No thresholds were crossed -> It's Class 6 (Healthy/Others)
105
+ if len(detected_diseases) == 0:
106
+ predicted_class = "Healthy / Others"
107
+ info_text = "কোনো চর্মরোগ শনাক্ত হয়নি। ত্বক সুস্থ অথবা এটি অন্য কোনো সাধারণ অবস্থা হতে পারে।"
108
+
109
+ # Added flex-direction: column here to stack them vertically
110
+ html_output += f"""
111
+ <div style="display: flex; flex-direction: column; align-items: center; justify-content: center; gap: 8px;">
112
+ <span class="disease-name" style="color: #10b981;">{predicted_class}</span>
113
  <div class="tooltip">
114
+ <span class="info-icon">i</span>
115
  <div class="tooltiptext">
116
+ <strong>{predicted_class}</strong><br><br>
117
+ {info_text}
 
 
118
  </div>
119
  </div>
120
  </div>
121
  """
122
+
123
+ # SCENARIO B: One or more diseases detected
124
+ else:
125
+ # Sort by probability (highest confidence first)
126
+ detected_diseases.sort(key=lambda x: x[1], reverse=True)
127
+
128
+ for orig_idx, prob in detected_diseases:
129
+ predicted_class = ORIGINAL_CLASSES[orig_idx]
130
+ info_text = disease_info.get(predicted_class, "Fundamental details for this condition are not available.")
131
+ confidence = prob * 100
132
+ print(confidence)
133
+
134
+ # Added flex-direction: column here to stack them vertically
135
+ html_output += f"""
136
+ <div style="display: flex; flex-direction: column; align-items: center; justify-content: center; gap: 8px; margin-bottom: 15px; width: 100%;">
137
+ <span class="disease-name">{predicted_class} <span style="font-size: 16px; color: #6b7280;">({confidence:.1f}%)</span></span>
138
+ <div class="tooltip">
139
+ <span class="info-icon">i</span>
140
+ <div class="tooltiptext">
141
+ <strong>{predicted_class}</strong><br><br>
142
+ {info_text}
143
+ </div>
144
+ </div>
145
  </div>
146
+ """
147
+
148
+ html_output += '</div>'
149
  return html_output
150
 
 
151
  # 5. Custom CSS
152
  custom_css = """
153
  .result-container {
 
299
 
300
  predict_btn.click(fn=predict, inputs=image_input, outputs=output_html)
301
 
302
+ demo.launch(css=custom_css, theme=gr.themes.Default())