| """ |
| Detection Vector Tracker for EnergySnake |
| |
| This module provides utilities to track high-dimensional vectors for each detection box |
| before YOLO decoder, maintaining correspondence with detection boxes. |
| """ |
|
|
| import torch |
| import numpy as np |
| from typing import Dict, List, Tuple, Optional |
| from dataclasses import dataclass, field |
|
|
|
|
| @dataclass |
| class DetectionVector: |
| """ |
| Represents a single detection with its high-dimensional vector and corresponding box. |
| """ |
| |
| vector: torch.Tensor |
| |
| |
| bbox: torch.Tensor |
| |
| |
| grid_pos: Tuple[int, int] |
| |
| |
| feature_map_idx: int |
| |
| |
| confidence: float |
| class_id: int |
| image_idx: int |
|
|
|
|
| @dataclass |
| class DetectionVectorBatch: |
| """ |
| Container for all detection vectors in a batch. |
| """ |
| detections: List[DetectionVector] = field(default_factory=list) |
| |
| |
| raw_vectors: torch.Tensor = None |
| raw_bboxes: torch.Tensor = None |
| raw_yolo_output: torch.Tensor = None |
| |
| |
| feature_h: int = 0 |
| feature_w: int = 0 |
| batch_size: int = 0 |
| |
| |
| vector_dim: int = 0 |
| |
| def add_detection(self, detection: DetectionVector): |
| """Add a detection to the batch.""" |
| self.detections.append(detection) |
| |
| def finalize(self): |
| """Convert list of detections to batch tensors.""" |
| if not self.detections: |
| return |
| |
| |
| self.raw_vectors = torch.stack([det.vector for det in self.detections]) |
| self.raw_bboxes = torch.stack([det.bbox for det in self.detections]) |
| |
| def get_vectors_by_class(self, class_id: int) -> torch.Tensor: |
| """Get all vectors for a specific class.""" |
| if not self.detections: |
| return torch.empty(0, self.vector_dim) |
| |
| class_vectors = [det.vector for det in self.detections if det.class_id == class_id] |
| return torch.stack(class_vectors) if class_vectors else torch.empty(0, self.vector_dim) |
| |
| def get_bboxes_by_class(self, class_id: int) -> torch.Tensor: |
| """Get all bboxes for a specific class.""" |
| if not self.detections: |
| return torch.empty(0, 6) |
| |
| class_bboxes = [det.bbox for det in self.detections if det.class_id == class_id] |
| return torch.stack(class_bboxes) if class_bboxes else torch.empty(0, 6) |
|
|
|
|
| class DetectionVectorTracker: |
| """ |
| Tracks high-dimensional vectors from YOLO detection head before decoder. |
| |
| The YOLO detection head outputs high-dimensional features that are then decoded |
| into bounding boxes. This tracker captures those intermediate features. |
| """ |
| |
| def __init__(self): |
| self.current_batch: Optional[DetectionVectorBatch] = None |
| self.history: List[DetectionVectorBatch] = [] |
| |
| |
| self.reg_max = 16 |
| self.num_classes = 52 |
| |
| def extract_vectors_from_yolo_output(self, |
| yolo_output: torch.Tensor, |
| feature_maps: List[torch.Tensor], |
| detection_bboxes: torch.Tensor, |
| image_size: Tuple[int, int] = (544, 544)) -> DetectionVectorBatch: |
| """ |
| Extract high-dimensional vectors from YOLO output before decoder. |
| |
| Args: |
| yolo_output: Raw YOLO output tensor [B, no, HW] |
| feature_maps: List of feature maps from YOLO backbone |
| detection_bboxes: Final detection boxes after NMS [B, M, 6] |
| image_size: Input image size (H, W) |
| |
| Returns: |
| DetectionVectorBatch containing extracted vectors and correspondence |
| """ |
| batch_size, output_dim, hw = yolo_output.shape |
| |
| |
| |
| |
| |
| batch = DetectionVectorBatch( |
| feature_h=int(np.sqrt(hw)), |
| feature_w=int(np.sqrt(hw)), |
| batch_size=batch_size, |
| vector_dim=output_dim, |
| raw_yolo_output=yolo_output.clone() |
| ) |
| |
| |
| if feature_maps: |
| |
| p2_shape = feature_maps[0].shape |
| batch.feature_h = p2_shape[2] |
| batch.feature_w = p2_shape[3] |
| |
| |
| yolo_reshaped = yolo_output.permute(0, 2, 1).contiguous() |
| |
| |
| |
| |
| |
| for batch_idx in range(batch_size): |
| |
| img_detections = detection_bboxes[batch_idx] |
| valid_detections = img_detections[img_detections[:, 4] > 0] |
| |
| |
| for det_idx, detection in enumerate(valid_detections): |
| bbox = detection |
| confidence = float(detection[4]) |
| class_id = int(detection[5]) |
| |
| |
| |
| center_x = (bbox[0] + bbox[2]) / 2 |
| center_y = (bbox[1] + bbox[3]) / 2 |
| |
| |
| |
| best_match_idx = self._find_best_yolo_position( |
| bbox, yolo_reshaped[batch_idx], image_size, feature_maps |
| ) |
| |
| if best_match_idx is not None: |
| vector = yolo_reshaped[batch_idx, best_match_idx] |
| |
| |
| |
| grid_x = best_match_idx % batch.feature_w |
| grid_y = best_match_idx // batch.feature_w |
| |
| |
| det_vector = DetectionVector( |
| vector=vector, |
| bbox=bbox, |
| grid_pos=(int(grid_y), int(grid_x)), |
| feature_map_idx=best_match_idx, |
| confidence=confidence, |
| class_id=class_id, |
| image_idx=batch_idx |
| ) |
| |
| batch.add_detection(det_vector) |
| |
| def _find_best_yolo_position(self, bbox, yolo_flat_output, image_size, feature_maps): |
| """ |
| Find the best matching position in YOLO output for a given detection. |
| |
| For multi-scale YOLO, this is an approximation that finds the closest spatial match. |
| """ |
| center_x = (bbox[0] + bbox[2]) / 2 |
| center_y = (bbox[1] + bbox[3]) / 2 |
| |
| |
| |
| h, w = image_size |
| flat_positions = yolo_flat_output.shape[0] |
| |
| |
| grid_size = int(np.sqrt(flat_positions)) |
| if grid_size * grid_size != flat_positions: |
| |
| grid_size = int(np.sqrt(flat_positions)) |
| |
| |
| grid_x = int(center_x * grid_size / w) |
| grid_y = int(center_y * grid_size / h) |
| |
| |
| grid_x = max(0, min(grid_x, grid_size - 1)) |
| grid_y = max(0, min(grid_y, grid_size - 1)) |
| |
| |
| flat_idx = grid_y * grid_size + grid_x |
| |
| |
| if flat_idx >= flat_positions: |
| flat_idx = flat_positions - 1 |
| |
| return flat_idx |
| |
| |
| batch.finalize() |
| |
| |
| self.current_batch = batch |
| self.history.append(batch) |
| |
| return batch |
| |
| def get_vector_shape_info(self) -> Dict[str, any]: |
| """ |
| Get information about the vector shapes and dimensions. |
| |
| Returns: |
| Dictionary containing shape information |
| """ |
| if not self.current_batch: |
| return {"error": "No batch processed yet"} |
| |
| batch = self.current_batch |
| yolo_output = batch.raw_yolo_output |
| |
| |
| |
| |
| |
| reg_max = self.reg_max |
| num_classes = self.num_classes |
| output_channels = reg_max * 4 + num_classes |
| |
| return { |
| "yolo_output_shape": list(yolo_output.shape), |
| "output_channels": output_channels, |
| "regression_channels": reg_max * 4, |
| "classification_channels": num_classes, |
| "feature_map_size": (batch.feature_h, batch.feature_w), |
| "total_positions": batch.feature_h * batch.feature_w, |
| "high_dim_vector_shape": [output_channels], |
| "vector_breakdown": { |
| "bbox_regression": [reg_max * 4], |
| "class_logits": [num_classes] |
| } |
| } |
| |
| def get_batch_summary(self) -> Dict[str, any]: |
| """Get summary statistics for the current batch.""" |
| if not self.current_batch: |
| return {"error": "No batch processed yet"} |
| |
| batch = self.current_batch |
| |
| |
| class_counts = {} |
| for det in batch.detections: |
| class_counts[det.class_id] = class_counts.get(det.class_id, 0) + 1 |
| |
| |
| confidences = [det.confidence for det in batch.detections] |
| |
| return { |
| "total_detections": len(batch.detections), |
| "detections_by_class": class_counts, |
| "confidence_stats": { |
| "mean": np.mean(confidences) if confidences else 0, |
| "min": np.min(confidences) if confidences else 0, |
| "max": np.max(confidences) if confidences else 0 |
| }, |
| "vector_stats": { |
| "shape": list(batch.raw_vectors.shape) if batch.raw_vectors is not None else None, |
| "mean_norm": float(torch.mean(torch.norm(batch.raw_vectors, dim=1))) if batch.raw_vectors is not None else 0 |
| } |
| } |
| |
| def clear_history(self): |
| """Clear processing history.""" |
| self.history.clear() |
| |
| def save_vectors_to_file(self, filepath: str, batch_idx: int = -1): |
| """ |
| Save detection vectors to file. |
| |
| Args: |
| filepath: Path to save the vectors |
| batch_idx: Index of batch to save (-1 for current/latest) |
| """ |
| if batch_idx == -1 and self.current_batch: |
| batch = self.current_batch |
| elif 0 <= batch_idx < len(self.history): |
| batch = self.history[batch_idx] |
| else: |
| raise ValueError(f"Invalid batch_idx: {batch_idx}") |
| |
| save_data = { |
| "batch_info": { |
| "batch_size": batch.batch_size, |
| "feature_map_size": (batch.feature_h, batch.feature_w), |
| "vector_dim": batch.vector_dim, |
| "total_detections": len(batch.detections) |
| }, |
| "detections": [ |
| { |
| "vector": det.vector.cpu().numpy().tolist(), |
| "bbox": det.bbox.cpu().numpy().tolist(), |
| "grid_pos": det.grid_pos, |
| "confidence": det.confidence, |
| "class_id": det.class_id, |
| "image_idx": det.image_idx |
| } |
| for det in batch.detections |
| ] |
| } |
| |
| import json |
| with open(filepath, 'w') as f: |
| json.dump(save_data, f, indent=2) |