File size: 17,774 Bytes
c64c726
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
"""
Differentiable Vehicle Detection Model using Soft Segmentation and Geometric Learning

This module provides a differentiable alternative to the traditional contour detection approach.
It uses soft attention mechanisms, differentiable color space operations, and learned geometric
primitives to enable end-to-end gradient-based optimization.

Key differences from traditional approach:
1. Soft thresholding instead of hard color masking
2. Attention-based "soft contours" instead of discrete contours  
3. Differentiable geometric operations using PyTorch
4. Learnable color ranges and geometric parameters
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2
import math
from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass


@dataclass
class DifferentiableDetectionConfig:
    """Configuration for differentiable vehicle detection."""
    # Learnable color parameters
    learn_color_ranges: bool = True
    color_softness: float = 10.0  # Controls sharpness of soft thresholding
    
    # Attention mechanism
    attention_hidden_dim: int = 32
    num_attention_heads: int = 4
    
    # Geometric estimation
    min_blob_size: float = 0.01  # Minimum relative area for valid detection
    position_smoothing: float = 0.1  # Smoothing factor for position estimation
    
    # Training parameters
    temperature: float = 1.0  # Temperature for soft operations


class DifferentiableContourDetection(nn.Module):
    """
    A differentiable vehicle detection model that replaces traditional CV operations
    with learnable, gradient-friendly alternatives.
    """
    
    def __init__(self, config: DifferentiableDetectionConfig = None):
        super().__init__()
        self.config = config or DifferentiableDetectionConfig()
        
        # Learnable color ranges (HSV format)
        if self.config.learn_color_ranges:
            # Initialize with reasonable defaults, then make them learnable
            self.green_hsv_lower = nn.Parameter(torch.tensor([50., 100., 100.]) / 180.)  # Normalize to [0,1]
            self.green_hsv_upper = nn.Parameter(torch.tensor([70., 255., 255.]) / 255.)
            self.blue_hsv_lower = nn.Parameter(torch.tensor([80., 80., 80.]) / 180.)
            self.blue_hsv_upper = nn.Parameter(torch.tensor([115., 255., 255.]) / 255.)
        else:
            # Fixed color ranges
            self.register_buffer('green_hsv_lower', torch.tensor([50., 100., 100.]) / 180.)
            self.register_buffer('green_hsv_upper', torch.tensor([70., 255., 255.]) / 255.)
            self.register_buffer('blue_hsv_lower', torch.tensor([80., 80., 80.]) / 180.)
            self.register_buffer('blue_hsv_upper', torch.tensor([115., 255., 255.]) / 255.)
        
        # Attention mechanism for spatial reasoning
        self.spatial_attention = SpatialAttentionModule(
            hidden_dim=self.config.attention_hidden_dim,
            num_heads=self.config.num_attention_heads
        )
        
        # Geometric parameter estimator
        self.geometry_estimator = GeometricParameterEstimator()
        
    def forward(self, image: torch.Tensor) -> Tuple[torch.Tensor, List[Dict]]:
        """
        Forward pass for differentiable vehicle detection.
        
        Args:
            image: Input image tensor [B, C, H, W] in BGR format (0-1 range)
            
        Returns:
            - attention_maps: Soft segmentation maps [B, 2, H, W] for [green, blue]
            - vehicle_states: List of estimated vehicle states
        """
        batch_size, channels, height, width = image.shape
        
        # 1. Convert BGR to HSV (differentiable)
        hsv_image = self.bgr_to_hsv_differentiable(image)
        
        # 2. Create soft color masks
        green_mask = self.soft_color_mask(hsv_image, self.green_hsv_lower, self.green_hsv_upper)
        blue_mask = self.soft_color_mask(hsv_image, self.blue_hsv_lower, self.blue_hsv_upper)
        
        # 3. Apply spatial attention to refine masks
        color_masks = torch.stack([green_mask, blue_mask], dim=1)  # [B, 2, H, W]
        attention_maps = self.spatial_attention(color_masks, image)
        
        # 4. Extract vehicle states from attention maps
        vehicle_states = self.extract_states_from_attention(attention_maps)
        
        return attention_maps, vehicle_states
    
    def bgr_to_hsv_differentiable(self, bgr: torch.Tensor) -> torch.Tensor:
        """
        Differentiable BGR to HSV conversion.
        
        Args:
            bgr: Input tensor [B, 3, H, W] in BGR format
            
        Returns:
            hsv: Output tensor [B, 3, H, W] in HSV format
        """
        # Assume input is already BGR ordered
        b, g, r = bgr[:, 0:1], bgr[:, 1:2], bgr[:, 2:3]
        
        max_rgb, _ = torch.max(torch.cat([r, g, b], dim=1), dim=1, keepdim=True)
        min_rgb, _ = torch.min(torch.cat([r, g, b], dim=1), dim=1, keepdim=True)
        
        diff = max_rgb - min_rgb
        
        # Value
        v = max_rgb
        
        # Saturation
        s = torch.where(max_rgb > 0, diff / max_rgb, torch.zeros_like(max_rgb))
        
        # Hue (simplified, approximate differentiable version)
        h = torch.zeros_like(max_rgb)
        
        # Red is max
        red_max = (max_rgb == r).float()
        h = h + red_max * (((g - b) / (diff + 1e-8)) % 6)
        
        # Green is max  
        green_max = (max_rgb == g).float()
        h = h + green_max * (((b - r) / (diff + 1e-8)) + 2)
        
        # Blue is max
        blue_max = (max_rgb == b).float()
        h = h + blue_max * (((r - g) / (diff + 1e-8)) + 4)
        
        h = h * 60 / 360  # Normalize to [0, 1]
        
        return torch.cat([h, s, v], dim=1)
    
    def soft_color_mask(self, hsv: torch.Tensor, lower: torch.Tensor, upper: torch.Tensor) -> torch.Tensor:
        """
        Create soft color mask using sigmoid functions instead of hard thresholding.
        
        Args:
            hsv: HSV image [B, 3, H, W]
            lower: Lower HSV bounds [3]
            upper: Upper HSV bounds [3]
            
        Returns:
            mask: Soft mask [B, H, W]
        """
        h, s, v = hsv[:, 0], hsv[:, 1], hsv[:, 2]
        
        # Soft thresholding using sigmoid
        softness = self.config.color_softness
        
        h_mask = torch.sigmoid(softness * (h - lower[0])) * torch.sigmoid(softness * (upper[0] - h))
        s_mask = torch.sigmoid(softness * (s - lower[1])) * torch.sigmoid(softness * (upper[1] - s))
        v_mask = torch.sigmoid(softness * (v - lower[2])) * torch.sigmoid(softness * (upper[2] - v))
        
        return h_mask * s_mask * v_mask
    
    def extract_states_from_attention(self, attention_maps: torch.Tensor) -> List[Dict]:
        """
        Extract vehicle states from soft attention maps.
        
        Args:
            attention_maps: Attention maps [B, 2, H, W] for [green, blue]
            
        Returns:
            List of vehicle state dictionaries
        """
        states = []
        batch_size = attention_maps.shape[0]
        
        for b in range(batch_size):
            green_map = attention_maps[b, 0]  # [H, W]
            blue_map = attention_maps[b, 1]   # [H, W]
            
            # Process each color channel
            for color_idx, (color_map, class_name) in enumerate([(green_map, "ego_vehicle"), (blue_map, "other_vehicle")]):
                # Find significant blobs using thresholding
                threshold = 0.5
                binary_mask = (color_map > threshold).float()
                
                # Skip if blob is too small
                blob_size = binary_mask.sum() / (color_map.shape[0] * color_map.shape[1])
                if blob_size < self.config.min_blob_size:
                    continue
                
                # Estimate position using weighted centroid
                pos_y, pos_x = self.estimate_position_differentiable(color_map)
                
                # Estimate heading using spatial gradients
                heading = self.estimate_heading_differentiable(color_map, pos_x, pos_y)
                
                states.append({
                    "class": class_name,
                    "position_x": pos_x.item(),
                    "position_y": pos_y.item(), 
                    "heading": heading.item(),
                    "speed": 0.0,  # Placeholder
                    "confidence": blob_size.item()
                })
        
        return states
    
    def estimate_position_differentiable(self, attention_map: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Estimate position using differentiable weighted centroid.
        """
        h, w = attention_map.shape
        
        # Create coordinate grids
        y_coords = torch.arange(h, dtype=attention_map.dtype, device=attention_map.device).view(-1, 1)
        x_coords = torch.arange(w, dtype=attention_map.dtype, device=attention_map.device).view(1, -1)
        
        # Weighted centroid
        total_weight = attention_map.sum() + 1e-8
        pos_y = (attention_map * y_coords).sum() / total_weight
        pos_x = (attention_map * x_coords).sum() / total_weight
        
        return pos_y, pos_x
    
    def estimate_heading_differentiable(self, attention_map: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor) -> torch.Tensor:
        """
        Estimate heading using spatial moment analysis.
        """
        h, w = attention_map.shape
        
        # Create coordinate grids relative to centroid
        y_coords = torch.arange(h, dtype=attention_map.dtype, device=attention_map.device).view(-1, 1) - pos_y
        x_coords = torch.arange(w, dtype=attention_map.dtype, device=attention_map.device).view(1, -1) - pos_x
        
        # Second moments
        total_weight = attention_map.sum() + 1e-8
        mu_20 = (attention_map * x_coords.pow(2)).sum() / total_weight
        mu_02 = (attention_map * y_coords.pow(2)).sum() / total_weight  
        mu_11 = (attention_map * x_coords * y_coords).sum() / total_weight
        
        # Principal axis angle
        theta = 0.5 * torch.atan2(2 * mu_11, mu_20 - mu_02)
        heading_degrees = theta * 180 / math.pi
        
        return heading_degrees


class SpatialAttentionModule(nn.Module):
    """Spatial attention module for refining color-based segmentation."""
    
    def __init__(self, hidden_dim: int = 64, num_heads: int = 4):
        super().__init__()
        self.hidden_dim = hidden_dim
        
        # Feature extraction
        self.feature_conv = nn.Sequential(
            nn.Conv2d(5, hidden_dim, 3, padding=1),  # 2 color masks + 3 RGB
            nn.ReLU(),
            nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
            nn.ReLU()
        )
        
        # Replace expensive MultiheadAttention with efficient channel attention
        self.channel_attention = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(hidden_dim, hidden_dim // 4, 1),
            nn.ReLU(),
            nn.Conv2d(hidden_dim // 4, hidden_dim, 1),
            nn.Sigmoid()
        )
        
        # Spatial attention using depth-wise separable convolutions
        self.spatial_attention = nn.Sequential(
            nn.Conv2d(hidden_dim, hidden_dim, 7, padding=3, groups=hidden_dim),  # Depth-wise
            nn.Conv2d(hidden_dim, 1, 1),  # Point-wise
            nn.Sigmoid()
        )
        
        # Output projection
        self.output_conv = nn.Sequential(
            nn.Conv2d(hidden_dim, hidden_dim // 2, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_dim // 2, 2, 3, padding=1),
            nn.Sigmoid()
        )
    
    def forward(self, color_masks: torch.Tensor, image: torch.Tensor) -> torch.Tensor:
        """
        Args:
            color_masks: [B, 2, H, W] soft color masks
            image: [B, 3, H, W] original image
            
        Returns:
            attention_maps: [B, 2, H, W] refined attention maps
        """
        # Combine inputs
        features = torch.cat([color_masks, image], dim=1)  # [B, 5, H, W]
        
        # Extract features
        features = self.feature_conv(features)  # [B, hidden_dim, H, W]
        
        # Apply channel attention
        channel_weights = self.channel_attention(features)
        features = features * channel_weights
        
        # Apply spatial attention
        spatial_weights = self.spatial_attention(features)
        features = features * spatial_weights
        
        # Generate refined attention maps
        attention_maps = self.output_conv(features)
        
        return attention_maps


class GeometricParameterEstimator(nn.Module):
    """Network for estimating geometric parameters from attention maps."""
    
    def __init__(self):
        super().__init__()
        # This could be expanded for more sophisticated geometric estimation
        pass


# Training utilities

def contour_detection_loss(attention_maps: torch.Tensor, 
                          ground_truth_masks: torch.Tensor,
                          vehicle_states: List[Dict],
                          gt_states: List[Dict]) -> torch.Tensor:
    """
    Combined loss function for training the differentiable contour detection model.
    
    Args:
        attention_maps: Predicted attention maps [B, 2, H, W]
        ground_truth_masks: GT segmentation masks [B, 2, H, W] 
        vehicle_states: Predicted vehicle states
        gt_states: Ground truth vehicle states
        
    Returns:
        total_loss: Combined loss value
    """
    # Segmentation loss
    seg_loss = F.binary_cross_entropy(attention_maps, ground_truth_masks)
    
    # State estimation loss (if GT states available)
    state_loss = torch.tensor(0.0, device=attention_maps.device)
    if vehicle_states and gt_states:
        # This would need more sophisticated matching between predicted and GT states
        # For now, placeholder
        pass
    
    total_loss = seg_loss + 0.1 * state_loss
    
    return total_loss


# Compatibility wrapper to match original interface
class DifferentiableContourDetectionModel:
    """Wrapper to provide same interface as original ContourDetectionModel."""
    
    def __init__(self, config: DifferentiableDetectionConfig = None):
        self.model = DifferentiableContourDetection(config)
        self.model.eval()  # Set to eval mode by default
    
    def detect_vehicles(self, image_path: str) -> Tuple[Optional[np.ndarray], List[Dict]]:
        """
        Detect vehicles using the differentiable model.
        Compatible with original interface.
        """
        # Load and preprocess image
        img = cv2.imread(image_path)
        if img is None:
            return None, []
        
        # Convert to tensor
        img_tensor = torch.from_numpy(img).float() / 255.0
        img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0)  # [1, 3, H, W]
        
        with torch.no_grad():
            attention_maps, vehicle_states = self.model(img_tensor)
        
        # Create visualization
        annotated_img = self._create_visualization(img, attention_maps, vehicle_states)
        
        return annotated_img, vehicle_states
    
    def _create_visualization(self, original_img: np.ndarray, 
                            attention_maps: torch.Tensor, 
                            vehicle_states: List[Dict]) -> np.ndarray:
        """Create visualization of detection results."""
        annotated_img = original_img.copy()
        
        # Overlay attention maps
        attention_np = attention_maps[0].cpu().numpy()  # [2, H, W]
        
        for i, (attention_map, color) in enumerate(zip(attention_np, [(0, 255, 0), (255, 0, 0)])):
            # Convert attention to color overlay
            overlay = np.zeros_like(original_img)
            overlay[:, :] = color
            
            # Apply attention as alpha
            alpha = (attention_map * 0.3).clip(0, 1)
            for c in range(3):
                annotated_img[:, :, c] = (1 - alpha) * annotated_img[:, :, c] + alpha * overlay[:, :, c]
        
        # Draw vehicle states
        for state in vehicle_states:
            pos_x, pos_y = int(state['position_x']), int(state['position_y'])
            heading = state['heading']
            
            # Draw center point
            cv2.circle(annotated_img, (pos_x, pos_y), 5, (0, 0, 255), -1)
            
            # Draw heading vector
            length = 40
            angle_rad = np.deg2rad(heading)
            end_x = int(pos_x + length * np.cos(angle_rad))
            end_y = int(pos_y + length * np.sin(angle_rad))
            cv2.line(annotated_img, (pos_x, pos_y), (end_x, end_y), (255, 0, 0), 2)
            
            # Add label
            label = f"H: {heading:.1f}"
            cv2.putText(annotated_img, label, (pos_x + 10, pos_y - 10), 
                       cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 255), 2)
        
        return annotated_img


if __name__ == '__main__':
    # Example usage
    config = DifferentiableDetectionConfig(
        learn_color_ranges=True,
        color_softness=10.0
    )
    
    model = DifferentiableContourDetectionModel(config)
    
    # Test with an image
    input_image_path = '/home/alienware3/Documents/diamond/frames/frame_0.png'
    annotated_image, states = model.detect_vehicles(input_image_path)
    
    if annotated_image is not None:
        cv2.imwrite('differentiable_detection_output.png', annotated_image)
        print(f"Detected {len(states)} vehicles")
        for i, state in enumerate(states):
            print(f"Vehicle {i+1}: {state}")