Spaces:
Sleeping
Sleeping
| """ | |
| Differentiable Vehicle Detection Model using Soft Segmentation and Geometric Learning | |
| This module provides a differentiable alternative to the traditional contour detection approach. | |
| It uses soft attention mechanisms, differentiable color space operations, and learned geometric | |
| primitives to enable end-to-end gradient-based optimization. | |
| Key differences from traditional approach: | |
| 1. Soft thresholding instead of hard color masking | |
| 2. Attention-based "soft contours" instead of discrete contours | |
| 3. Differentiable geometric operations using PyTorch | |
| 4. Learnable color ranges and geometric parameters | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import cv2 | |
| import math | |
| from typing import List, Dict, Tuple, Optional | |
| from dataclasses import dataclass | |
| class DifferentiableDetectionConfig: | |
| """Configuration for differentiable vehicle detection.""" | |
| # Learnable color parameters | |
| learn_color_ranges: bool = True | |
| color_softness: float = 10.0 # Controls sharpness of soft thresholding | |
| # Attention mechanism | |
| attention_hidden_dim: int = 32 | |
| num_attention_heads: int = 4 | |
| # Geometric estimation | |
| min_blob_size: float = 0.01 # Minimum relative area for valid detection | |
| position_smoothing: float = 0.1 # Smoothing factor for position estimation | |
| # Training parameters | |
| temperature: float = 1.0 # Temperature for soft operations | |
| class DifferentiableContourDetection(nn.Module): | |
| """ | |
| A differentiable vehicle detection model that replaces traditional CV operations | |
| with learnable, gradient-friendly alternatives. | |
| """ | |
| def __init__(self, config: DifferentiableDetectionConfig = None): | |
| super().__init__() | |
| self.config = config or DifferentiableDetectionConfig() | |
| # Learnable color ranges (HSV format) | |
| if self.config.learn_color_ranges: | |
| # Initialize with reasonable defaults, then make them learnable | |
| self.green_hsv_lower = nn.Parameter(torch.tensor([50., 100., 100.]) / 180.) # Normalize to [0,1] | |
| self.green_hsv_upper = nn.Parameter(torch.tensor([70., 255., 255.]) / 255.) | |
| self.blue_hsv_lower = nn.Parameter(torch.tensor([80., 80., 80.]) / 180.) | |
| self.blue_hsv_upper = nn.Parameter(torch.tensor([115., 255., 255.]) / 255.) | |
| else: | |
| # Fixed color ranges | |
| self.register_buffer('green_hsv_lower', torch.tensor([50., 100., 100.]) / 180.) | |
| self.register_buffer('green_hsv_upper', torch.tensor([70., 255., 255.]) / 255.) | |
| self.register_buffer('blue_hsv_lower', torch.tensor([80., 80., 80.]) / 180.) | |
| self.register_buffer('blue_hsv_upper', torch.tensor([115., 255., 255.]) / 255.) | |
| # Attention mechanism for spatial reasoning | |
| self.spatial_attention = SpatialAttentionModule( | |
| hidden_dim=self.config.attention_hidden_dim, | |
| num_heads=self.config.num_attention_heads | |
| ) | |
| # Geometric parameter estimator | |
| self.geometry_estimator = GeometricParameterEstimator() | |
| def forward(self, image: torch.Tensor) -> Tuple[torch.Tensor, List[Dict]]: | |
| """ | |
| Forward pass for differentiable vehicle detection. | |
| Args: | |
| image: Input image tensor [B, C, H, W] in BGR format (0-1 range) | |
| Returns: | |
| - attention_maps: Soft segmentation maps [B, 2, H, W] for [green, blue] | |
| - vehicle_states: List of estimated vehicle states | |
| """ | |
| batch_size, channels, height, width = image.shape | |
| # 1. Convert BGR to HSV (differentiable) | |
| hsv_image = self.bgr_to_hsv_differentiable(image) | |
| # 2. Create soft color masks | |
| green_mask = self.soft_color_mask(hsv_image, self.green_hsv_lower, self.green_hsv_upper) | |
| blue_mask = self.soft_color_mask(hsv_image, self.blue_hsv_lower, self.blue_hsv_upper) | |
| # 3. Apply spatial attention to refine masks | |
| color_masks = torch.stack([green_mask, blue_mask], dim=1) # [B, 2, H, W] | |
| attention_maps = self.spatial_attention(color_masks, image) | |
| # 4. Extract vehicle states from attention maps | |
| vehicle_states = self.extract_states_from_attention(attention_maps) | |
| return attention_maps, vehicle_states | |
| def bgr_to_hsv_differentiable(self, bgr: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Differentiable BGR to HSV conversion. | |
| Args: | |
| bgr: Input tensor [B, 3, H, W] in BGR format | |
| Returns: | |
| hsv: Output tensor [B, 3, H, W] in HSV format | |
| """ | |
| # Assume input is already BGR ordered | |
| b, g, r = bgr[:, 0:1], bgr[:, 1:2], bgr[:, 2:3] | |
| max_rgb, _ = torch.max(torch.cat([r, g, b], dim=1), dim=1, keepdim=True) | |
| min_rgb, _ = torch.min(torch.cat([r, g, b], dim=1), dim=1, keepdim=True) | |
| diff = max_rgb - min_rgb | |
| # Value | |
| v = max_rgb | |
| # Saturation | |
| s = torch.where(max_rgb > 0, diff / max_rgb, torch.zeros_like(max_rgb)) | |
| # Hue (simplified, approximate differentiable version) | |
| h = torch.zeros_like(max_rgb) | |
| # Red is max | |
| red_max = (max_rgb == r).float() | |
| h = h + red_max * (((g - b) / (diff + 1e-8)) % 6) | |
| # Green is max | |
| green_max = (max_rgb == g).float() | |
| h = h + green_max * (((b - r) / (diff + 1e-8)) + 2) | |
| # Blue is max | |
| blue_max = (max_rgb == b).float() | |
| h = h + blue_max * (((r - g) / (diff + 1e-8)) + 4) | |
| h = h * 60 / 360 # Normalize to [0, 1] | |
| return torch.cat([h, s, v], dim=1) | |
| def soft_color_mask(self, hsv: torch.Tensor, lower: torch.Tensor, upper: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Create soft color mask using sigmoid functions instead of hard thresholding. | |
| Args: | |
| hsv: HSV image [B, 3, H, W] | |
| lower: Lower HSV bounds [3] | |
| upper: Upper HSV bounds [3] | |
| Returns: | |
| mask: Soft mask [B, H, W] | |
| """ | |
| h, s, v = hsv[:, 0], hsv[:, 1], hsv[:, 2] | |
| # Soft thresholding using sigmoid | |
| softness = self.config.color_softness | |
| h_mask = torch.sigmoid(softness * (h - lower[0])) * torch.sigmoid(softness * (upper[0] - h)) | |
| s_mask = torch.sigmoid(softness * (s - lower[1])) * torch.sigmoid(softness * (upper[1] - s)) | |
| v_mask = torch.sigmoid(softness * (v - lower[2])) * torch.sigmoid(softness * (upper[2] - v)) | |
| return h_mask * s_mask * v_mask | |
| def extract_states_from_attention(self, attention_maps: torch.Tensor) -> List[Dict]: | |
| """ | |
| Extract vehicle states from soft attention maps. | |
| Args: | |
| attention_maps: Attention maps [B, 2, H, W] for [green, blue] | |
| Returns: | |
| List of vehicle state dictionaries | |
| """ | |
| states = [] | |
| batch_size = attention_maps.shape[0] | |
| for b in range(batch_size): | |
| green_map = attention_maps[b, 0] # [H, W] | |
| blue_map = attention_maps[b, 1] # [H, W] | |
| # Process each color channel | |
| for color_idx, (color_map, class_name) in enumerate([(green_map, "ego_vehicle"), (blue_map, "other_vehicle")]): | |
| # Find significant blobs using thresholding | |
| threshold = 0.5 | |
| binary_mask = (color_map > threshold).float() | |
| # Skip if blob is too small | |
| blob_size = binary_mask.sum() / (color_map.shape[0] * color_map.shape[1]) | |
| if blob_size < self.config.min_blob_size: | |
| continue | |
| # Estimate position using weighted centroid | |
| pos_y, pos_x = self.estimate_position_differentiable(color_map) | |
| # Estimate heading using spatial gradients | |
| heading = self.estimate_heading_differentiable(color_map, pos_x, pos_y) | |
| states.append({ | |
| "class": class_name, | |
| "position_x": pos_x.item(), | |
| "position_y": pos_y.item(), | |
| "heading": heading.item(), | |
| "speed": 0.0, # Placeholder | |
| "confidence": blob_size.item() | |
| }) | |
| return states | |
| def estimate_position_differentiable(self, attention_map: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Estimate position using differentiable weighted centroid. | |
| """ | |
| h, w = attention_map.shape | |
| # Create coordinate grids | |
| y_coords = torch.arange(h, dtype=attention_map.dtype, device=attention_map.device).view(-1, 1) | |
| x_coords = torch.arange(w, dtype=attention_map.dtype, device=attention_map.device).view(1, -1) | |
| # Weighted centroid | |
| total_weight = attention_map.sum() + 1e-8 | |
| pos_y = (attention_map * y_coords).sum() / total_weight | |
| pos_x = (attention_map * x_coords).sum() / total_weight | |
| return pos_y, pos_x | |
| def estimate_heading_differentiable(self, attention_map: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Estimate heading using spatial moment analysis. | |
| """ | |
| h, w = attention_map.shape | |
| # Create coordinate grids relative to centroid | |
| y_coords = torch.arange(h, dtype=attention_map.dtype, device=attention_map.device).view(-1, 1) - pos_y | |
| x_coords = torch.arange(w, dtype=attention_map.dtype, device=attention_map.device).view(1, -1) - pos_x | |
| # Second moments | |
| total_weight = attention_map.sum() + 1e-8 | |
| mu_20 = (attention_map * x_coords.pow(2)).sum() / total_weight | |
| mu_02 = (attention_map * y_coords.pow(2)).sum() / total_weight | |
| mu_11 = (attention_map * x_coords * y_coords).sum() / total_weight | |
| # Principal axis angle | |
| theta = 0.5 * torch.atan2(2 * mu_11, mu_20 - mu_02) | |
| heading_degrees = theta * 180 / math.pi | |
| return heading_degrees | |
| class SpatialAttentionModule(nn.Module): | |
| """Spatial attention module for refining color-based segmentation.""" | |
| def __init__(self, hidden_dim: int = 64, num_heads: int = 4): | |
| super().__init__() | |
| self.hidden_dim = hidden_dim | |
| # Feature extraction | |
| self.feature_conv = nn.Sequential( | |
| nn.Conv2d(5, hidden_dim, 3, padding=1), # 2 color masks + 3 RGB | |
| nn.ReLU(), | |
| nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), | |
| nn.ReLU() | |
| ) | |
| # Replace expensive MultiheadAttention with efficient channel attention | |
| self.channel_attention = nn.Sequential( | |
| nn.AdaptiveAvgPool2d(1), | |
| nn.Conv2d(hidden_dim, hidden_dim // 4, 1), | |
| nn.ReLU(), | |
| nn.Conv2d(hidden_dim // 4, hidden_dim, 1), | |
| nn.Sigmoid() | |
| ) | |
| # Spatial attention using depth-wise separable convolutions | |
| self.spatial_attention = nn.Sequential( | |
| nn.Conv2d(hidden_dim, hidden_dim, 7, padding=3, groups=hidden_dim), # Depth-wise | |
| nn.Conv2d(hidden_dim, 1, 1), # Point-wise | |
| nn.Sigmoid() | |
| ) | |
| # Output projection | |
| self.output_conv = nn.Sequential( | |
| nn.Conv2d(hidden_dim, hidden_dim // 2, 3, padding=1), | |
| nn.ReLU(), | |
| nn.Conv2d(hidden_dim // 2, 2, 3, padding=1), | |
| nn.Sigmoid() | |
| ) | |
| def forward(self, color_masks: torch.Tensor, image: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Args: | |
| color_masks: [B, 2, H, W] soft color masks | |
| image: [B, 3, H, W] original image | |
| Returns: | |
| attention_maps: [B, 2, H, W] refined attention maps | |
| """ | |
| # Combine inputs | |
| features = torch.cat([color_masks, image], dim=1) # [B, 5, H, W] | |
| # Extract features | |
| features = self.feature_conv(features) # [B, hidden_dim, H, W] | |
| # Apply channel attention | |
| channel_weights = self.channel_attention(features) | |
| features = features * channel_weights | |
| # Apply spatial attention | |
| spatial_weights = self.spatial_attention(features) | |
| features = features * spatial_weights | |
| # Generate refined attention maps | |
| attention_maps = self.output_conv(features) | |
| return attention_maps | |
| class GeometricParameterEstimator(nn.Module): | |
| """Network for estimating geometric parameters from attention maps.""" | |
| def __init__(self): | |
| super().__init__() | |
| # This could be expanded for more sophisticated geometric estimation | |
| pass | |
| # Training utilities | |
| def contour_detection_loss(attention_maps: torch.Tensor, | |
| ground_truth_masks: torch.Tensor, | |
| vehicle_states: List[Dict], | |
| gt_states: List[Dict]) -> torch.Tensor: | |
| """ | |
| Combined loss function for training the differentiable contour detection model. | |
| Args: | |
| attention_maps: Predicted attention maps [B, 2, H, W] | |
| ground_truth_masks: GT segmentation masks [B, 2, H, W] | |
| vehicle_states: Predicted vehicle states | |
| gt_states: Ground truth vehicle states | |
| Returns: | |
| total_loss: Combined loss value | |
| """ | |
| # Segmentation loss | |
| seg_loss = F.binary_cross_entropy(attention_maps, ground_truth_masks) | |
| # State estimation loss (if GT states available) | |
| state_loss = torch.tensor(0.0, device=attention_maps.device) | |
| if vehicle_states and gt_states: | |
| # This would need more sophisticated matching between predicted and GT states | |
| # For now, placeholder | |
| pass | |
| total_loss = seg_loss + 0.1 * state_loss | |
| return total_loss | |
| # Compatibility wrapper to match original interface | |
| class DifferentiableContourDetectionModel: | |
| """Wrapper to provide same interface as original ContourDetectionModel.""" | |
| def __init__(self, config: DifferentiableDetectionConfig = None): | |
| self.model = DifferentiableContourDetection(config) | |
| self.model.eval() # Set to eval mode by default | |
| def detect_vehicles(self, image_path: str) -> Tuple[Optional[np.ndarray], List[Dict]]: | |
| """ | |
| Detect vehicles using the differentiable model. | |
| Compatible with original interface. | |
| """ | |
| # Load and preprocess image | |
| img = cv2.imread(image_path) | |
| if img is None: | |
| return None, [] | |
| # Convert to tensor | |
| img_tensor = torch.from_numpy(img).float() / 255.0 | |
| img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0) # [1, 3, H, W] | |
| with torch.no_grad(): | |
| attention_maps, vehicle_states = self.model(img_tensor) | |
| # Create visualization | |
| annotated_img = self._create_visualization(img, attention_maps, vehicle_states) | |
| return annotated_img, vehicle_states | |
| def _create_visualization(self, original_img: np.ndarray, | |
| attention_maps: torch.Tensor, | |
| vehicle_states: List[Dict]) -> np.ndarray: | |
| """Create visualization of detection results.""" | |
| annotated_img = original_img.copy() | |
| # Overlay attention maps | |
| attention_np = attention_maps[0].cpu().numpy() # [2, H, W] | |
| for i, (attention_map, color) in enumerate(zip(attention_np, [(0, 255, 0), (255, 0, 0)])): | |
| # Convert attention to color overlay | |
| overlay = np.zeros_like(original_img) | |
| overlay[:, :] = color | |
| # Apply attention as alpha | |
| alpha = (attention_map * 0.3).clip(0, 1) | |
| for c in range(3): | |
| annotated_img[:, :, c] = (1 - alpha) * annotated_img[:, :, c] + alpha * overlay[:, :, c] | |
| # Draw vehicle states | |
| for state in vehicle_states: | |
| pos_x, pos_y = int(state['position_x']), int(state['position_y']) | |
| heading = state['heading'] | |
| # Draw center point | |
| cv2.circle(annotated_img, (pos_x, pos_y), 5, (0, 0, 255), -1) | |
| # Draw heading vector | |
| length = 40 | |
| angle_rad = np.deg2rad(heading) | |
| end_x = int(pos_x + length * np.cos(angle_rad)) | |
| end_y = int(pos_y + length * np.sin(angle_rad)) | |
| cv2.line(annotated_img, (pos_x, pos_y), (end_x, end_y), (255, 0, 0), 2) | |
| # Add label | |
| label = f"H: {heading:.1f}" | |
| cv2.putText(annotated_img, label, (pos_x + 10, pos_y - 10), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 255), 2) | |
| return annotated_img | |
| if __name__ == '__main__': | |
| # Example usage | |
| config = DifferentiableDetectionConfig( | |
| learn_color_ranges=True, | |
| color_softness=10.0 | |
| ) | |
| model = DifferentiableContourDetectionModel(config) | |
| # Test with an image | |
| input_image_path = '/home/alienware3/Documents/diamond/frames/frame_0.png' | |
| annotated_image, states = model.detect_vehicles(input_image_path) | |
| if annotated_image is not None: | |
| cv2.imwrite('differentiable_detection_output.png', annotated_image) | |
| print(f"Detected {len(states)} vehicles") | |
| for i, state in enumerate(states): | |
| print(f"Vehicle {i+1}: {state}") |