""" Differentiable Vehicle Detection Model using Soft Segmentation and Geometric Learning This module provides a differentiable alternative to the traditional contour detection approach. It uses soft attention mechanisms, differentiable color space operations, and learned geometric primitives to enable end-to-end gradient-based optimization. Key differences from traditional approach: 1. Soft thresholding instead of hard color masking 2. Attention-based "soft contours" instead of discrete contours 3. Differentiable geometric operations using PyTorch 4. Learnable color ranges and geometric parameters """ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import cv2 import math from typing import List, Dict, Tuple, Optional from dataclasses import dataclass @dataclass class DifferentiableDetectionConfig: """Configuration for differentiable vehicle detection.""" # Learnable color parameters learn_color_ranges: bool = True color_softness: float = 10.0 # Controls sharpness of soft thresholding # Attention mechanism attention_hidden_dim: int = 32 num_attention_heads: int = 4 # Geometric estimation min_blob_size: float = 0.01 # Minimum relative area for valid detection position_smoothing: float = 0.1 # Smoothing factor for position estimation # Training parameters temperature: float = 1.0 # Temperature for soft operations class DifferentiableContourDetection(nn.Module): """ A differentiable vehicle detection model that replaces traditional CV operations with learnable, gradient-friendly alternatives. """ def __init__(self, config: DifferentiableDetectionConfig = None): super().__init__() self.config = config or DifferentiableDetectionConfig() # Learnable color ranges (HSV format) if self.config.learn_color_ranges: # Initialize with reasonable defaults, then make them learnable self.green_hsv_lower = nn.Parameter(torch.tensor([50., 100., 100.]) / 180.) # Normalize to [0,1] self.green_hsv_upper = nn.Parameter(torch.tensor([70., 255., 255.]) / 255.) self.blue_hsv_lower = nn.Parameter(torch.tensor([80., 80., 80.]) / 180.) self.blue_hsv_upper = nn.Parameter(torch.tensor([115., 255., 255.]) / 255.) else: # Fixed color ranges self.register_buffer('green_hsv_lower', torch.tensor([50., 100., 100.]) / 180.) self.register_buffer('green_hsv_upper', torch.tensor([70., 255., 255.]) / 255.) self.register_buffer('blue_hsv_lower', torch.tensor([80., 80., 80.]) / 180.) self.register_buffer('blue_hsv_upper', torch.tensor([115., 255., 255.]) / 255.) # Attention mechanism for spatial reasoning self.spatial_attention = SpatialAttentionModule( hidden_dim=self.config.attention_hidden_dim, num_heads=self.config.num_attention_heads ) # Geometric parameter estimator self.geometry_estimator = GeometricParameterEstimator() def forward(self, image: torch.Tensor) -> Tuple[torch.Tensor, List[Dict]]: """ Forward pass for differentiable vehicle detection. Args: image: Input image tensor [B, C, H, W] in BGR format (0-1 range) Returns: - attention_maps: Soft segmentation maps [B, 2, H, W] for [green, blue] - vehicle_states: List of estimated vehicle states """ batch_size, channels, height, width = image.shape # 1. Convert BGR to HSV (differentiable) hsv_image = self.bgr_to_hsv_differentiable(image) # 2. Create soft color masks green_mask = self.soft_color_mask(hsv_image, self.green_hsv_lower, self.green_hsv_upper) blue_mask = self.soft_color_mask(hsv_image, self.blue_hsv_lower, self.blue_hsv_upper) # 3. Apply spatial attention to refine masks color_masks = torch.stack([green_mask, blue_mask], dim=1) # [B, 2, H, W] attention_maps = self.spatial_attention(color_masks, image) # 4. Extract vehicle states from attention maps vehicle_states = self.extract_states_from_attention(attention_maps) return attention_maps, vehicle_states def bgr_to_hsv_differentiable(self, bgr: torch.Tensor) -> torch.Tensor: """ Differentiable BGR to HSV conversion. Args: bgr: Input tensor [B, 3, H, W] in BGR format Returns: hsv: Output tensor [B, 3, H, W] in HSV format """ # Assume input is already BGR ordered b, g, r = bgr[:, 0:1], bgr[:, 1:2], bgr[:, 2:3] max_rgb, _ = torch.max(torch.cat([r, g, b], dim=1), dim=1, keepdim=True) min_rgb, _ = torch.min(torch.cat([r, g, b], dim=1), dim=1, keepdim=True) diff = max_rgb - min_rgb # Value v = max_rgb # Saturation s = torch.where(max_rgb > 0, diff / max_rgb, torch.zeros_like(max_rgb)) # Hue (simplified, approximate differentiable version) h = torch.zeros_like(max_rgb) # Red is max red_max = (max_rgb == r).float() h = h + red_max * (((g - b) / (diff + 1e-8)) % 6) # Green is max green_max = (max_rgb == g).float() h = h + green_max * (((b - r) / (diff + 1e-8)) + 2) # Blue is max blue_max = (max_rgb == b).float() h = h + blue_max * (((r - g) / (diff + 1e-8)) + 4) h = h * 60 / 360 # Normalize to [0, 1] return torch.cat([h, s, v], dim=1) def soft_color_mask(self, hsv: torch.Tensor, lower: torch.Tensor, upper: torch.Tensor) -> torch.Tensor: """ Create soft color mask using sigmoid functions instead of hard thresholding. Args: hsv: HSV image [B, 3, H, W] lower: Lower HSV bounds [3] upper: Upper HSV bounds [3] Returns: mask: Soft mask [B, H, W] """ h, s, v = hsv[:, 0], hsv[:, 1], hsv[:, 2] # Soft thresholding using sigmoid softness = self.config.color_softness h_mask = torch.sigmoid(softness * (h - lower[0])) * torch.sigmoid(softness * (upper[0] - h)) s_mask = torch.sigmoid(softness * (s - lower[1])) * torch.sigmoid(softness * (upper[1] - s)) v_mask = torch.sigmoid(softness * (v - lower[2])) * torch.sigmoid(softness * (upper[2] - v)) return h_mask * s_mask * v_mask def extract_states_from_attention(self, attention_maps: torch.Tensor) -> List[Dict]: """ Extract vehicle states from soft attention maps. Args: attention_maps: Attention maps [B, 2, H, W] for [green, blue] Returns: List of vehicle state dictionaries """ states = [] batch_size = attention_maps.shape[0] for b in range(batch_size): green_map = attention_maps[b, 0] # [H, W] blue_map = attention_maps[b, 1] # [H, W] # Process each color channel for color_idx, (color_map, class_name) in enumerate([(green_map, "ego_vehicle"), (blue_map, "other_vehicle")]): # Find significant blobs using thresholding threshold = 0.5 binary_mask = (color_map > threshold).float() # Skip if blob is too small blob_size = binary_mask.sum() / (color_map.shape[0] * color_map.shape[1]) if blob_size < self.config.min_blob_size: continue # Estimate position using weighted centroid pos_y, pos_x = self.estimate_position_differentiable(color_map) # Estimate heading using spatial gradients heading = self.estimate_heading_differentiable(color_map, pos_x, pos_y) states.append({ "class": class_name, "position_x": pos_x.item(), "position_y": pos_y.item(), "heading": heading.item(), "speed": 0.0, # Placeholder "confidence": blob_size.item() }) return states def estimate_position_differentiable(self, attention_map: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Estimate position using differentiable weighted centroid. """ h, w = attention_map.shape # Create coordinate grids y_coords = torch.arange(h, dtype=attention_map.dtype, device=attention_map.device).view(-1, 1) x_coords = torch.arange(w, dtype=attention_map.dtype, device=attention_map.device).view(1, -1) # Weighted centroid total_weight = attention_map.sum() + 1e-8 pos_y = (attention_map * y_coords).sum() / total_weight pos_x = (attention_map * x_coords).sum() / total_weight return pos_y, pos_x def estimate_heading_differentiable(self, attention_map: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor) -> torch.Tensor: """ Estimate heading using spatial moment analysis. """ h, w = attention_map.shape # Create coordinate grids relative to centroid y_coords = torch.arange(h, dtype=attention_map.dtype, device=attention_map.device).view(-1, 1) - pos_y x_coords = torch.arange(w, dtype=attention_map.dtype, device=attention_map.device).view(1, -1) - pos_x # Second moments total_weight = attention_map.sum() + 1e-8 mu_20 = (attention_map * x_coords.pow(2)).sum() / total_weight mu_02 = (attention_map * y_coords.pow(2)).sum() / total_weight mu_11 = (attention_map * x_coords * y_coords).sum() / total_weight # Principal axis angle theta = 0.5 * torch.atan2(2 * mu_11, mu_20 - mu_02) heading_degrees = theta * 180 / math.pi return heading_degrees class SpatialAttentionModule(nn.Module): """Spatial attention module for refining color-based segmentation.""" def __init__(self, hidden_dim: int = 64, num_heads: int = 4): super().__init__() self.hidden_dim = hidden_dim # Feature extraction self.feature_conv = nn.Sequential( nn.Conv2d(5, hidden_dim, 3, padding=1), # 2 color masks + 3 RGB nn.ReLU(), nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), nn.ReLU() ) # Replace expensive MultiheadAttention with efficient channel attention self.channel_attention = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(hidden_dim, hidden_dim // 4, 1), nn.ReLU(), nn.Conv2d(hidden_dim // 4, hidden_dim, 1), nn.Sigmoid() ) # Spatial attention using depth-wise separable convolutions self.spatial_attention = nn.Sequential( nn.Conv2d(hidden_dim, hidden_dim, 7, padding=3, groups=hidden_dim), # Depth-wise nn.Conv2d(hidden_dim, 1, 1), # Point-wise nn.Sigmoid() ) # Output projection self.output_conv = nn.Sequential( nn.Conv2d(hidden_dim, hidden_dim // 2, 3, padding=1), nn.ReLU(), nn.Conv2d(hidden_dim // 2, 2, 3, padding=1), nn.Sigmoid() ) def forward(self, color_masks: torch.Tensor, image: torch.Tensor) -> torch.Tensor: """ Args: color_masks: [B, 2, H, W] soft color masks image: [B, 3, H, W] original image Returns: attention_maps: [B, 2, H, W] refined attention maps """ # Combine inputs features = torch.cat([color_masks, image], dim=1) # [B, 5, H, W] # Extract features features = self.feature_conv(features) # [B, hidden_dim, H, W] # Apply channel attention channel_weights = self.channel_attention(features) features = features * channel_weights # Apply spatial attention spatial_weights = self.spatial_attention(features) features = features * spatial_weights # Generate refined attention maps attention_maps = self.output_conv(features) return attention_maps class GeometricParameterEstimator(nn.Module): """Network for estimating geometric parameters from attention maps.""" def __init__(self): super().__init__() # This could be expanded for more sophisticated geometric estimation pass # Training utilities def contour_detection_loss(attention_maps: torch.Tensor, ground_truth_masks: torch.Tensor, vehicle_states: List[Dict], gt_states: List[Dict]) -> torch.Tensor: """ Combined loss function for training the differentiable contour detection model. Args: attention_maps: Predicted attention maps [B, 2, H, W] ground_truth_masks: GT segmentation masks [B, 2, H, W] vehicle_states: Predicted vehicle states gt_states: Ground truth vehicle states Returns: total_loss: Combined loss value """ # Segmentation loss seg_loss = F.binary_cross_entropy(attention_maps, ground_truth_masks) # State estimation loss (if GT states available) state_loss = torch.tensor(0.0, device=attention_maps.device) if vehicle_states and gt_states: # This would need more sophisticated matching between predicted and GT states # For now, placeholder pass total_loss = seg_loss + 0.1 * state_loss return total_loss # Compatibility wrapper to match original interface class DifferentiableContourDetectionModel: """Wrapper to provide same interface as original ContourDetectionModel.""" def __init__(self, config: DifferentiableDetectionConfig = None): self.model = DifferentiableContourDetection(config) self.model.eval() # Set to eval mode by default def detect_vehicles(self, image_path: str) -> Tuple[Optional[np.ndarray], List[Dict]]: """ Detect vehicles using the differentiable model. Compatible with original interface. """ # Load and preprocess image img = cv2.imread(image_path) if img is None: return None, [] # Convert to tensor img_tensor = torch.from_numpy(img).float() / 255.0 img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0) # [1, 3, H, W] with torch.no_grad(): attention_maps, vehicle_states = self.model(img_tensor) # Create visualization annotated_img = self._create_visualization(img, attention_maps, vehicle_states) return annotated_img, vehicle_states def _create_visualization(self, original_img: np.ndarray, attention_maps: torch.Tensor, vehicle_states: List[Dict]) -> np.ndarray: """Create visualization of detection results.""" annotated_img = original_img.copy() # Overlay attention maps attention_np = attention_maps[0].cpu().numpy() # [2, H, W] for i, (attention_map, color) in enumerate(zip(attention_np, [(0, 255, 0), (255, 0, 0)])): # Convert attention to color overlay overlay = np.zeros_like(original_img) overlay[:, :] = color # Apply attention as alpha alpha = (attention_map * 0.3).clip(0, 1) for c in range(3): annotated_img[:, :, c] = (1 - alpha) * annotated_img[:, :, c] + alpha * overlay[:, :, c] # Draw vehicle states for state in vehicle_states: pos_x, pos_y = int(state['position_x']), int(state['position_y']) heading = state['heading'] # Draw center point cv2.circle(annotated_img, (pos_x, pos_y), 5, (0, 0, 255), -1) # Draw heading vector length = 40 angle_rad = np.deg2rad(heading) end_x = int(pos_x + length * np.cos(angle_rad)) end_y = int(pos_y + length * np.sin(angle_rad)) cv2.line(annotated_img, (pos_x, pos_y), (end_x, end_y), (255, 0, 0), 2) # Add label label = f"H: {heading:.1f}" cv2.putText(annotated_img, label, (pos_x + 10, pos_y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 255), 2) return annotated_img if __name__ == '__main__': # Example usage config = DifferentiableDetectionConfig( learn_color_ranges=True, color_softness=10.0 ) model = DifferentiableContourDetectionModel(config) # Test with an image input_image_path = '/home/alienware3/Documents/diamond/frames/frame_0.png' annotated_image, states = model.detect_vehicles(input_image_path) if annotated_image is not None: cv2.imwrite('differentiable_detection_output.png', annotated_image) print(f"Detected {len(states)} vehicles") for i, state in enumerate(states): print(f"Vehicle {i+1}: {state}")