| """ |
| Single Detection Feature Extractor |
| |
| Extract high-dimensional visual features for individual detection boxes. |
| This provides the unique visual representation for each specific detection. |
| """ |
|
|
| import torch |
| import numpy as np |
| from typing import Dict, List, Tuple, Optional |
| import torch.nn.functional as F |
| from dataclasses import dataclass |
|
|
|
|
| @dataclass |
| class SingleDetectionFeature: |
| """ |
| Represents the visual features for a single, specific detection box. |
| """ |
| |
| visual_features: torch.Tensor |
| |
| |
| bbox: torch.Tensor |
| center_point: Tuple[float, float] |
| |
| |
| feature_position: int |
| feature_coordinates: Tuple[int, int] |
| |
| |
| confidence: float |
| class_id: int |
| image_idx: int |
| |
| |
| feature_norm: float |
| feature_activation: float |
|
|
|
|
| class SingleDetectionFeatureExtractor: |
| """ |
| Extracts unique visual features for individual detection boxes from YOLO output. |
| |
| This provides the most accurate visual representation for each specific detection. |
| """ |
| |
| def __init__(self): |
| self.current_detections: List[SingleDetectionFeature] = [] |
| self.detection_history: List[List[SingleDetectionFeature]] = [] |
| |
| def extract_single_detection_features( |
| self, |
| x_cat: torch.Tensor, |
| detection_bboxes: torch.Tensor, |
| feature_maps: List[torch.Tensor], |
| image_size: Tuple[int, int] = (544, 544), |
| confidence_threshold: float = 0.25 |
| ) -> List[SingleDetectionFeature]: |
| """ |
| Extract visual features for individual detection boxes. |
| |
| Args: |
| x_cat: Raw concatenated features from YOLO head [B, feature_dim, HW] |
| detection_bboxes: Final detection boxes after NMS [B, N, 6] |
| feature_maps: List of multi-scale feature maps for precise mapping |
| image_size: Input image size (H, W) |
| confidence_threshold: Minimum confidence to consider a detection |
| |
| Returns: |
| List of SingleDetectionFeature objects, one per valid detection |
| """ |
| batch_size, feature_dim, hw = x_cat.shape |
| |
| print(f"Extracting features from x_cat.shape: {x_cat.shape}") |
| print(f"Processing {batch_size} images with {hw} total positions per image") |
| |
| detections = [] |
| |
| |
| for batch_idx in range(batch_size): |
| |
| img_detections = detection_bboxes[batch_idx] |
| valid_detections = img_detections[img_detections[:, 4] > confidence_threshold] |
| |
| print(f"Image {batch_idx}: Found {len(valid_detections)} valid detections") |
| |
| |
| for det_idx, detection in enumerate(valid_detections): |
| bbox = detection |
| confidence = float(detection[4]) |
| class_id = int(detection[5]) |
| |
| |
| center_x = (bbox[0] + bbox[2]) / 2 |
| center_y = (bbox[1] + bbox[3]) / 2 |
| |
| |
| feature_pos, feat_coords = self._map_detection_to_feature_position( |
| center_x, center_y, x_cat, feature_maps, image_size, batch_idx |
| ) |
| |
| |
| visual_features = x_cat[batch_idx, :, feature_pos] |
| |
| |
| feature_norm = torch.norm(visual_features).item() |
| feature_activation = torch.mean(torch.abs(visual_features)).item() |
| |
| |
| single_feature = SingleDetectionFeature( |
| visual_features=visual_features, |
| bbox=bbox, |
| center_point=(float(center_x), float(center_y)), |
| feature_position=feature_pos, |
| feature_coordinates=feat_coords, |
| confidence=confidence, |
| class_id=class_id, |
| image_idx=batch_idx, |
| feature_norm=feature_norm, |
| feature_activation=feature_activation |
| ) |
| |
| detections.append(single_feature) |
| |
| print(f" Detection {det_idx+1}:") |
| print(f" BBox: [{bbox[0]:.1f}, {bbox[1]:.1f}, {bbox[2]:.1f}, {bbox[3]:.1f}]") |
| print(f" Center: ({center_x:.1f}, {center_y:.1f})") |
| print(f" Feature position: {feature_pos} -> coordinates {feat_coords}") |
| print(f" Feature shape: {list(visual_features.shape)}") |
| print(f" Feature norm: {feature_norm:.4f}") |
| print(f" Feature activation: {feature_activation:.4f}") |
| |
| self.current_detections = detections |
| self.detection_history.append(detections.copy()) |
| |
| return detections |
| |
| def _map_detection_to_feature_position( |
| self, |
| center_x: float, |
| center_y: float, |
| x_cat: torch.Tensor, |
| feature_maps: List[torch.Tensor], |
| image_size: Tuple[int, int], |
| batch_idx: int |
| ) -> Tuple[int, Tuple[int, int]]: |
| """ |
| Map detection center to precise position in x_cat feature tensor. |
| |
| This is the most critical part - accurately finding which feature |
| in x_cat corresponds to this specific detection. |
| """ |
| _, feature_dim, hw = x_cat.shape |
| img_h, img_w = image_size |
| |
| |
| if feature_maps and len(feature_maps) > 0: |
| |
| p2_features = feature_maps[0] |
| feat_h, feat_w = p2_features.shape[2], p2_features.shape[3] |
| |
| |
| feat_x = int(center_x * feat_w / img_w) |
| feat_y = int(center_y * feat_h / img_h) |
| |
| |
| feat_x = max(0, min(feat_x, feat_w - 1)) |
| feat_y = max(0, min(feat_y, feat_h - 1)) |
| |
| |
| |
| |
| |
| |
| p2_positions = feat_h * feat_w |
| if p2_positions <= hw: |
| |
| feature_pos = feat_y * feat_w + feat_x |
| if feature_pos < hw: |
| return feature_pos, (feat_x, feat_y) |
| |
| |
| |
| norm_x = center_x / img_w |
| norm_y = center_y / img_h |
| |
| |
| spatial_size = int(np.sqrt(hw)) |
| if spatial_size * spatial_size > hw: |
| spatial_size = int(np.sqrt(hw)) |
| |
| |
| grid_x = int(norm_x * spatial_size) |
| grid_y = int(norm_y * spatial_size) |
| |
| |
| grid_x = max(0, min(grid_x, spatial_size - 1)) |
| grid_y = max(0, min(grid_y, spatial_size - 1)) |
| |
| |
| feature_pos = grid_y * spatial_size + grid_x |
| |
| |
| if feature_pos >= hw: |
| feature_pos = hw - 1 |
| |
| return feature_pos, (grid_x, grid_y) |
| |
| def get_detection_feature_by_index(self, detection_idx: int) -> Optional[SingleDetectionFeature]: |
| """Get a specific detection feature by index.""" |
| if 0 <= detection_idx < len(self.current_detections): |
| return self.current_detections[detection_idx] |
| return None |
| |
| def get_features_by_class(self, class_id: int) -> List[SingleDetectionFeature]: |
| """Get all detection features for a specific class.""" |
| return [det for det in self.current_detections if det.class_id == class_id] |
| |
| def find_most_similar_detections( |
| self, |
| target_feature: torch.Tensor, |
| top_k: int = 5 |
| ) -> List[Tuple[SingleDetectionFeature, float]]: |
| """ |
| Find detections with most similar visual features. |
| |
| Returns list of (detection, similarity_score) tuples. |
| """ |
| if not self.current_detections: |
| return [] |
| |
| similarities = [] |
| target_norm = F.normalize(target_feature.unsqueeze(0), dim=1) |
| |
| for detection in self.current_detections: |
| det_norm = F.normalize(detection.visual_features.unsqueeze(0), dim=1) |
| similarity = torch.cosine_similarity(target_norm, det_norm).item() |
| similarities.append((detection, similarity)) |
| |
| |
| similarities.sort(key=lambda x: x[1], reverse=True) |
| |
| return similarities[:top_k] |
| |
| def analyze_feature_diversity(self) -> Dict[str, float]: |
| """Analyze the diversity of current detection features.""" |
| if len(self.current_detections) < 2: |
| return {"error": "Need at least 2 detections for diversity analysis"} |
| |
| |
| all_features = torch.stack([det.visual_features for det in self.current_detections]) |
| |
| |
| norm_features = F.normalize(all_features, dim=1) |
| similarity_matrix = torch.mm(norm_features, norm_features.t()) |
| |
| |
| mask = ~torch.eye(similarity_matrix.shape[0], dtype=torch.bool, device=similarity_matrix.device) |
| similarities = similarity_matrix[mask] |
| |
| return { |
| "num_detections": len(self.current_detections), |
| "mean_similarity": float(torch.mean(similarities)), |
| "std_similarity": float(torch.std(similarities)), |
| "min_similarity": float(torch.min(similarities)), |
| "max_similarity": float(torch.max(similarities)), |
| "feature_dimension": all_features.shape[1] |
| } |
| |
| def save_single_features(self, filepath: str): |
| """Save individual detection features to file.""" |
| import json |
| |
| save_data = { |
| "num_detections": len(self.current_detections), |
| "detections": [ |
| { |
| "visual_features": det.visual_features.cpu().numpy().tolist(), |
| "bbox": det.bbox.cpu().numpy().tolist(), |
| "center_point": det.center_point, |
| "feature_position": det.feature_position, |
| "feature_coordinates": det.feature_coordinates, |
| "confidence": det.confidence, |
| "class_id": det.class_id, |
| "image_idx": det.image_idx, |
| "feature_norm": det.feature_norm, |
| "feature_activation": det.feature_activation |
| } |
| for det in self.current_detections |
| ] |
| } |
| |
| with open(filepath, 'w') as f: |
| json.dump(save_data, f, indent=2) |
|
|
|
|
| |
| def demo_single_detection_extraction(): |
| """Demonstrate single detection feature extraction.""" |
| |
| pass |