asamasach commited on
Commit
80280f8
·
1 Parent(s): d181e61

Add GroundingDINO and YOLO-World zero-shot models - Added GroundingDINO and YOLO-World for better zero-shot detection - Updated requirements.txt with ultralytics - Added visualization with distinct colors

Browse files
Files changed (2) hide show
  1. app.py +242 -0
  2. requirements.txt +1 -0
app.py CHANGED
@@ -278,6 +278,162 @@ def run_florence2_inference(image_bytes: bytes, confidence: float = 0.3):
278
  return []
279
 
280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  def run_owlvit_inference(image_bytes: bytes, text_queries: list = None, confidence: float = 0.1):
282
  """
283
  Run zero-shot object detection using OWL-ViT (Open World Localization - Vision Transformer).
@@ -387,6 +543,16 @@ MODELS = {
387
  "type": "owlvit",
388
  "description": "Zero-shot object detection using Google's OWL-ViT - detects objects based on text descriptions"
389
  },
 
 
 
 
 
 
 
 
 
 
390
  }
391
 
392
  # AdaCLIP configuration
@@ -651,6 +817,66 @@ def gradio_inference(image, model_display_name, conf_threshold):
651
 
652
  return cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
653
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
654
  # Handle YOLO models (default)
655
  session = get_session(model_key)
656
  if session is None:
@@ -727,6 +953,22 @@ def api_inference(image, model_display_name, conf_threshold):
727
  detections = run_owlvit_inference(image_bytes, confidence=conf_threshold)
728
  return detections
729
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
730
  # Handle YOLO models (default)
731
  session = get_session(model_key)
732
  if session is None:
 
278
  return []
279
 
280
 
281
+ def run_groundingdino_inference(image_bytes: bytes, text_queries: list = None, confidence: float = 0.3):
282
+ """
283
+ Run zero-shot object detection using GroundingDINO (IDEA Research).
284
+
285
+ GroundingDINO is better than OWL-ViT for open-set object detection.
286
+ It can find objects based on text descriptions with better accuracy.
287
+ """
288
+ try:
289
+ from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
290
+ from PIL import Image
291
+ import torch
292
+ import io
293
+
294
+ if text_queries is None:
295
+ text_queries = ["defect", "anomaly", "crack", "scratch", "damage", "error", "imperfection"]
296
+
297
+ # Load image
298
+ image = Image.open(io.BytesIO(image_bytes))
299
+ orig_w, orig_h = image.size
300
+ logger.info(f"GroundingDINO: Processing image {orig_w}x{orig_h}")
301
+
302
+ # Initialize model and processor (cached after first load)
303
+ if not hasattr(run_groundingdino_inference, 'processor'):
304
+ logger.info("Loading GroundingDINO model (first time only)...")
305
+ run_groundingdino_inference.processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-tiny")
306
+ run_groundingdino_inference.model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-tiny")
307
+ logger.info("GroundingDINO model loaded successfully")
308
+
309
+ processor = run_groundingdino_inference.processor
310
+ model = run_groundingdino_inference.model
311
+
312
+ # Create text prompt (comma-separated)
313
+ text_prompt = ". ".join(text_queries) + "."
314
+
315
+ # Prepare inputs
316
+ inputs = processor(images=image, text=text_prompt, return_tensors="pt")
317
+
318
+ # Run inference
319
+ with torch.no_grad():
320
+ outputs = model(**inputs)
321
+
322
+ # Post-process results
323
+ results = processor.post_process_grounded_object_detection(
324
+ outputs,
325
+ inputs.input_ids,
326
+ box_threshold=confidence,
327
+ text_threshold=confidence,
328
+ target_sizes=[(orig_h, orig_w)]
329
+ )[0]
330
+
331
+ detections = []
332
+
333
+ if len(results["boxes"]) > 0:
334
+ boxes = results["boxes"].cpu().numpy()
335
+ scores = results["scores"].cpu().numpy()
336
+ labels = results["labels"]
337
+
338
+ logger.info(f"GroundingDINO found {len(boxes)} objects")
339
+
340
+ for box, score, label in zip(boxes, scores, labels):
341
+ x1, y1, x2, y2 = box
342
+
343
+ detections.append({
344
+ "bbox": [float(x1), float(y1), float(x2), float(y2)],
345
+ "confidence": float(score),
346
+ "class_id": 0,
347
+ "class_name": str(label),
348
+ "x1": float(x1),
349
+ "y1": float(y1),
350
+ "x2": float(x2),
351
+ "y2": float(y2),
352
+ "model_type": "groundingdino"
353
+ })
354
+
355
+ logger.info(f"GroundingDINO detected {len(detections)} objects: {[d['class_name'] for d in detections]}")
356
+ return detections
357
+
358
+ except Exception as e:
359
+ logger.error(f"GroundingDINO inference error: {e}")
360
+ import traceback
361
+ logger.error(traceback.format_exc())
362
+ return []
363
+
364
+
365
+ def run_yoloworld_inference(image_bytes: bytes, text_queries: list = None, confidence: float = 0.3):
366
+ """
367
+ Run zero-shot object detection using YOLO-World.
368
+
369
+ YOLO-World combines YOLO speed with open-vocabulary detection.
370
+ Fast and effective for real-time anomaly detection.
371
+ """
372
+ try:
373
+ from ultralytics import YOLOWorld
374
+ from PIL import Image
375
+ import io
376
+ import numpy as np
377
+
378
+ if text_queries is None:
379
+ text_queries = ["defect", "anomaly", "crack", "scratch", "damage"]
380
+
381
+ # Load image
382
+ image = Image.open(io.BytesIO(image_bytes))
383
+ orig_w, orig_h = image.size
384
+ logger.info(f"YOLO-World: Processing image {orig_w}x{orig_h}")
385
+
386
+ # Initialize model (cached after first load)
387
+ if not hasattr(run_yoloworld_inference, 'model'):
388
+ logger.info("Loading YOLO-World model (first time only)...")
389
+ run_yoloworld_inference.model = YOLOWorld("yolov8s-world.pt") # Small model
390
+ logger.info("YOLO-World model loaded successfully")
391
+
392
+ model = run_yoloworld_inference.model
393
+
394
+ # Set custom classes
395
+ model.set_classes(text_queries)
396
+
397
+ # Convert PIL to numpy array
398
+ img_array = np.array(image)
399
+
400
+ # Run inference
401
+ results = model.predict(img_array, conf=confidence, verbose=False)
402
+
403
+ detections = []
404
+
405
+ if len(results) > 0 and results[0].boxes is not None:
406
+ boxes = results[0].boxes
407
+ logger.info(f"YOLO-World found {len(boxes)} objects")
408
+
409
+ for box in boxes:
410
+ x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
411
+ conf = float(box.conf[0].cpu().numpy())
412
+ cls = int(box.cls[0].cpu().numpy())
413
+ class_name = text_queries[cls] if cls < len(text_queries) else "object"
414
+
415
+ detections.append({
416
+ "bbox": [float(x1), float(y1), float(x2), float(y2)],
417
+ "confidence": conf,
418
+ "class_id": cls,
419
+ "class_name": class_name,
420
+ "x1": float(x1),
421
+ "y1": float(y1),
422
+ "x2": float(x2),
423
+ "y2": float(y2),
424
+ "model_type": "yoloworld"
425
+ })
426
+
427
+ logger.info(f"YOLO-World detected {len(detections)} objects: {[d['class_name'] for d in detections]}")
428
+ return detections
429
+
430
+ except Exception as e:
431
+ logger.error(f"YOLO-World inference error: {e}")
432
+ import traceback
433
+ logger.error(traceback.format_exc())
434
+ return []
435
+
436
+
437
  def run_owlvit_inference(image_bytes: bytes, text_queries: list = None, confidence: float = 0.1):
438
  """
439
  Run zero-shot object detection using OWL-ViT (Open World Localization - Vision Transformer).
 
543
  "type": "owlvit",
544
  "description": "Zero-shot object detection using Google's OWL-ViT - detects objects based on text descriptions"
545
  },
546
+ "zero-shot-groundingdino": {
547
+ "name": "Zero Shot (GroundingDINO)",
548
+ "type": "groundingdino",
549
+ "description": "IDEA Research's open-set object detection - better than OWL-ViT for text-guided detection"
550
+ },
551
+ "zero-shot-yoloworld": {
552
+ "name": "Zero Shot (YOLO-World)",
553
+ "type": "yoloworld",
554
+ "description": "Fast open-vocabulary detection using YOLO architecture - combines speed with zero-shot capability"
555
+ },
556
  }
557
 
558
  # AdaCLIP configuration
 
817
 
818
  return cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
819
 
820
+ # Handle GroundingDINO (zero-shot object detection)
821
+ if model_type == "groundingdino":
822
+ _, img_encoded = cv2.imencode('.jpg', img_bgr)
823
+ image_bytes = img_encoded.tobytes()
824
+
825
+ detections = run_groundingdino_inference(image_bytes, confidence=conf_threshold)
826
+
827
+ # Add detection count
828
+ status_text = f"GroundingDINO: {len(detections)} objects"
829
+ cv2.putText(img_bgr, status_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
830
+ cv2.putText(img_bgr, status_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 165, 255), 1) # Orange
831
+
832
+ for i, det in enumerate(detections):
833
+ x1 = int(det["x1"])
834
+ y1 = int(det["y1"])
835
+ x2 = int(det["x2"])
836
+ y2 = int(det["y2"])
837
+ score = det["confidence"]
838
+ class_name = det.get("class_name", "object")
839
+
840
+ label = f"#{i+1} {class_name}:{score:.2f}"
841
+ cv2.rectangle(img_bgr, (x1, y1), (x2, y2), (0, 165, 255), 3) # Orange
842
+ cv2.putText(img_bgr, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 165, 255), 2)
843
+
844
+ if not detections:
845
+ no_detect_text = f"No objects detected (threshold: {conf_threshold:.2f})"
846
+ cv2.putText(img_bgr, no_detect_text, (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 165, 255), 2)
847
+
848
+ return cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
849
+
850
+ # Handle YOLO-World (zero-shot object detection)
851
+ if model_type == "yoloworld":
852
+ _, img_encoded = cv2.imencode('.jpg', img_bgr)
853
+ image_bytes = img_encoded.tobytes()
854
+
855
+ detections = run_yoloworld_inference(image_bytes, confidence=conf_threshold)
856
+
857
+ # Add detection count
858
+ status_text = f"YOLO-World: {len(detections)} objects"
859
+ cv2.putText(img_bgr, status_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
860
+ cv2.putText(img_bgr, status_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 0), 1) # Cyan
861
+
862
+ for i, det in enumerate(detections):
863
+ x1 = int(det["x1"])
864
+ y1 = int(det["y1"])
865
+ x2 = int(det["x2"])
866
+ y2 = int(det["y2"])
867
+ score = det["confidence"]
868
+ class_name = det.get("class_name", "object")
869
+
870
+ label = f"#{i+1} {class_name}:{score:.2f}"
871
+ cv2.rectangle(img_bgr, (x1, y1), (x2, y2), (255, 255, 0), 3) # Cyan
872
+ cv2.putText(img_bgr, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 0), 2)
873
+
874
+ if not detections:
875
+ no_detect_text = f"No objects detected (threshold: {conf_threshold:.2f})"
876
+ cv2.putText(img_bgr, no_detect_text, (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 0), 2)
877
+
878
+ return cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
879
+
880
  # Handle YOLO models (default)
881
  session = get_session(model_key)
882
  if session is None:
 
953
  detections = run_owlvit_inference(image_bytes, confidence=conf_threshold)
954
  return detections
955
 
956
+ # Handle GroundingDINO (zero-shot object detection)
957
+ if model_type == "groundingdino":
958
+ _, img_encoded = cv2.imencode('.jpg', img_bgr)
959
+ image_bytes = img_encoded.tobytes()
960
+
961
+ detections = run_groundingdino_inference(image_bytes, confidence=conf_threshold)
962
+ return detections
963
+
964
+ # Handle YOLO-World (zero-shot object detection)
965
+ if model_type == "yoloworld":
966
+ _, img_encoded = cv2.imencode('.jpg', img_bgr)
967
+ image_bytes = img_encoded.tobytes()
968
+
969
+ detections = run_yoloworld_inference(image_bytes, confidence=conf_threshold)
970
+ return detections
971
+
972
  # Handle YOLO models (default)
973
  session = get_session(model_key)
974
  if session is None:
requirements.txt CHANGED
@@ -11,3 +11,4 @@ transformers
11
  torch
12
  torchvision
13
  pillow
 
 
11
  torch
12
  torchvision
13
  pillow
14
+ ultralytics