PIWM / src /models /differentiable_contour_detection.py
musictimer's picture
Initial Diamond CSGO AI deployment
c64c726
"""
Differentiable Vehicle Detection Model using Soft Segmentation and Geometric Learning
This module provides a differentiable alternative to the traditional contour detection approach.
It uses soft attention mechanisms, differentiable color space operations, and learned geometric
primitives to enable end-to-end gradient-based optimization.
Key differences from traditional approach:
1. Soft thresholding instead of hard color masking
2. Attention-based "soft contours" instead of discrete contours
3. Differentiable geometric operations using PyTorch
4. Learnable color ranges and geometric parameters
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2
import math
from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass
@dataclass
class DifferentiableDetectionConfig:
"""Configuration for differentiable vehicle detection."""
# Learnable color parameters
learn_color_ranges: bool = True
color_softness: float = 10.0 # Controls sharpness of soft thresholding
# Attention mechanism
attention_hidden_dim: int = 32
num_attention_heads: int = 4
# Geometric estimation
min_blob_size: float = 0.01 # Minimum relative area for valid detection
position_smoothing: float = 0.1 # Smoothing factor for position estimation
# Training parameters
temperature: float = 1.0 # Temperature for soft operations
class DifferentiableContourDetection(nn.Module):
"""
A differentiable vehicle detection model that replaces traditional CV operations
with learnable, gradient-friendly alternatives.
"""
def __init__(self, config: DifferentiableDetectionConfig = None):
super().__init__()
self.config = config or DifferentiableDetectionConfig()
# Learnable color ranges (HSV format)
if self.config.learn_color_ranges:
# Initialize with reasonable defaults, then make them learnable
self.green_hsv_lower = nn.Parameter(torch.tensor([50., 100., 100.]) / 180.) # Normalize to [0,1]
self.green_hsv_upper = nn.Parameter(torch.tensor([70., 255., 255.]) / 255.)
self.blue_hsv_lower = nn.Parameter(torch.tensor([80., 80., 80.]) / 180.)
self.blue_hsv_upper = nn.Parameter(torch.tensor([115., 255., 255.]) / 255.)
else:
# Fixed color ranges
self.register_buffer('green_hsv_lower', torch.tensor([50., 100., 100.]) / 180.)
self.register_buffer('green_hsv_upper', torch.tensor([70., 255., 255.]) / 255.)
self.register_buffer('blue_hsv_lower', torch.tensor([80., 80., 80.]) / 180.)
self.register_buffer('blue_hsv_upper', torch.tensor([115., 255., 255.]) / 255.)
# Attention mechanism for spatial reasoning
self.spatial_attention = SpatialAttentionModule(
hidden_dim=self.config.attention_hidden_dim,
num_heads=self.config.num_attention_heads
)
# Geometric parameter estimator
self.geometry_estimator = GeometricParameterEstimator()
def forward(self, image: torch.Tensor) -> Tuple[torch.Tensor, List[Dict]]:
"""
Forward pass for differentiable vehicle detection.
Args:
image: Input image tensor [B, C, H, W] in BGR format (0-1 range)
Returns:
- attention_maps: Soft segmentation maps [B, 2, H, W] for [green, blue]
- vehicle_states: List of estimated vehicle states
"""
batch_size, channels, height, width = image.shape
# 1. Convert BGR to HSV (differentiable)
hsv_image = self.bgr_to_hsv_differentiable(image)
# 2. Create soft color masks
green_mask = self.soft_color_mask(hsv_image, self.green_hsv_lower, self.green_hsv_upper)
blue_mask = self.soft_color_mask(hsv_image, self.blue_hsv_lower, self.blue_hsv_upper)
# 3. Apply spatial attention to refine masks
color_masks = torch.stack([green_mask, blue_mask], dim=1) # [B, 2, H, W]
attention_maps = self.spatial_attention(color_masks, image)
# 4. Extract vehicle states from attention maps
vehicle_states = self.extract_states_from_attention(attention_maps)
return attention_maps, vehicle_states
def bgr_to_hsv_differentiable(self, bgr: torch.Tensor) -> torch.Tensor:
"""
Differentiable BGR to HSV conversion.
Args:
bgr: Input tensor [B, 3, H, W] in BGR format
Returns:
hsv: Output tensor [B, 3, H, W] in HSV format
"""
# Assume input is already BGR ordered
b, g, r = bgr[:, 0:1], bgr[:, 1:2], bgr[:, 2:3]
max_rgb, _ = torch.max(torch.cat([r, g, b], dim=1), dim=1, keepdim=True)
min_rgb, _ = torch.min(torch.cat([r, g, b], dim=1), dim=1, keepdim=True)
diff = max_rgb - min_rgb
# Value
v = max_rgb
# Saturation
s = torch.where(max_rgb > 0, diff / max_rgb, torch.zeros_like(max_rgb))
# Hue (simplified, approximate differentiable version)
h = torch.zeros_like(max_rgb)
# Red is max
red_max = (max_rgb == r).float()
h = h + red_max * (((g - b) / (diff + 1e-8)) % 6)
# Green is max
green_max = (max_rgb == g).float()
h = h + green_max * (((b - r) / (diff + 1e-8)) + 2)
# Blue is max
blue_max = (max_rgb == b).float()
h = h + blue_max * (((r - g) / (diff + 1e-8)) + 4)
h = h * 60 / 360 # Normalize to [0, 1]
return torch.cat([h, s, v], dim=1)
def soft_color_mask(self, hsv: torch.Tensor, lower: torch.Tensor, upper: torch.Tensor) -> torch.Tensor:
"""
Create soft color mask using sigmoid functions instead of hard thresholding.
Args:
hsv: HSV image [B, 3, H, W]
lower: Lower HSV bounds [3]
upper: Upper HSV bounds [3]
Returns:
mask: Soft mask [B, H, W]
"""
h, s, v = hsv[:, 0], hsv[:, 1], hsv[:, 2]
# Soft thresholding using sigmoid
softness = self.config.color_softness
h_mask = torch.sigmoid(softness * (h - lower[0])) * torch.sigmoid(softness * (upper[0] - h))
s_mask = torch.sigmoid(softness * (s - lower[1])) * torch.sigmoid(softness * (upper[1] - s))
v_mask = torch.sigmoid(softness * (v - lower[2])) * torch.sigmoid(softness * (upper[2] - v))
return h_mask * s_mask * v_mask
def extract_states_from_attention(self, attention_maps: torch.Tensor) -> List[Dict]:
"""
Extract vehicle states from soft attention maps.
Args:
attention_maps: Attention maps [B, 2, H, W] for [green, blue]
Returns:
List of vehicle state dictionaries
"""
states = []
batch_size = attention_maps.shape[0]
for b in range(batch_size):
green_map = attention_maps[b, 0] # [H, W]
blue_map = attention_maps[b, 1] # [H, W]
# Process each color channel
for color_idx, (color_map, class_name) in enumerate([(green_map, "ego_vehicle"), (blue_map, "other_vehicle")]):
# Find significant blobs using thresholding
threshold = 0.5
binary_mask = (color_map > threshold).float()
# Skip if blob is too small
blob_size = binary_mask.sum() / (color_map.shape[0] * color_map.shape[1])
if blob_size < self.config.min_blob_size:
continue
# Estimate position using weighted centroid
pos_y, pos_x = self.estimate_position_differentiable(color_map)
# Estimate heading using spatial gradients
heading = self.estimate_heading_differentiable(color_map, pos_x, pos_y)
states.append({
"class": class_name,
"position_x": pos_x.item(),
"position_y": pos_y.item(),
"heading": heading.item(),
"speed": 0.0, # Placeholder
"confidence": blob_size.item()
})
return states
def estimate_position_differentiable(self, attention_map: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Estimate position using differentiable weighted centroid.
"""
h, w = attention_map.shape
# Create coordinate grids
y_coords = torch.arange(h, dtype=attention_map.dtype, device=attention_map.device).view(-1, 1)
x_coords = torch.arange(w, dtype=attention_map.dtype, device=attention_map.device).view(1, -1)
# Weighted centroid
total_weight = attention_map.sum() + 1e-8
pos_y = (attention_map * y_coords).sum() / total_weight
pos_x = (attention_map * x_coords).sum() / total_weight
return pos_y, pos_x
def estimate_heading_differentiable(self, attention_map: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor) -> torch.Tensor:
"""
Estimate heading using spatial moment analysis.
"""
h, w = attention_map.shape
# Create coordinate grids relative to centroid
y_coords = torch.arange(h, dtype=attention_map.dtype, device=attention_map.device).view(-1, 1) - pos_y
x_coords = torch.arange(w, dtype=attention_map.dtype, device=attention_map.device).view(1, -1) - pos_x
# Second moments
total_weight = attention_map.sum() + 1e-8
mu_20 = (attention_map * x_coords.pow(2)).sum() / total_weight
mu_02 = (attention_map * y_coords.pow(2)).sum() / total_weight
mu_11 = (attention_map * x_coords * y_coords).sum() / total_weight
# Principal axis angle
theta = 0.5 * torch.atan2(2 * mu_11, mu_20 - mu_02)
heading_degrees = theta * 180 / math.pi
return heading_degrees
class SpatialAttentionModule(nn.Module):
"""Spatial attention module for refining color-based segmentation."""
def __init__(self, hidden_dim: int = 64, num_heads: int = 4):
super().__init__()
self.hidden_dim = hidden_dim
# Feature extraction
self.feature_conv = nn.Sequential(
nn.Conv2d(5, hidden_dim, 3, padding=1), # 2 color masks + 3 RGB
nn.ReLU(),
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
nn.ReLU()
)
# Replace expensive MultiheadAttention with efficient channel attention
self.channel_attention = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(hidden_dim, hidden_dim // 4, 1),
nn.ReLU(),
nn.Conv2d(hidden_dim // 4, hidden_dim, 1),
nn.Sigmoid()
)
# Spatial attention using depth-wise separable convolutions
self.spatial_attention = nn.Sequential(
nn.Conv2d(hidden_dim, hidden_dim, 7, padding=3, groups=hidden_dim), # Depth-wise
nn.Conv2d(hidden_dim, 1, 1), # Point-wise
nn.Sigmoid()
)
# Output projection
self.output_conv = nn.Sequential(
nn.Conv2d(hidden_dim, hidden_dim // 2, 3, padding=1),
nn.ReLU(),
nn.Conv2d(hidden_dim // 2, 2, 3, padding=1),
nn.Sigmoid()
)
def forward(self, color_masks: torch.Tensor, image: torch.Tensor) -> torch.Tensor:
"""
Args:
color_masks: [B, 2, H, W] soft color masks
image: [B, 3, H, W] original image
Returns:
attention_maps: [B, 2, H, W] refined attention maps
"""
# Combine inputs
features = torch.cat([color_masks, image], dim=1) # [B, 5, H, W]
# Extract features
features = self.feature_conv(features) # [B, hidden_dim, H, W]
# Apply channel attention
channel_weights = self.channel_attention(features)
features = features * channel_weights
# Apply spatial attention
spatial_weights = self.spatial_attention(features)
features = features * spatial_weights
# Generate refined attention maps
attention_maps = self.output_conv(features)
return attention_maps
class GeometricParameterEstimator(nn.Module):
"""Network for estimating geometric parameters from attention maps."""
def __init__(self):
super().__init__()
# This could be expanded for more sophisticated geometric estimation
pass
# Training utilities
def contour_detection_loss(attention_maps: torch.Tensor,
ground_truth_masks: torch.Tensor,
vehicle_states: List[Dict],
gt_states: List[Dict]) -> torch.Tensor:
"""
Combined loss function for training the differentiable contour detection model.
Args:
attention_maps: Predicted attention maps [B, 2, H, W]
ground_truth_masks: GT segmentation masks [B, 2, H, W]
vehicle_states: Predicted vehicle states
gt_states: Ground truth vehicle states
Returns:
total_loss: Combined loss value
"""
# Segmentation loss
seg_loss = F.binary_cross_entropy(attention_maps, ground_truth_masks)
# State estimation loss (if GT states available)
state_loss = torch.tensor(0.0, device=attention_maps.device)
if vehicle_states and gt_states:
# This would need more sophisticated matching between predicted and GT states
# For now, placeholder
pass
total_loss = seg_loss + 0.1 * state_loss
return total_loss
# Compatibility wrapper to match original interface
class DifferentiableContourDetectionModel:
"""Wrapper to provide same interface as original ContourDetectionModel."""
def __init__(self, config: DifferentiableDetectionConfig = None):
self.model = DifferentiableContourDetection(config)
self.model.eval() # Set to eval mode by default
def detect_vehicles(self, image_path: str) -> Tuple[Optional[np.ndarray], List[Dict]]:
"""
Detect vehicles using the differentiable model.
Compatible with original interface.
"""
# Load and preprocess image
img = cv2.imread(image_path)
if img is None:
return None, []
# Convert to tensor
img_tensor = torch.from_numpy(img).float() / 255.0
img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0) # [1, 3, H, W]
with torch.no_grad():
attention_maps, vehicle_states = self.model(img_tensor)
# Create visualization
annotated_img = self._create_visualization(img, attention_maps, vehicle_states)
return annotated_img, vehicle_states
def _create_visualization(self, original_img: np.ndarray,
attention_maps: torch.Tensor,
vehicle_states: List[Dict]) -> np.ndarray:
"""Create visualization of detection results."""
annotated_img = original_img.copy()
# Overlay attention maps
attention_np = attention_maps[0].cpu().numpy() # [2, H, W]
for i, (attention_map, color) in enumerate(zip(attention_np, [(0, 255, 0), (255, 0, 0)])):
# Convert attention to color overlay
overlay = np.zeros_like(original_img)
overlay[:, :] = color
# Apply attention as alpha
alpha = (attention_map * 0.3).clip(0, 1)
for c in range(3):
annotated_img[:, :, c] = (1 - alpha) * annotated_img[:, :, c] + alpha * overlay[:, :, c]
# Draw vehicle states
for state in vehicle_states:
pos_x, pos_y = int(state['position_x']), int(state['position_y'])
heading = state['heading']
# Draw center point
cv2.circle(annotated_img, (pos_x, pos_y), 5, (0, 0, 255), -1)
# Draw heading vector
length = 40
angle_rad = np.deg2rad(heading)
end_x = int(pos_x + length * np.cos(angle_rad))
end_y = int(pos_y + length * np.sin(angle_rad))
cv2.line(annotated_img, (pos_x, pos_y), (end_x, end_y), (255, 0, 0), 2)
# Add label
label = f"H: {heading:.1f}"
cv2.putText(annotated_img, label, (pos_x + 10, pos_y - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 255), 2)
return annotated_img
if __name__ == '__main__':
# Example usage
config = DifferentiableDetectionConfig(
learn_color_ranges=True,
color_softness=10.0
)
model = DifferentiableContourDetectionModel(config)
# Test with an image
input_image_path = '/home/alienware3/Documents/diamond/frames/frame_0.png'
annotated_image, states = model.detect_vehicles(input_image_path)
if annotated_image is not None:
cv2.imwrite('differentiable_detection_output.png', annotated_image)
print(f"Detected {len(states)} vehicles")
for i, state in enumerate(states):
print(f"Vehicle {i+1}: {state}")