File size: 7,143 Bytes
acdcf49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
"""
Vehicle Detection Configuration Module
=======================================

Manages configuration settings for vehicle detection, tracking, and speed estimation.

Authors:
    - Abhay Gupta (0205CC221005)
    - Aditi Lakhera (0205CC221011)
    - Balraj Patel (0205CC221049)
    - Bhumika Patel (0205CC221050)
"""

import os
from dataclasses import dataclass, field
from typing import List, Tuple, Optional
import logging

logger = logging.getLogger(__name__)


@dataclass
class VehicleDetectionConfig:
    """
    Configuration class for vehicle detection and speed estimation system.
    
    This class encapsulates all configuration parameters needed for the
    vehicle detection pipeline, including video paths, model settings,
    detection zones, and perspective transformation parameters.
    """
    
    # Video Configuration
    input_video: str = "./data/vehicles.mp4"
    output_video: str = "./data/vehicles_output.mp4"
    
    # Model Configuration
    model_name: str = "yolov8n"
    model_path: Optional[str] = None
    confidence_threshold: float = 0.3
    iou_threshold: float = 0.7
    
    # Detection Zone Configuration
    line_y: int = 480
    line_offset: int = 55
    crossing_threshold: int = 1
    
    # Perspective Transformation Configuration
    # Source points define the region in the original video frame
    source_points: List[List[int]] = field(default_factory=lambda: [
        [450, 300],   # Top-left
        [860, 300],   # Top-right
        [1900, 720],  # Bottom-right
        [-660, 720]   # Bottom-left
    ])
    
    # Target points define the transformed top-down view dimensions (in meters)
    target_width_meters: float = 25.0
    target_height_meters: float = 100.0
    
    # Display Configuration (disabled by default for headless environments like HF Spaces)
    window_name: str = "Vehicle Speed Estimation - Traffic Analysis"
    display_enabled: bool = False
    
    # Annotation Configuration
    enable_boxes: bool = True
    enable_labels: bool = True
    enable_traces: bool = True
    enable_line_zones: bool = True
    trace_length: int = 20
    
    # Speed Estimation Configuration
    speed_history_seconds: int = 1
    speed_unit: str = "km/h"  # Options: "km/h", "mph", "m/s"
    
    def __post_init__(self):
        """Validate configuration after initialization."""
        self._validate_config()
        self._setup_model_path()
    
    def _validate_config(self) -> None:
        """
        Validate configuration parameters.
        
        Raises:
            ValueError: If configuration parameters are invalid
        """
        # Validate video paths
        if not self.input_video:
            raise ValueError("Input video path cannot be empty")
        
        # Validate model configuration
        if not 0.0 <= self.confidence_threshold <= 1.0:
            raise ValueError(f"Confidence threshold must be between 0 and 1, got {self.confidence_threshold}")
        
        if not 0.0 <= self.iou_threshold <= 1.0:
            raise ValueError(f"IOU threshold must be between 0 and 1, got {self.iou_threshold}")
        
        # Validate detection zone
        if self.line_y < 0:
            raise ValueError(f"Line Y position must be positive, got {self.line_y}")
        
        if self.line_offset < 0:
            raise ValueError(f"Line offset must be positive, got {self.line_offset}")
        
        # Validate perspective transformation
        if len(self.source_points) != 4:
            raise ValueError(f"Source points must contain exactly 4 points, got {len(self.source_points)}")
        
        for i, point in enumerate(self.source_points):
            if len(point) != 2:
                raise ValueError(f"Source point {i} must have 2 coordinates, got {len(point)}")
        
        if self.target_width_meters <= 0 or self.target_height_meters <= 0:
            raise ValueError("Target dimensions must be positive")
        
        # Validate speed configuration
        if self.speed_unit not in ["km/h", "mph", "m/s"]:
            raise ValueError(f"Invalid speed unit: {self.speed_unit}. Must be 'km/h', 'mph', or 'm/s'")
        
        logger.info("Configuration validation successful")
    
    def _setup_model_path(self) -> None:
        """Set up the model path based on model name."""
        if self.model_path is None:
            # Try to find model in models directory
            model_dir = "./models"
            potential_paths = [
                f"{model_dir}/{self.model_name}.pt",
                f"{model_dir}/VisDrone_YOLO_x2.pt",  # Custom trained model
                self.model_name  # Let ultralytics download from hub
            ]
            
            for path in potential_paths:
                if os.path.exists(path):
                    self.model_path = path
                    logger.info(f"Using model from: {path}")
                    return
            
            # Use model name directly (will be downloaded by ultralytics)
            self.model_path = self.model_name
            logger.info(f"Model will be downloaded: {self.model_name}")
    
    @property
    def target_points(self) -> List[List[float]]:
        """
        Generate target points for perspective transformation.
        
        Returns:
            List of 4 points defining the target perspective in meters
        """
        w, h = self.target_width_meters, self.target_height_meters
        return [
            [0, 0],      # Top-left
            [w, 0],      # Top-right
            [w, h],      # Bottom-right
            [0, h]       # Bottom-left
        ]
    
    def get_speed_conversion_factor(self) -> float:
        """
        Get conversion factor for speed unit.
        
        Returns:
            Conversion factor from m/s to desired unit
        """
        conversions = {
            "km/h": 3.6,
            "mph": 2.23694,
            "m/s": 1.0
        }
        return conversions[self.speed_unit]
    
    def to_dict(self) -> dict:
        """
        Convert configuration to dictionary.
        
        Returns:
            Dictionary representation of configuration
        """
        return {
            "input_video": self.input_video,
            "output_video": self.output_video,
            "model_name": self.model_name,
            "model_path": self.model_path,
            "confidence_threshold": self.confidence_threshold,
            "line_y": self.line_y,
            "speed_unit": self.speed_unit,
        }
    
    def __repr__(self) -> str:
        """String representation of configuration."""
        return f"VehicleDetectionConfig(model={self.model_name}, input={self.input_video})"


# Default configuration instance for backward compatibility
DEFAULT_CONFIG = VehicleDetectionConfig()

# Export commonly used configuration values
IN_VIDEO_PATH = DEFAULT_CONFIG.input_video
OUT_VIDEO_PATH = DEFAULT_CONFIG.output_video
YOLO_MODEL_PATH = DEFAULT_CONFIG.model_path
LINE_Y = DEFAULT_CONFIG.line_y
SOURCE_POINTS = DEFAULT_CONFIG.source_points
TARGET_POINTS = DEFAULT_CONFIG.target_points
WINDOW_NAME = DEFAULT_CONFIG.window_name