asamasach Claude Sonnet 4.5 commited on
Commit
3972d40
·
1 Parent(s): 0ee3a29

Improve zero-shot detection: simplified CLIP logic and better visualization

Browse files

CLIP Changes:
- Changed from 4-class to simple binary comparison (normal vs defect)
- Lowered default threshold from 0.5 to 0.25 (more sensitive)
- Now logs both normal and defect probabilities
- Returns defect probability directly instead of summing classes
- Added normal_score to detection metadata

Visualization Improvements:
- CLIP: Always shows anomaly score on image (even if no detection)
- CLIP: Shows defect vs normal scores in label
- CLIP: Green text when no anomaly detected
- OWL-ViT: Shows detection count
- OWL-ViT: Numbered detections (#1, #2, etc)
- Both: Show threshold used when no detection
- Thicker bounding boxes (3px instead of 2px)

This makes it much clearer what the models are seeing and why they did/didnt detect.

Test file added: test_zeroshot.py (creates synthetic defect images for testing)

🤖 Generated with Claude Code

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

Files changed (3) hide show
  1. __pycache__/app.cpython-313.pyc +0 -0
  2. app.py +46 -26
  3. test_zeroshot.py +116 -0
__pycache__/app.cpython-313.pyc ADDED
Binary file (30.2 kB). View file
 
app.py CHANGED
@@ -96,12 +96,12 @@ def extract_bboxes_from_heatmap(heatmap_path: str, orig_w: int, orig_h: int, thr
96
  return []
97
 
98
 
99
- def run_clip_anomaly_inference(image_bytes: bytes, confidence: float = 0.5):
100
  """
101
  Run zero-shot anomaly detection using CLIP similarity scoring.
102
 
103
- This uses CLIP to compare image patches against "normal" vs "defect" descriptions.
104
- Simple but effective for general anomaly detection.
105
  """
106
  try:
107
  from transformers import CLIPProcessor, CLIPModel
@@ -123,12 +123,10 @@ def run_clip_anomaly_inference(image_bytes: bytes, confidence: float = 0.5):
123
  processor = run_clip_anomaly_inference.processor
124
  model = run_clip_anomaly_inference.model
125
 
126
- # Text descriptions for anomaly detection
127
  text_descriptions = [
128
- "a photo of a normal product without defects",
129
- "a photo of a defective product with anomalies",
130
- "a photo with cracks or scratches",
131
- "a photo with damage or imperfections"
132
  ]
133
 
134
  # Process inputs
@@ -145,31 +143,34 @@ def run_clip_anomaly_inference(image_bytes: bytes, confidence: float = 0.5):
145
  logits_per_image = outputs.logits_per_image
146
  probs = logits_per_image.softmax(dim=1)
147
 
148
- # Get anomaly probability (sum of defect-related classes)
149
- anomaly_prob = float(probs[0][1:].sum()) # Skip "normal" class
 
 
 
150
 
151
  detections = []
152
 
153
- # If anomaly detected, create detection box
154
- if anomaly_prob >= confidence:
155
- # Create a detection for the whole image
156
- # In a real scenario, you'd segment the anomalous region
157
  detections.append({
158
  "bbox": [0, 0, orig_w, orig_h],
159
- "confidence": anomaly_prob,
160
  "class_id": 0,
161
  "class_name": "anomaly",
162
  "x1": 0,
163
  "y1": 0,
164
  "x2": orig_w,
165
  "y2": orig_h,
166
- "anomaly_score": anomaly_prob,
 
167
  "model_type": "clip",
168
- "description": "CLIP-based anomaly detection"
169
  })
170
 
171
- logger.info(f"CLIP anomaly score: {anomaly_prob:.3f}, detections: {len(detections)}")
172
- return detections, anomaly_prob
173
 
174
  except Exception as e:
175
  logger.error(f"CLIP inference error: {e}")
@@ -465,16 +466,26 @@ def gradio_inference(image, model_display_name, conf_threshold):
465
 
466
  detections, anomaly_score = run_clip_anomaly_inference(image_bytes, confidence=conf_threshold)
467
 
 
 
 
 
 
468
  for det in detections:
469
  x1 = int(det["x1"])
470
  y1 = int(det["y1"])
471
  x2 = int(det["x2"])
472
  y2 = int(det["y2"])
473
  score = det["confidence"]
 
 
 
 
 
474
 
475
- label = f"anomaly:{score:.2f}"
476
- cv2.rectangle(img_bgr, (x1, y1), (x2, y2), (0, 0, 255), 2) # Red for anomalies
477
- cv2.putText(img_bgr, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
478
 
479
  return cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
480
 
@@ -485,7 +496,12 @@ def gradio_inference(image, model_display_name, conf_threshold):
485
 
486
  detections = run_owlvit_inference(image_bytes, confidence=conf_threshold)
487
 
488
- for det in detections:
 
 
 
 
 
489
  x1 = int(det["x1"])
490
  y1 = int(det["y1"])
491
  x2 = int(det["x2"])
@@ -493,9 +509,13 @@ def gradio_inference(image, model_display_name, conf_threshold):
493
  score = det["confidence"]
494
  class_name = det.get("class_name", "object")
495
 
496
- label = f"{class_name}:{score:.2f}"
497
- cv2.rectangle(img_bgr, (x1, y1), (x2, y2), (255, 0, 0), 2) # Blue for OWL-ViT
498
- cv2.putText(img_bgr, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 0), 2)
 
 
 
 
499
 
500
  return cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
501
 
 
96
  return []
97
 
98
 
99
+ def run_clip_anomaly_inference(image_bytes: bytes, confidence: float = 0.25):
100
  """
101
  Run zero-shot anomaly detection using CLIP similarity scoring.
102
 
103
+ This uses CLIP to compare the image against "normal" vs "defect" descriptions.
104
+ Returns detection if the image is more similar to defect descriptions than normal.
105
  """
106
  try:
107
  from transformers import CLIPProcessor, CLIPModel
 
123
  processor = run_clip_anomaly_inference.processor
124
  model = run_clip_anomaly_inference.model
125
 
126
+ # Simpler binary comparison: normal vs defect
127
  text_descriptions = [
128
+ "a high quality product without any defects or anomalies",
129
+ "a defective product with visible defects, cracks, scratches, or damage"
 
 
130
  ]
131
 
132
  # Process inputs
 
143
  logits_per_image = outputs.logits_per_image
144
  probs = logits_per_image.softmax(dim=1)
145
 
146
+ # Get probabilities
147
+ normal_prob = float(probs[0][0])
148
+ defect_prob = float(probs[0][1])
149
+
150
+ logger.info(f"CLIP probabilities - Normal: {normal_prob:.3f}, Defect: {defect_prob:.3f}")
151
 
152
  detections = []
153
 
154
+ # If defect probability is higher than threshold, create detection
155
+ # This means the image looks more like a defect than normal
156
+ if defect_prob >= confidence:
 
157
  detections.append({
158
  "bbox": [0, 0, orig_w, orig_h],
159
+ "confidence": defect_prob,
160
  "class_id": 0,
161
  "class_name": "anomaly",
162
  "x1": 0,
163
  "y1": 0,
164
  "x2": orig_w,
165
  "y2": orig_h,
166
+ "anomaly_score": defect_prob,
167
+ "normal_score": normal_prob,
168
  "model_type": "clip",
169
+ "description": f"CLIP anomaly detection (defect:{defect_prob:.2f} vs normal:{normal_prob:.2f})"
170
  })
171
 
172
+ logger.info(f"CLIP result - Defect score: {defect_prob:.3f}, Detections: {len(detections)}")
173
+ return detections, defect_prob
174
 
175
  except Exception as e:
176
  logger.error(f"CLIP inference error: {e}")
 
466
 
467
  detections, anomaly_score = run_clip_anomaly_inference(image_bytes, confidence=conf_threshold)
468
 
469
+ # Add text showing anomaly score even if no detection
470
+ status_text = f"Anomaly Score: {anomaly_score:.3f}"
471
+ cv2.putText(img_bgr, status_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
472
+ cv2.putText(img_bgr, status_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 1)
473
+
474
  for det in detections:
475
  x1 = int(det["x1"])
476
  y1 = int(det["y1"])
477
  x2 = int(det["x2"])
478
  y2 = int(det["y2"])
479
  score = det["confidence"]
480
+ normal_score = det.get("normal_score", 0)
481
+
482
+ label = f"DEFECT:{score:.2f} (vs normal:{normal_score:.2f})"
483
+ cv2.rectangle(img_bgr, (x1, y1), (x2, y2), (0, 0, 255), 3) # Red for anomalies
484
+ cv2.putText(img_bgr, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
485
 
486
+ if not detections:
487
+ no_detect_text = f"No anomaly detected (threshold: {conf_threshold:.2f})"
488
+ cv2.putText(img_bgr, no_detect_text, (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
489
 
490
  return cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
491
 
 
496
 
497
  detections = run_owlvit_inference(image_bytes, confidence=conf_threshold)
498
 
499
+ # Add detection count
500
+ status_text = f"OWL-ViT Detections: {len(detections)}"
501
+ cv2.putText(img_bgr, status_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
502
+ cv2.putText(img_bgr, status_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 0, 0), 1)
503
+
504
+ for i, det in enumerate(detections):
505
  x1 = int(det["x1"])
506
  y1 = int(det["y1"])
507
  x2 = int(det["x2"])
 
509
  score = det["confidence"]
510
  class_name = det.get("class_name", "object")
511
 
512
+ label = f"#{i+1} {class_name}:{score:.2f}"
513
+ cv2.rectangle(img_bgr, (x1, y1), (x2, y2), (255, 0, 0), 3) # Blue for OWL-ViT
514
+ cv2.putText(img_bgr, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2)
515
+
516
+ if not detections:
517
+ no_detect_text = f"No objects detected (threshold: {conf_threshold:.2f})"
518
+ cv2.putText(img_bgr, no_detect_text, (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 0), 2)
519
 
520
  return cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
521
 
test_zeroshot.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test script to verify zero-shot models are working properly.
3
+ """
4
+ import cv2
5
+ import numpy as np
6
+ import sys
7
+ import os
8
+
9
+ # Add parent directory to path
10
+ sys.path.insert(0, os.path.dirname(__file__))
11
+
12
+ from app import run_clip_anomaly_inference, run_owlvit_inference
13
+
14
+ def create_test_image_with_defect():
15
+ """Create a simple test image with a visible defect."""
16
+ # Create white background
17
+ img = np.ones((640, 640, 3), dtype=np.uint8) * 255
18
+
19
+ # Draw a normal grid pattern
20
+ for i in range(0, 640, 80):
21
+ cv2.line(img, (i, 0), (i, 640), (200, 200, 200), 2)
22
+ cv2.line(img, (0, i), (640, i), (200, 200, 200), 2)
23
+
24
+ # Draw a "defect" - irregular shapes
25
+ cv2.circle(img, (320, 320), 50, (0, 0, 0), -1) # Black circle (defect)
26
+ cv2.rectangle(img, (100, 100), (150, 180), (50, 50, 50), -1) # Dark rectangle (scratch)
27
+
28
+ # Save the test image
29
+ cv2.imwrite("test_defect_image.jpg", img)
30
+
31
+ # Convert to bytes
32
+ _, img_encoded = cv2.imencode('.jpg', img)
33
+ return img_encoded.tobytes()
34
+
35
+ def create_normal_test_image():
36
+ """Create a simple test image without defects."""
37
+ # Create white background
38
+ img = np.ones((640, 640, 3), dtype=np.uint8) * 255
39
+
40
+ # Draw a normal grid pattern only
41
+ for i in range(0, 640, 80):
42
+ cv2.line(img, (i, 0), (i, 640), (200, 200, 200), 2)
43
+ cv2.line(img, (0, i), (640, i), (200, 200, 200), 2)
44
+
45
+ # Save the test image
46
+ cv2.imwrite("test_normal_image.jpg", img)
47
+
48
+ # Convert to bytes
49
+ _, img_encoded = cv2.imencode('.jpg', img)
50
+ return img_encoded.tobytes()
51
+
52
+ def test_clip():
53
+ """Test CLIP anomaly detection."""
54
+ print("\n" + "="*60)
55
+ print("Testing CLIP Anomaly Detection")
56
+ print("="*60)
57
+
58
+ # Test with defect image
59
+ print("\n1. Testing with DEFECT image (should detect anomaly)...")
60
+ defect_image = create_test_image_with_defect()
61
+ detections, score = run_clip_anomaly_inference(defect_image, confidence=0.3)
62
+ print(f" Anomaly Score: {score:.4f}")
63
+ print(f" Detections: {len(detections)}")
64
+ if detections:
65
+ for i, det in enumerate(detections):
66
+ print(f" Detection {i+1}: {det}")
67
+ else:
68
+ print(" ⚠️ NO DETECTIONS (this is the problem!)")
69
+
70
+ # Test with normal image
71
+ print("\n2. Testing with NORMAL image (should NOT detect anomaly)...")
72
+ normal_image = create_normal_test_image()
73
+ detections, score = run_clip_anomaly_inference(normal_image, confidence=0.3)
74
+ print(f" Anomaly Score: {score:.4f}")
75
+ print(f" Detections: {len(detections)}")
76
+ if detections:
77
+ print(" ⚠️ False positive detected!")
78
+ else:
79
+ print(" ✓ Correctly identified as normal")
80
+
81
+ def test_owlvit():
82
+ """Test OWL-ViT object detection."""
83
+ print("\n" + "="*60)
84
+ print("Testing OWL-ViT Object Detection")
85
+ print("="*60)
86
+
87
+ # Test with defect image
88
+ print("\n1. Testing with DEFECT image...")
89
+ defect_image = create_test_image_with_defect()
90
+ detections = run_owlvit_inference(defect_image, confidence=0.05)
91
+ print(f" Detections: {len(detections)}")
92
+ if detections:
93
+ for i, det in enumerate(detections):
94
+ print(f" Detection {i+1}: bbox={det['bbox']}, conf={det['confidence']:.4f}, class={det['class_name']}")
95
+ else:
96
+ print(" ⚠️ NO DETECTIONS (this is the problem!)")
97
+
98
+ if __name__ == "__main__":
99
+ print("Testing Zero-Shot Models")
100
+ print("This will create test images and run inference")
101
+
102
+ try:
103
+ test_clip()
104
+ test_owlvit()
105
+
106
+ print("\n" + "="*60)
107
+ print("Test Complete!")
108
+ print("="*60)
109
+ print("\nTest images saved:")
110
+ print(" - test_defect_image.jpg (has defects)")
111
+ print(" - test_normal_image.jpg (normal)")
112
+
113
+ except Exception as e:
114
+ print(f"\n❌ ERROR: {e}")
115
+ import traceback
116
+ traceback.print_exc()