""" Single Detection Feature Extractor Extract high-dimensional visual features for individual detection boxes. This provides the unique visual representation for each specific detection. """ import torch import numpy as np from typing import Dict, List, Tuple, Optional import torch.nn.functional as F from dataclasses import dataclass @dataclass class SingleDetectionFeature: """ Represents the visual features for a single, specific detection box. """ # The unique visual signature of this detection visual_features: torch.Tensor # Shape: [feature_dim], e.g., [56] # Detection information bbox: torch.Tensor # [x1, y1, x2, y2, score, class_id] center_point: Tuple[float, float] # (center_x, center_y) # Feature mapping info feature_position: int # Position index in x_cat [0, HW-1] feature_coordinates: Tuple[int, int] # (feat_x, feat_y) in feature grid # Metadata confidence: float class_id: int image_idx: int # Feature quality metrics feature_norm: float # L2 norm of the feature vector feature_activation: float # Overall activation strength class SingleDetectionFeatureExtractor: """ Extracts unique visual features for individual detection boxes from YOLO output. This provides the most accurate visual representation for each specific detection. """ def __init__(self): self.current_detections: List[SingleDetectionFeature] = [] self.detection_history: List[List[SingleDetectionFeature]] = [] def extract_single_detection_features( self, x_cat: torch.Tensor, # [B, feature_dim, HW] - the core visual features detection_bboxes: torch.Tensor, # [B, N, 6] - final detections after NMS feature_maps: List[torch.Tensor], # Multi-scale feature maps image_size: Tuple[int, int] = (544, 544), confidence_threshold: float = 0.25 ) -> List[SingleDetectionFeature]: """ Extract visual features for individual detection boxes. Args: x_cat: Raw concatenated features from YOLO head [B, feature_dim, HW] detection_bboxes: Final detection boxes after NMS [B, N, 6] feature_maps: List of multi-scale feature maps for precise mapping image_size: Input image size (H, W) confidence_threshold: Minimum confidence to consider a detection Returns: List of SingleDetectionFeature objects, one per valid detection """ batch_size, feature_dim, hw = x_cat.shape print(f"Extracting features from x_cat.shape: {x_cat.shape}") print(f"Processing {batch_size} images with {hw} total positions per image") detections = [] # Process each image in the batch for batch_idx in range(batch_size): # Get valid detections for this image img_detections = detection_bboxes[batch_idx] valid_detections = img_detections[img_detections[:, 4] > confidence_threshold] print(f"Image {batch_idx}: Found {len(valid_detections)} valid detections") # Process each individual detection for det_idx, detection in enumerate(valid_detections): bbox = detection # [x1, y1, x2, y2, score, class_id] confidence = float(detection[4]) class_id = int(detection[5]) # Calculate center point of the detection center_x = (bbox[0] + bbox[2]) / 2 center_y = (bbox[1] + bbox[3]) / 2 # Find the precise feature position for this detection feature_pos, feat_coords = self._map_detection_to_feature_position( center_x, center_y, x_cat, feature_maps, image_size, batch_idx ) # Extract the visual features for this specific detection visual_features = x_cat[batch_idx, :, feature_pos] # [feature_dim] # Calculate feature quality metrics feature_norm = torch.norm(visual_features).item() feature_activation = torch.mean(torch.abs(visual_features)).item() # Create single detection feature object single_feature = SingleDetectionFeature( visual_features=visual_features, bbox=bbox, center_point=(float(center_x), float(center_y)), feature_position=feature_pos, feature_coordinates=feat_coords, confidence=confidence, class_id=class_id, image_idx=batch_idx, feature_norm=feature_norm, feature_activation=feature_activation ) detections.append(single_feature) print(f" Detection {det_idx+1}:") print(f" BBox: [{bbox[0]:.1f}, {bbox[1]:.1f}, {bbox[2]:.1f}, {bbox[3]:.1f}]") print(f" Center: ({center_x:.1f}, {center_y:.1f})") print(f" Feature position: {feature_pos} -> coordinates {feat_coords}") print(f" Feature shape: {list(visual_features.shape)}") print(f" Feature norm: {feature_norm:.4f}") print(f" Feature activation: {feature_activation:.4f}") self.current_detections = detections self.detection_history.append(detections.copy()) return detections def _map_detection_to_feature_position( self, center_x: float, center_y: float, x_cat: torch.Tensor, feature_maps: List[torch.Tensor], image_size: Tuple[int, int], batch_idx: int ) -> Tuple[int, Tuple[int, int]]: """ Map detection center to precise position in x_cat feature tensor. This is the most critical part - accurately finding which feature in x_cat corresponds to this specific detection. """ _, feature_dim, hw = x_cat.shape img_h, img_w = image_size # Strategy 1: Use the first (largest) feature map for precise mapping if feature_maps and len(feature_maps) > 0: # Get P2 feature map (stride=4, most precise) p2_features = feature_maps[0] # [B, channels, H, W] feat_h, feat_w = p2_features.shape[2], p2_features.shape[3] # Map image coordinates to feature coordinates (stride=4) feat_x = int(center_x * feat_w / img_w) feat_y = int(center_y * feat_h / img_h) # Clamp to valid range feat_x = max(0, min(feat_x, feat_w - 1)) feat_y = max(0, min(feat_y, feat_h - 1)) # Calculate position in flattened x_cat # Since x_cat is concatenated from multiple scales, we need to estimate # For YOLOv8-p2, P2 contributes the first portion # Estimate P2 contribution to total positions p2_positions = feat_h * feat_w if p2_positions <= hw: # Use P2 position if within bounds feature_pos = feat_y * feat_w + feat_x if feature_pos < hw: return feature_pos, (feat_x, feat_y) # Strategy 2: Direct spatial mapping to x_cat # Map center to normalized position norm_x = center_x / img_w norm_y = center_y / img_h # Estimate spatial grid size for x_cat spatial_size = int(np.sqrt(hw)) if spatial_size * spatial_size > hw: spatial_size = int(np.sqrt(hw)) # Map to grid position grid_x = int(norm_x * spatial_size) grid_y = int(norm_y * spatial_size) # Clamp to valid range grid_x = max(0, min(grid_x, spatial_size - 1)) grid_y = max(0, min(grid_y, spatial_size - 1)) # Convert to flat position feature_pos = grid_y * spatial_size + grid_x # Ensure within bounds if feature_pos >= hw: feature_pos = hw - 1 return feature_pos, (grid_x, grid_y) def get_detection_feature_by_index(self, detection_idx: int) -> Optional[SingleDetectionFeature]: """Get a specific detection feature by index.""" if 0 <= detection_idx < len(self.current_detections): return self.current_detections[detection_idx] return None def get_features_by_class(self, class_id: int) -> List[SingleDetectionFeature]: """Get all detection features for a specific class.""" return [det for det in self.current_detections if det.class_id == class_id] def find_most_similar_detections( self, target_feature: torch.Tensor, top_k: int = 5 ) -> List[Tuple[SingleDetectionFeature, float]]: """ Find detections with most similar visual features. Returns list of (detection, similarity_score) tuples. """ if not self.current_detections: return [] similarities = [] target_norm = F.normalize(target_feature.unsqueeze(0), dim=1) for detection in self.current_detections: det_norm = F.normalize(detection.visual_features.unsqueeze(0), dim=1) similarity = torch.cosine_similarity(target_norm, det_norm).item() similarities.append((detection, similarity)) # Sort by similarity (descending) similarities.sort(key=lambda x: x[1], reverse=True) return similarities[:top_k] def analyze_feature_diversity(self) -> Dict[str, float]: """Analyze the diversity of current detection features.""" if len(self.current_detections) < 2: return {"error": "Need at least 2 detections for diversity analysis"} # Stack all features all_features = torch.stack([det.visual_features for det in self.current_detections]) # Calculate pairwise similarities norm_features = F.normalize(all_features, dim=1) similarity_matrix = torch.mm(norm_features, norm_features.t()) # Remove diagonal (self-similarities) mask = ~torch.eye(similarity_matrix.shape[0], dtype=torch.bool, device=similarity_matrix.device) similarities = similarity_matrix[mask] return { "num_detections": len(self.current_detections), "mean_similarity": float(torch.mean(similarities)), "std_similarity": float(torch.std(similarities)), "min_similarity": float(torch.min(similarities)), "max_similarity": float(torch.max(similarities)), "feature_dimension": all_features.shape[1] } def save_single_features(self, filepath: str): """Save individual detection features to file.""" import json save_data = { "num_detections": len(self.current_detections), "detections": [ { "visual_features": det.visual_features.cpu().numpy().tolist(), "bbox": det.bbox.cpu().numpy().tolist(), "center_point": det.center_point, "feature_position": det.feature_position, "feature_coordinates": det.feature_coordinates, "confidence": det.confidence, "class_id": det.class_id, "image_idx": det.image_idx, "feature_norm": det.feature_norm, "feature_activation": det.feature_activation } for det in self.current_detections ] } with open(filepath, 'w') as f: json.dump(save_data, f, indent=2) # Usage example def demo_single_detection_extraction(): """Demonstrate single detection feature extraction.""" # This would be integrated into the main network pass