File size: 3,098 Bytes
5369733
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""
Trash Detection MCP Tool

Wraps the trash detection model for use as an MCP tool.
"""

from typing import Any
from PIL import Image
import base64
from io import BytesIO
import json

from trash_model import detect_trash, Detection


def detect_trash_mcp(image_data: str | dict):
    """
    MCP tool wrapper for trash detection.
    
    Args:
        image_data: Either:
            - Base64 encoded image string
            - Dict with 'path' key pointing to image file
            - Dict with 'base64' key containing base64 image
    
    Returns:
        Dict containing:
            - detections: List of trash objects found
            - count: Total number of items detected
            - categories: Unique trash categories found
            - summary: Human-readable summary
    """
    # Parse input
    image = _load_image_from_input(image_data)
    
    # Run detection
    detections = detect_trash(image)
    
    # Analyze results
    categories = list(set(d["label"] for d in detections))
    avg_confidence = sum(d["score"] for d in detections) / len(detections) if detections else 0
    
    summary = f"Detected {len(detections)} trash items across {len(categories)} categories. "
    summary += f"Average confidence: {avg_confidence:.1%}"
    
    return {
        "detections": detections,
        "count": len(detections),
        "categories": categories,
        "average_confidence": avg_confidence,
        "summary": summary,
        "image_dimensions": {"width": image.width, "height": image.height}
    }


def _load_image_from_input(image_data: str | dict) -> Image.Image:
    """Load PIL Image from various input formats."""
    if isinstance(image_data, str):
        # Assume base64 encoded
        if image_data.startswith('data:image'):
            # Remove data URL prefix
            image_data = image_data.split(',', 1)[1]
        image_bytes = base64.b64decode(image_data)
        return Image.open(BytesIO(image_bytes))
    
    elif isinstance(image_data, dict):
        if 'path' in image_data:
            return Image.open(image_data['path'])
        elif 'base64' in image_data:
            image_bytes = base64.b64decode(image_data['base64'])
            return Image.open(BytesIO(image_bytes))
    
    raise ValueError("Invalid image_data format. Provide base64 string or dict with 'path' or 'base64' key")


def format_detections_for_display(detections: list[Detection]) -> str:
    """Format detection results as readable text."""
    if not detections:
        return "No trash detected in the image."
    
    lines = [f"Found {len(detections)} trash items:\n"]
    
    # Group by category
    by_category = {}
    for det in detections:
        category = det["label"]
        if category not in by_category:
            by_category[category] = []
        by_category[category].append(det)
    
    for category, items in sorted(by_category.items()):
        avg_conf = sum(d["score"] for d in items) / len(items)
        lines.append(f"  • {category}: {len(items)} item(s) (confidence: {avg_conf:.1%})")
    
    return "\n".join(lines)