File size: 4,455 Bytes
5369733
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""
Trash Detection Model Wrapper

This module provides an interface for trash detection in images using YOLOv8.
Loads a trained model from Weights/best.pt for real trash detection.
"""

from typing import TypedDict
from PIL import Image
from pathlib import Path
import numpy as np
import os

# Import YOLO from ultralytics
try:
    from ultralytics import YOLO
    YOLO_AVAILABLE = True
except ImportError:
    YOLO_AVAILABLE = False
    print("⚠️  Ultralytics not available. Install with: pip install ultralytics")


class Detection(TypedDict):
    """Single trash detection result."""
    bbox: list[float]  # [x1, y1, x2, y2] in pixels
    label: str         # Trash category
    score: float       # Confidence score (0-1)


# Global model instance (loaded once)
_model = None

# Get the directory where this script is located
SCRIPT_DIR = Path(__file__).parent.resolve()
DEFAULT_MODEL_PATH = SCRIPT_DIR / "Weights" / "best.pt"


def load_model(model_path: str = None) -> YOLO:
    """
    Load the YOLOv8 trash detection model.
    
    Args:
        model_path: Path to the model weights file (if None, uses default)
        
    Returns:
        Loaded YOLO model instance
    """
    global _model
    
    if _model is None:
        if not YOLO_AVAILABLE:
            raise ImportError("Ultralytics not installed. Run: pip install ultralytics")
        
        # Use default path if none provided
        if model_path is None:
            model_file = DEFAULT_MODEL_PATH
        else:
            model_file = Path(model_path)
            # If relative path provided, make it relative to script directory
            if not model_file.is_absolute():
                model_file = SCRIPT_DIR / model_file
        
        if not model_file.exists():
            raise FileNotFoundError(
                f"Model file not found: {model_file}\n"
                f"Expected location: {DEFAULT_MODEL_PATH}\n"
                f"Current directory: {os.getcwd()}"
            )
        
        print(f"🔄 Loading YOLO model from {model_file}...")
        _model = YOLO(str(model_file))
        print(f"✅ Model loaded successfully!")
        print(f"   Classes: {_model.names}")
    
    return _model


def detect_trash(image: Image.Image, conf_threshold: float = 0.25) -> list[Detection]:
    """
    Detect trash objects in an image using YOLOv8.
    
    Args:
        image: PIL Image to analyze
        conf_threshold: Confidence threshold for detections (0-1)
        
    Returns:
        List of detections with bounding boxes, labels, and confidence scores
    """
    try:
        # Load model (only happens once)
        model = load_model()
        
        # Run inference
        results = model(image, conf=conf_threshold, verbose=False)
        
        # Parse results
        detections: list[Detection] = []
        
        # Get the first result (single image)
        if len(results) > 0:
            result = results[0]
            
            # Extract boxes, classes, and scores
            if result.boxes is not None and len(result.boxes) > 0:
                boxes = result.boxes.xyxy.cpu().numpy()  # [x1, y1, x2, y2]
                confidences = result.boxes.conf.cpu().numpy()
                class_ids = result.boxes.cls.cpu().numpy().astype(int)
                
                # Convert to Detection format
                for box, conf, cls_id in zip(boxes, confidences, class_ids):
                    # Get class name
                    label = model.names[cls_id]
                    
                    detection: Detection = {
                        "bbox": box.tolist(),  # [x1, y1, x2, y2]
                        "label": label,
                        "score": float(conf)
                    }
                    detections.append(detection)
        
        return detections
    
    except Exception as e:
        print(f"❌ Error during detection: {e}")
        print("   Falling back to empty detection list")
        return []


def get_model_info():
    """Get information about the loaded model."""
    try:
        model = load_model()
        return {
            "model_type": "YOLOv8",
            "classes": model.names,
            "num_classes": len(model.names),
            "model_path": "Weights/best.pt"
        }
    except Exception as e:
        return {
            "error": str(e),
            "model_type": "None",
            "status": "Model not loaded"
        }