Spaces:
Sleeping
Sleeping
| """ | |
| Trash Detection MCP Tool | |
| Wraps the trash detection model for use as an MCP tool. | |
| """ | |
| from typing import Any | |
| from PIL import Image | |
| import base64 | |
| from io import BytesIO | |
| import json | |
| from trash_model import detect_trash, Detection | |
| def detect_trash_mcp(image_data: str | dict): | |
| """ | |
| MCP tool wrapper for trash detection. | |
| Args: | |
| image_data: Either: | |
| - Base64 encoded image string | |
| - Dict with 'path' key pointing to image file | |
| - Dict with 'base64' key containing base64 image | |
| Returns: | |
| Dict containing: | |
| - detections: List of trash objects found | |
| - count: Total number of items detected | |
| - categories: Unique trash categories found | |
| - summary: Human-readable summary | |
| """ | |
| # Parse input | |
| image = _load_image_from_input(image_data) | |
| # Run detection | |
| detections = detect_trash(image) | |
| # Analyze results | |
| categories = list(set(d["label"] for d in detections)) | |
| avg_confidence = sum(d["score"] for d in detections) / len(detections) if detections else 0 | |
| summary = f"Detected {len(detections)} trash items across {len(categories)} categories. " | |
| summary += f"Average confidence: {avg_confidence:.1%}" | |
| return { | |
| "detections": detections, | |
| "count": len(detections), | |
| "categories": categories, | |
| "average_confidence": avg_confidence, | |
| "summary": summary, | |
| "image_dimensions": {"width": image.width, "height": image.height} | |
| } | |
| def _load_image_from_input(image_data: str | dict) -> Image.Image: | |
| """Load PIL Image from various input formats.""" | |
| if isinstance(image_data, str): | |
| # Assume base64 encoded | |
| if image_data.startswith('data:image'): | |
| # Remove data URL prefix | |
| image_data = image_data.split(',', 1)[1] | |
| image_bytes = base64.b64decode(image_data) | |
| return Image.open(BytesIO(image_bytes)) | |
| elif isinstance(image_data, dict): | |
| if 'path' in image_data: | |
| return Image.open(image_data['path']) | |
| elif 'base64' in image_data: | |
| image_bytes = base64.b64decode(image_data['base64']) | |
| return Image.open(BytesIO(image_bytes)) | |
| raise ValueError("Invalid image_data format. Provide base64 string or dict with 'path' or 'base64' key") | |
| def format_detections_for_display(detections: list[Detection]) -> str: | |
| """Format detection results as readable text.""" | |
| if not detections: | |
| return "No trash detected in the image." | |
| lines = [f"Found {len(detections)} trash items:\n"] | |
| # Group by category | |
| by_category = {} | |
| for det in detections: | |
| category = det["label"] | |
| if category not in by_category: | |
| by_category[category] = [] | |
| by_category[category].append(det) | |
| for category, items in sorted(by_category.items()): | |
| avg_conf = sum(d["score"] for d in items) / len(items) | |
| lines.append(f" • {category}: {len(items)} item(s) (confidence: {avg_conf:.1%})") | |
| return "\n".join(lines) | |