HMWCS commited on
Commit
ddca90f
·
2 Parent(s): 77cc30a f0a45ec

feat: add confidence score feature to model predictions

Browse files
Files changed (3) hide show
  1. app.py +13 -9
  2. classifier.py +99 -26
  3. requirements.txt +1 -1
app.py CHANGED
@@ -14,7 +14,6 @@ import os
14
  from classifier import GarbageClassifier
15
  from config import Config
16
 
17
-
18
  # Initialize classifier
19
  config = Config()
20
  classifier = GarbageClassifier(config)
@@ -30,14 +29,14 @@ def classify_garbage_impl(image):
30
  Actual classification implementation
31
  """
32
  if image is None:
33
- return "Please upload an image", "No image provided"
34
 
35
  try:
36
- classification, full_response = classifier.classify_image(image)
37
- return classification, full_response
 
38
  except Exception as e:
39
- return "Error", f"Classification failed: {str(e)}"
40
-
41
 
42
  # Apply GPU decorator based on environment
43
  if HF_SPACES:
@@ -78,6 +77,11 @@ with gr.Blocks(title="Garbage Classification System") as demo:
78
  placeholder="Upload an image and click classify",
79
  )
80
 
 
 
 
 
 
81
  full_response_output = gr.Textbox(
82
  label="Detailed Analysis",
83
  placeholder="Detailed reasoning will appear here",
@@ -102,15 +106,15 @@ with gr.Blocks(title="Garbage Classification System") as demo:
102
  classify_btn.click(
103
  fn=classify_garbage,
104
  inputs=image_input,
105
- outputs=[classification_output, full_response_output],
106
  )
107
 
108
  # Auto-classify on image upload
109
  image_input.change(
110
  fn=classify_garbage,
111
  inputs=image_input,
112
- outputs=[classification_output, full_response_output],
113
  )
114
 
115
  if __name__ == "__main__":
116
- demo.launch()
 
14
  from classifier import GarbageClassifier
15
  from config import Config
16
 
 
17
  # Initialize classifier
18
  config = Config()
19
  classifier = GarbageClassifier(config)
 
29
  Actual classification implementation
30
  """
31
  if image is None:
32
+ return "Please upload an image", "No image provided", "N/A"
33
 
34
  try:
35
+ classification, full_response, confidence_score = classifier.classify_image(image)
36
+ confidence_display = f"{confidence_score}/10"
37
+ return classification, full_response, confidence_display
38
  except Exception as e:
39
+ return "Error", f"Classification failed: {str(e)}", "0/10"
 
40
 
41
  # Apply GPU decorator based on environment
42
  if HF_SPACES:
 
77
  placeholder="Upload an image and click classify",
78
  )
79
 
80
+ confidence_output = gr.Textbox(
81
+ label="Confidence Score",
82
+ placeholder="Confidence score will appear here",
83
+ )
84
+
85
  full_response_output = gr.Textbox(
86
  label="Detailed Analysis",
87
  placeholder="Detailed reasoning will appear here",
 
106
  classify_btn.click(
107
  fn=classify_garbage,
108
  inputs=image_input,
109
+ outputs=[classification_output, full_response_output, confidence_output]
110
  )
111
 
112
  # Auto-classify on image upload
113
  image_input.change(
114
  fn=classify_garbage,
115
  inputs=image_input,
116
+ outputs=[classification_output, full_response_output, confidence_output]
117
  )
118
 
119
  if __name__ == "__main__":
120
+ demo.launch()
classifier.py CHANGED
@@ -5,7 +5,7 @@ import logging
5
  from typing import Union, Tuple
6
  from config import Config
7
  from knowledge_base import GarbageClassificationKnowledge
8
-
9
 
10
  class GarbageClassifier:
11
  def __init__(self, config: Config = None):
@@ -86,7 +86,7 @@ class GarbageClassifier:
86
 
87
  return processed_image
88
 
89
- def classify_image(self, image: Union[str, Image.Image]) -> Tuple[str, str]:
90
  """
91
  Classify garbage in the image
92
 
@@ -94,7 +94,7 @@ class GarbageClassifier:
94
  image: PIL Image or path to image file
95
 
96
  Returns:
97
- Tuple of (classification_result, detailed_analysis)
98
  """
99
  if self.model is None or self.processor is None:
100
  raise RuntimeError("Model not loaded. Call load_model() first.")
@@ -126,7 +126,7 @@ class GarbageClassifier:
126
  {"type": "image", "image": processed_image},
127
  {
128
  "type": "text",
129
- "text": "Please classify what you see in this image. If it shows garbage/waste items, classify them according to the garbage classification standards. If it shows people, living things, or other non-waste items, classify it as 'Unable to classify' and explain why it's not garbage.",
130
  },
131
  ],
132
  },
@@ -158,14 +158,87 @@ class GarbageClassifier:
158
  # Extract reasoning from response
159
  reasoning = self._extract_reasoning(response)
160
 
161
- return classification, reasoning
 
 
 
162
 
163
  except Exception as e:
164
  self.logger.error(f"Error during classification: {str(e)}")
165
  import traceback
166
 
167
  traceback.print_exc()
168
- return "Error", f"Classification failed: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
  def _extract_classification(self, response: str) -> str:
171
  """Extract the main classification from the response"""
@@ -268,43 +341,43 @@ class GarbageClassifier:
268
  def _extract_reasoning(self, response: str) -> str:
269
  """Extract only the reasoning content, removing all formatting markers and classification info"""
270
  import re
271
-
272
  # Remove all formatting markers
273
  cleaned_response = response.replace("**Classification**:", "")
274
  cleaned_response = cleaned_response.replace("**Reasoning**:", "")
275
  cleaned_response = re.sub(r'\*\*.*?\*\*:', '', cleaned_response) # Remove any **text**: patterns
276
  cleaned_response = cleaned_response.replace("**", "") # Remove remaining ** markers
277
-
278
  # Remove category names that might appear at the beginning
279
  categories = self.knowledge.get_categories()
280
  for category in categories:
281
  if cleaned_response.strip().startswith(category):
282
  cleaned_response = cleaned_response.replace(category, "", 1)
283
  break
284
-
285
  # Remove common material names that might appear at the beginning
286
  material_names = [
287
- "Glass", "Plastic", "Metal", "Paper", "Cardboard", "Aluminum",
288
  "Steel", "Iron", "Tin", "Foil", "Wood", "Ceramic", "Fabric",
289
  "Recyclable Waste", "Food/Kitchen Waste", "Hazardous Waste", "Other Waste"
290
  ]
291
-
292
  # Clean the response
293
  cleaned_response = cleaned_response.strip()
294
-
295
  # Remove material names at the beginning
296
  for material in material_names:
297
  if cleaned_response.startswith(material):
298
  # Remove the material name and any following punctuation/whitespace
299
  cleaned_response = cleaned_response[len(material):].lstrip(" .,;:")
300
  break
301
-
302
  # Split into sentences and clean up
303
  sentences = []
304
-
305
  # Split by common sentence endings, but keep the endings
306
  parts = re.split(r'([.!?])\s+', cleaned_response)
307
-
308
  # Rejoin parts to maintain sentence structure
309
  reconstructed_parts = []
310
  for i in range(0, len(parts), 2):
@@ -313,49 +386,49 @@ class GarbageClassifier:
313
  if i + 1 < len(parts):
314
  sentence += parts[i + 1] # Add the punctuation back
315
  reconstructed_parts.append(sentence)
316
-
317
  for part in reconstructed_parts:
318
  part = part.strip()
319
  if not part:
320
  continue
321
-
322
  # Skip parts that are just category names or material names
323
  if part in categories or part.rstrip(".,;:") in material_names:
324
  continue
325
-
326
  # Skip parts that start with category names or material names
327
  is_category_line = False
328
  for item in categories + material_names:
329
  if part.startswith(item):
330
  is_category_line = True
331
  break
332
-
333
  if is_category_line:
334
  continue
335
-
336
  # Clean up the sentence
337
  part = re.sub(r'^[A-Za-z\s]+:', '', part).strip() # Remove "Category:" type prefixes
338
-
339
  if part and len(part) > 3: # Only keep meaningful content
340
  sentences.append(part)
341
-
342
  # Join sentences
343
  reasoning = ' '.join(sentences)
344
-
345
  # Final cleanup - remove any remaining standalone material words at the beginning
346
  reasoning_words = reasoning.split()
347
  if reasoning_words and reasoning_words[0] in [m.lower() for m in material_names]:
348
  reasoning_words = reasoning_words[1:]
349
  reasoning = ' '.join(reasoning_words)
350
-
351
  # Ensure proper capitalization
352
  if reasoning:
353
  reasoning = reasoning[0].upper() + reasoning[1:] if len(reasoning) > 1 else reasoning.upper()
354
-
355
  # Ensure proper punctuation
356
  if not reasoning.endswith(('.', '!', '?')):
357
  reasoning += '.'
358
-
359
  return reasoning if reasoning else "Analysis not available"
360
 
361
  def get_categories_info(self):
 
5
  from typing import Union, Tuple
6
  from config import Config
7
  from knowledge_base import GarbageClassificationKnowledge
8
+ import re
9
 
10
  class GarbageClassifier:
11
  def __init__(self, config: Config = None):
 
86
 
87
  return processed_image
88
 
89
+ def classify_image(self, image: Union[str, Image.Image]) -> Tuple[str, str, int]:
90
  """
91
  Classify garbage in the image
92
 
 
94
  image: PIL Image or path to image file
95
 
96
  Returns:
97
+ Tuple of (classification_result, detailed_analysis, confidence_score)
98
  """
99
  if self.model is None or self.processor is None:
100
  raise RuntimeError("Model not loaded. Call load_model() first.")
 
126
  {"type": "image", "image": processed_image},
127
  {
128
  "type": "text",
129
+ "text": "Please classify what you see in this image. If it shows garbage/waste items, classify them according to the garbage classification standards. If it shows people, living things, or other non-waste items, classify it as 'Unable to classify' and explain why it's not garbage. Also provide a confidence score from 1-10 indicating how certain you are about your classification.",
130
  },
131
  ],
132
  },
 
158
  # Extract reasoning from response
159
  reasoning = self._extract_reasoning(response)
160
 
161
+ # Extract confidence score from response
162
+ confidence_score = self._extract_confidence_score(response, classification)
163
+
164
+ return classification, reasoning, confidence_score
165
 
166
  except Exception as e:
167
  self.logger.error(f"Error during classification: {str(e)}")
168
  import traceback
169
 
170
  traceback.print_exc()
171
+ return "Error", f"Classification failed: {str(e)}", 0
172
+
173
+
174
+ def _calculate_confidence_heuristic(self, response_lower: str, classification: str) -> int:
175
+ """Calculate confidence based on response content and classification type"""
176
+ base_confidence = 5
177
+
178
+ # Confidence indicators (increase confidence)
179
+ high_confidence_words = ["clearly", "obviously", "definitely", "certainly", "exactly"]
180
+ medium_confidence_words = ["appears", "seems", "likely", "probably"]
181
+
182
+ # Uncertainty indicators (decrease confidence)
183
+ uncertainty_words = ["might", "could", "possibly", "maybe", "unclear", "difficult"]
184
+
185
+ # Adjust based on confidence words
186
+ for word in high_confidence_words:
187
+ if word in response_lower:
188
+ base_confidence += 2
189
+ break
190
+
191
+ for word in medium_confidence_words:
192
+ if word in response_lower:
193
+ base_confidence += 1
194
+ break
195
+
196
+ for word in uncertainty_words:
197
+ if word in response_lower:
198
+ base_confidence -= 2
199
+ break
200
+
201
+ # Classification-specific adjustments
202
+ if classification == "Unable to classify":
203
+ if any(indicator in response_lower for indicator in ["person", "people", "human", "living"]):
204
+ base_confidence += 1 # High confidence when clearly not waste
205
+ else:
206
+ base_confidence -= 1 # Lower confidence for unclear items
207
+
208
+ elif classification == "Error":
209
+ base_confidence = 1
210
+
211
+ else:
212
+ # Check for specific material mentions (increases confidence)
213
+ specific_materials = ["aluminum", "plastic", "glass", "metal", "cardboard", "paper"]
214
+ if any(material in response_lower for material in specific_materials):
215
+ base_confidence += 1
216
+
217
+ return min(max(base_confidence, 1), 10)
218
+
219
+ def _extract_confidence_score(self, response: str, classification: str) -> int:
220
+ """Extract confidence score from response or calculate based on classification"""
221
+ response_lower = response.lower()
222
+
223
+ # Look for explicit confidence scores in the response
224
+ confidence_patterns = [
225
+ r'confidence[:\s]*(\d+)',
226
+ r'confident[:\s]*(\d+)',
227
+ r'certainty[:\s]*(\d+)',
228
+ r'score[:\s]*(\d+)',
229
+ r'(\d+)/10',
230
+ r'(\d+)\s*out\s*of\s*10'
231
+ ]
232
+
233
+ for pattern in confidence_patterns:
234
+ match = re.search(pattern, response_lower)
235
+ if match:
236
+ score = int(match.group(1))
237
+ return min(max(score, 1), 10) # Clamp between 1-10
238
+
239
+ # If no explicit score found, calculate based on classification indicators
240
+ return self._calculate_confidence_heuristic(response_lower, classification)
241
+
242
 
243
  def _extract_classification(self, response: str) -> str:
244
  """Extract the main classification from the response"""
 
341
  def _extract_reasoning(self, response: str) -> str:
342
  """Extract only the reasoning content, removing all formatting markers and classification info"""
343
  import re
344
+
345
  # Remove all formatting markers
346
  cleaned_response = response.replace("**Classification**:", "")
347
  cleaned_response = cleaned_response.replace("**Reasoning**:", "")
348
  cleaned_response = re.sub(r'\*\*.*?\*\*:', '', cleaned_response) # Remove any **text**: patterns
349
  cleaned_response = cleaned_response.replace("**", "") # Remove remaining ** markers
350
+
351
  # Remove category names that might appear at the beginning
352
  categories = self.knowledge.get_categories()
353
  for category in categories:
354
  if cleaned_response.strip().startswith(category):
355
  cleaned_response = cleaned_response.replace(category, "", 1)
356
  break
357
+
358
  # Remove common material names that might appear at the beginning
359
  material_names = [
360
+ "Glass", "Plastic", "Metal", "Paper", "Cardboard", "Aluminum",
361
  "Steel", "Iron", "Tin", "Foil", "Wood", "Ceramic", "Fabric",
362
  "Recyclable Waste", "Food/Kitchen Waste", "Hazardous Waste", "Other Waste"
363
  ]
364
+
365
  # Clean the response
366
  cleaned_response = cleaned_response.strip()
367
+
368
  # Remove material names at the beginning
369
  for material in material_names:
370
  if cleaned_response.startswith(material):
371
  # Remove the material name and any following punctuation/whitespace
372
  cleaned_response = cleaned_response[len(material):].lstrip(" .,;:")
373
  break
374
+
375
  # Split into sentences and clean up
376
  sentences = []
377
+
378
  # Split by common sentence endings, but keep the endings
379
  parts = re.split(r'([.!?])\s+', cleaned_response)
380
+
381
  # Rejoin parts to maintain sentence structure
382
  reconstructed_parts = []
383
  for i in range(0, len(parts), 2):
 
386
  if i + 1 < len(parts):
387
  sentence += parts[i + 1] # Add the punctuation back
388
  reconstructed_parts.append(sentence)
389
+
390
  for part in reconstructed_parts:
391
  part = part.strip()
392
  if not part:
393
  continue
394
+
395
  # Skip parts that are just category names or material names
396
  if part in categories or part.rstrip(".,;:") in material_names:
397
  continue
398
+
399
  # Skip parts that start with category names or material names
400
  is_category_line = False
401
  for item in categories + material_names:
402
  if part.startswith(item):
403
  is_category_line = True
404
  break
405
+
406
  if is_category_line:
407
  continue
408
+
409
  # Clean up the sentence
410
  part = re.sub(r'^[A-Za-z\s]+:', '', part).strip() # Remove "Category:" type prefixes
411
+
412
  if part and len(part) > 3: # Only keep meaningful content
413
  sentences.append(part)
414
+
415
  # Join sentences
416
  reasoning = ' '.join(sentences)
417
+
418
  # Final cleanup - remove any remaining standalone material words at the beginning
419
  reasoning_words = reasoning.split()
420
  if reasoning_words and reasoning_words[0] in [m.lower() for m in material_names]:
421
  reasoning_words = reasoning_words[1:]
422
  reasoning = ' '.join(reasoning_words)
423
+
424
  # Ensure proper capitalization
425
  if reasoning:
426
  reasoning = reasoning[0].upper() + reasoning[1:] if len(reasoning) > 1 else reasoning.upper()
427
+
428
  # Ensure proper punctuation
429
  if not reasoning.endswith(('.', '!', '?')):
430
  reasoning += '.'
431
+
432
  return reasoning if reasoning else "Analysis not available"
433
 
434
  def get_categories_info(self):
requirements.txt CHANGED
@@ -5,4 +5,4 @@ torchvision
5
  transformers >= 4.53
6
  accelerate
7
  timm
8
- gradio
 
5
  transformers >= 4.53
6
  accelerate
7
  timm
8
+ gradio