|
|
""" |
|
|
Object Detector Plugin |
|
|
|
|
|
Detects objects in images using CLIP model. |
|
|
""" |
|
|
|
|
|
from typing import Dict, Any, List |
|
|
from pathlib import Path |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from loguru import logger |
|
|
|
|
|
from plugins.base import BasePlugin, PluginMetadata |
|
|
|
|
|
|
|
|
class ObjectDetectorPlugin(BasePlugin): |
|
|
""" |
|
|
Detect objects in images using CLIP. |
|
|
|
|
|
Uses zero-shot classification to identify objects |
|
|
without requiring training data. |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
"""Initialize ObjectDetectorPlugin.""" |
|
|
super().__init__() |
|
|
self.model = None |
|
|
self.processor = None |
|
|
self.candidate_labels = [ |
|
|
"person", "people", "man", "woman", "child", "baby", |
|
|
"dog", "cat", "bird", "animal", |
|
|
"car", "vehicle", "bicycle", "motorcycle", |
|
|
"building", "house", "tree", "plant", "flower", |
|
|
"food", "plate", "cup", "bottle", |
|
|
"computer", "phone", "keyboard", "screen", |
|
|
"furniture", "chair", "table", "bed", |
|
|
"nature", "landscape", "mountain", "ocean", "beach", |
|
|
"sky", "cloud", "sunset", "sunrise", |
|
|
"indoor", "outdoor", "room", "street", |
|
|
] |
|
|
|
|
|
@property |
|
|
def metadata(self) -> PluginMetadata: |
|
|
"""Return plugin metadata.""" |
|
|
return PluginMetadata( |
|
|
name="object_detector", |
|
|
version="0.1.0", |
|
|
description="Detects objects using CLIP zero-shot classification", |
|
|
author="AI Dev Collective", |
|
|
requires=["transformers", "torch"], |
|
|
category="detection", |
|
|
priority=10, |
|
|
) |
|
|
|
|
|
def initialize(self) -> None: |
|
|
"""Initialize the plugin and load CLIP model.""" |
|
|
try: |
|
|
|
|
|
from transformers import CLIPProcessor, CLIPModel |
|
|
import torch |
|
|
|
|
|
logger.info("Loading CLIP model...") |
|
|
|
|
|
model_name = "openai/clip-vit-base-patch32" |
|
|
|
|
|
|
|
|
self.model = CLIPModel.from_pretrained(model_name) |
|
|
self.processor = CLIPProcessor.from_pretrained(model_name) |
|
|
|
|
|
|
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
device = "cpu" |
|
|
self.model.to(device) |
|
|
|
|
|
self._initialized = True |
|
|
|
|
|
logger.info( |
|
|
f"CLIP model loaded successfully on {device}" |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to initialize ObjectDetectorPlugin: {e}") |
|
|
raise |
|
|
|
|
|
def _detect_objects( |
|
|
self, |
|
|
image: Image.Image, |
|
|
labels: List[str], |
|
|
threshold: float = 0.3 |
|
|
) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Detect objects in image using CLIP. |
|
|
|
|
|
Args: |
|
|
image: PIL Image |
|
|
labels: List of candidate labels |
|
|
threshold: Confidence threshold |
|
|
|
|
|
Returns: |
|
|
List of detected objects |
|
|
""" |
|
|
import torch |
|
|
|
|
|
|
|
|
inputs = self.processor( |
|
|
text=labels, |
|
|
images=image, |
|
|
return_tensors="pt", |
|
|
padding=True |
|
|
) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(**inputs) |
|
|
logits_per_image = outputs.logits_per_image |
|
|
probs = logits_per_image.softmax(dim=1)[0] |
|
|
|
|
|
|
|
|
detected = [] |
|
|
for idx, (label, prob) in enumerate(zip(labels, probs)): |
|
|
confidence = float(prob) |
|
|
if confidence >= threshold: |
|
|
detected.append({ |
|
|
"name": label, |
|
|
"confidence": round(confidence, 4), |
|
|
"index": idx, |
|
|
}) |
|
|
|
|
|
|
|
|
detected.sort(key=lambda x: x["confidence"], reverse=True) |
|
|
|
|
|
return detected |
|
|
|
|
|
def analyze( |
|
|
self, |
|
|
media: Any, |
|
|
media_path: Path |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Detect objects in the image. |
|
|
|
|
|
Args: |
|
|
media: PIL Image or numpy array |
|
|
media_path: Path to image file |
|
|
|
|
|
Returns: |
|
|
Dictionary with detected objects |
|
|
""" |
|
|
try: |
|
|
|
|
|
if not self._initialized: |
|
|
self.initialize() |
|
|
|
|
|
|
|
|
if not self.validate_input(media): |
|
|
return {"error": "Invalid input type"} |
|
|
|
|
|
|
|
|
if isinstance(media, np.ndarray): |
|
|
image = Image.fromarray( |
|
|
(media * 255).astype(np.uint8) if media.max() <= 1 |
|
|
else media.astype(np.uint8) |
|
|
) |
|
|
else: |
|
|
image = media |
|
|
|
|
|
|
|
|
objects = self._detect_objects( |
|
|
image, |
|
|
self.candidate_labels, |
|
|
threshold=0.15 |
|
|
) |
|
|
|
|
|
|
|
|
top_objects = objects[:10] |
|
|
|
|
|
|
|
|
categories = self._categorize_objects(top_objects) |
|
|
|
|
|
result = { |
|
|
"objects": top_objects, |
|
|
"total_detected": len(objects), |
|
|
"categories": categories, |
|
|
"candidate_labels_count": len(self.candidate_labels), |
|
|
"status": "success", |
|
|
} |
|
|
|
|
|
logger.debug( |
|
|
f"Object detection complete: {len(top_objects)} objects found" |
|
|
) |
|
|
|
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Object detection failed: {e}") |
|
|
return { |
|
|
"error": str(e), |
|
|
"status": "failed" |
|
|
} |
|
|
|
|
|
def _categorize_objects( |
|
|
self, |
|
|
objects: List[Dict[str, Any]] |
|
|
) -> Dict[str, List[str]]: |
|
|
""" |
|
|
Categorize detected objects. |
|
|
|
|
|
Args: |
|
|
objects: List of detected objects |
|
|
|
|
|
Returns: |
|
|
Dictionary of categories |
|
|
""" |
|
|
categories = { |
|
|
"people": [], |
|
|
"animals": [], |
|
|
"vehicles": [], |
|
|
"nature": [], |
|
|
"objects": [], |
|
|
"places": [], |
|
|
} |
|
|
|
|
|
for obj in objects: |
|
|
name = obj["name"] |
|
|
|
|
|
if name in ["person", "people", "man", "woman", "child", "baby"]: |
|
|
categories["people"].append(name) |
|
|
elif name in ["dog", "cat", "bird", "animal"]: |
|
|
categories["animals"].append(name) |
|
|
elif name in ["car", "vehicle", "bicycle", "motorcycle"]: |
|
|
categories["vehicles"].append(name) |
|
|
elif name in ["tree", "plant", "flower", "nature", "landscape", |
|
|
"mountain", "ocean", "beach"]: |
|
|
categories["nature"].append(name) |
|
|
elif name in ["indoor", "outdoor", "room", "street", "building", |
|
|
"house"]: |
|
|
categories["places"].append(name) |
|
|
else: |
|
|
categories["objects"].append(name) |
|
|
|
|
|
|
|
|
categories = {k: v for k, v in categories.items() if v} |
|
|
|
|
|
return categories |
|
|
|
|
|
def cleanup(self) -> None: |
|
|
"""Clean up model resources.""" |
|
|
if self.model is not None: |
|
|
del self.model |
|
|
self.model = None |
|
|
|
|
|
if self.processor is not None: |
|
|
del self.processor |
|
|
self.processor = None |
|
|
|
|
|
logger.info("ObjectDetectorPlugin cleanup complete") |
|
|
|