Karthikraj Sivakumar commited on
Commit
3072360
·
1 Parent(s): e3ce74a

add confidence scoring

Browse files
Files changed (1) hide show
  1. app.py +70 -19
app.py CHANGED
@@ -253,11 +253,51 @@ model.eval()
253
  print(f"Model loaded successfully! Using device: {device}")
254
 
255
  # ==========================================
256
- # 4. Prediction Function
257
  # ==========================================
258
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  def predict_captcha(image):
260
- """Predict CAPTCHA text from image"""
261
 
262
  # Preprocess
263
  img_tensor = preprocess_image(image).to(device)
@@ -266,23 +306,29 @@ def predict_captcha(image):
266
  with torch.no_grad():
267
  log_probs = model(img_tensor)
268
 
269
- # Greedy decoding
270
- _, max_indices = torch.max(log_probs, dim=2)
271
- max_indices = max_indices.squeeze(1).cpu().numpy()
272
 
273
- # CTC collapse (remove blanks and repeated tokens)
274
- collapsed = []
275
- prev = None
276
- for token in max_indices:
277
- if token != 0 and token != prev:
278
- collapsed.append(token)
279
- prev = token
280
 
281
- # Decode to text
282
- prediction = ''.join([idx_to_char.get(t, '') for t in collapsed])
 
 
 
 
 
 
 
283
 
284
- # Return with confidence info
285
- return prediction
 
 
 
 
 
286
 
287
  # ==========================================
288
  # 5. Gradio Interface
@@ -291,12 +337,12 @@ def predict_captcha(image):
291
  demo = gr.Interface(
292
  fn=predict_captcha,
293
  inputs=gr.Image(type="pil", label="Upload CAPTCHA Image"),
294
- outputs=gr.Textbox(label="Predicted CAPTCHA Text", scale=2),
295
  title="CAPTCHA Recognition System",
296
  description="""
297
  **CS4243 Mini Project - CAPTCHA Recognition using CRNN + CTC Loss**
298
 
299
- Upload a CAPTCHA image to see the model's prediction.
300
 
301
  **Model Architecture:**
302
  - ResNet-based CNN feature extraction (4 layers, 2 blocks each)
@@ -308,8 +354,13 @@ demo = gr.Interface(
308
  - Character Accuracy: 85.82%
309
  - Trained on 7,777 samples with heavy augmentation
310
 
 
 
 
 
 
311
  **Training Details:**
312
- - 14 iterations of experimentation
313
  - Data augmentation: rotation, shear, black lines, noise
314
  - Regularization: dropout, weight decay, early stopping
315
  """,
 
253
  print(f"Model loaded successfully! Using device: {device}")
254
 
255
  # ==========================================
256
+ # 4. Prediction Functions
257
  # ==========================================
258
 
259
+ def ctc_decode_with_confidence(log_probs, idx_to_char):
260
+ """
261
+ Decode CTC output with confidence score
262
+
263
+ Args:
264
+ log_probs: Log probabilities from model (T, 1, C)
265
+ idx_to_char: Character mapping dictionary
266
+
267
+ Returns:
268
+ prediction: Decoded text string
269
+ confidence: Average probability score (0-1)
270
+ """
271
+ # Convert log probs to regular probabilities
272
+ probs = torch.exp(log_probs).squeeze(1) # (T, C)
273
+
274
+ # Greedy decoding - get max probability and index at each timestep
275
+ max_probs, max_indices = torch.max(probs, dim=1)
276
+ max_probs = max_probs.cpu().numpy()
277
+ max_indices = max_indices.cpu().numpy()
278
+
279
+ # CTC collapse (remove blanks and repeated tokens)
280
+ collapsed_tokens = []
281
+ collapsed_probs = []
282
+ prev = None
283
+
284
+ for token, prob in zip(max_indices, max_probs):
285
+ if token != 0 and token != prev: # Not blank and not repeat
286
+ collapsed_tokens.append(token)
287
+ collapsed_probs.append(prob)
288
+ prev = token
289
+
290
+ # Decode to text
291
+ prediction = ''.join([idx_to_char.get(t, '') for t in collapsed_tokens])
292
+
293
+ # Calculate average confidence
294
+ confidence = float(np.mean(collapsed_probs)) if collapsed_probs else 0.0
295
+
296
+ return prediction, confidence
297
+
298
+
299
  def predict_captcha(image):
300
+ """Predict CAPTCHA text from image with confidence score"""
301
 
302
  # Preprocess
303
  img_tensor = preprocess_image(image).to(device)
 
306
  with torch.no_grad():
307
  log_probs = model(img_tensor)
308
 
309
+ # Decode with confidence
310
+ prediction, confidence = ctc_decode_with_confidence(log_probs, idx_to_char)
 
311
 
312
+ # Format output with confidence indicator
313
+ confidence_pct = confidence * 100
 
 
 
 
 
314
 
315
+ if confidence < 0.6:
316
+ status = "⚠️ Low Confidence"
317
+ note = "Result may be uncertain due to visual ambiguity (e.g., 0/o, i/1/l confusion)"
318
+ elif confidence < 0.75:
319
+ status = "⚡ Medium Confidence"
320
+ note = "Result is reasonably reliable"
321
+ else:
322
+ status = "✓ High Confidence"
323
+ note = "Result is highly reliable"
324
 
325
+ # Return formatted string
326
+ output = f"Prediction: {prediction}\n\n"
327
+ output += f"{status}\n"
328
+ output += f"Confidence: {confidence_pct:.1f}%\n\n"
329
+ output += f"{note}"
330
+
331
+ return output
332
 
333
  # ==========================================
334
  # 5. Gradio Interface
 
337
  demo = gr.Interface(
338
  fn=predict_captcha,
339
  inputs=gr.Image(type="pil", label="Upload CAPTCHA Image"),
340
+ outputs=gr.Textbox(label="Prediction Results", lines=6, scale=2),
341
  title="CAPTCHA Recognition System",
342
  description="""
343
  **CS4243 Mini Project - CAPTCHA Recognition using CRNN + CTC Loss**
344
 
345
+ Upload a CAPTCHA image to see the model's prediction with confidence score.
346
 
347
  **Model Architecture:**
348
  - ResNet-based CNN feature extraction (4 layers, 2 blocks each)
 
354
  - Character Accuracy: 85.82%
355
  - Trained on 7,777 samples with heavy augmentation
356
 
357
+ **Features:**
358
+ - **Confidence scoring**: Shows prediction reliability
359
+ - **Low confidence warnings**: Alerts when visual ambiguity exists (0/o, i/1/l confusion)
360
+ - **Real-time inference**: Results in <1 second
361
+
362
  **Training Details:**
363
+ - 14 iterations of systematic experimentation
364
  - Data augmentation: rotation, shear, black lines, noise
365
  - Regularization: dropout, weight decay, early stopping
366
  """,