|
|
""" |
|
|
Trash Detection MCP Tool |
|
|
|
|
|
Wraps the trash detection model for use as an MCP tool. |
|
|
""" |
|
|
|
|
|
from typing import Any |
|
|
from PIL import Image |
|
|
import base64 |
|
|
from io import BytesIO |
|
|
import json |
|
|
|
|
|
from trash_model import detect_trash, Detection |
|
|
|
|
|
|
|
|
def detect_trash_mcp(image_data: str | dict): |
|
|
""" |
|
|
MCP tool wrapper for trash detection. |
|
|
|
|
|
Args: |
|
|
image_data: Either: |
|
|
- Base64 encoded image string |
|
|
- Dict with 'path' key pointing to image file |
|
|
- Dict with 'base64' key containing base64 image |
|
|
|
|
|
Returns: |
|
|
Dict containing: |
|
|
- detections: List of trash objects found |
|
|
- count: Total number of items detected |
|
|
- categories: Unique trash categories found |
|
|
- summary: Human-readable summary |
|
|
""" |
|
|
|
|
|
image = _load_image_from_input(image_data) |
|
|
|
|
|
|
|
|
detections = detect_trash(image) |
|
|
|
|
|
|
|
|
categories = list(set(d["label"] for d in detections)) |
|
|
avg_confidence = sum(d["score"] for d in detections) / len(detections) if detections else 0 |
|
|
|
|
|
summary = f"Detected {len(detections)} trash items across {len(categories)} categories. " |
|
|
summary += f"Average confidence: {avg_confidence:.1%}" |
|
|
|
|
|
return { |
|
|
"detections": detections, |
|
|
"count": len(detections), |
|
|
"categories": categories, |
|
|
"average_confidence": avg_confidence, |
|
|
"summary": summary, |
|
|
"image_dimensions": {"width": image.width, "height": image.height} |
|
|
} |
|
|
|
|
|
|
|
|
def _load_image_from_input(image_data: str | dict) -> Image.Image: |
|
|
"""Load PIL Image from various input formats.""" |
|
|
if isinstance(image_data, str): |
|
|
|
|
|
if image_data.startswith('data:image'): |
|
|
|
|
|
image_data = image_data.split(',', 1)[1] |
|
|
image_bytes = base64.b64decode(image_data) |
|
|
return Image.open(BytesIO(image_bytes)) |
|
|
|
|
|
elif isinstance(image_data, dict): |
|
|
if 'path' in image_data: |
|
|
return Image.open(image_data['path']) |
|
|
elif 'base64' in image_data: |
|
|
image_bytes = base64.b64decode(image_data['base64']) |
|
|
return Image.open(BytesIO(image_bytes)) |
|
|
|
|
|
raise ValueError("Invalid image_data format. Provide base64 string or dict with 'path' or 'base64' key") |
|
|
|
|
|
|
|
|
def format_detections_for_display(detections: list[Detection]) -> str: |
|
|
"""Format detection results as readable text.""" |
|
|
if not detections: |
|
|
return "No trash detected in the image." |
|
|
|
|
|
lines = [f"Found {len(detections)} trash items:\n"] |
|
|
|
|
|
|
|
|
by_category = {} |
|
|
for det in detections: |
|
|
category = det["label"] |
|
|
if category not in by_category: |
|
|
by_category[category] = [] |
|
|
by_category[category].append(det) |
|
|
|
|
|
for category, items in sorted(by_category.items()): |
|
|
avg_conf = sum(d["score"] for d in items) / len(items) |
|
|
lines.append(f" • {category}: {len(items)} item(s) (confidence: {avg_conf:.1%})") |
|
|
|
|
|
return "\n".join(lines) |
|
|
|