TEAMS / lib /utils /single_detection_feature_extractor.py
Richard-ZZZZZ's picture
Upload folder using huggingface_hub
e168a4d verified
"""
Single Detection Feature Extractor
Extract high-dimensional visual features for individual detection boxes.
This provides the unique visual representation for each specific detection.
"""
import torch
import numpy as np
from typing import Dict, List, Tuple, Optional
import torch.nn.functional as F
from dataclasses import dataclass
@dataclass
class SingleDetectionFeature:
"""
Represents the visual features for a single, specific detection box.
"""
# The unique visual signature of this detection
visual_features: torch.Tensor # Shape: [feature_dim], e.g., [56]
# Detection information
bbox: torch.Tensor # [x1, y1, x2, y2, score, class_id]
center_point: Tuple[float, float] # (center_x, center_y)
# Feature mapping info
feature_position: int # Position index in x_cat [0, HW-1]
feature_coordinates: Tuple[int, int] # (feat_x, feat_y) in feature grid
# Metadata
confidence: float
class_id: int
image_idx: int
# Feature quality metrics
feature_norm: float # L2 norm of the feature vector
feature_activation: float # Overall activation strength
class SingleDetectionFeatureExtractor:
"""
Extracts unique visual features for individual detection boxes from YOLO output.
This provides the most accurate visual representation for each specific detection.
"""
def __init__(self):
self.current_detections: List[SingleDetectionFeature] = []
self.detection_history: List[List[SingleDetectionFeature]] = []
def extract_single_detection_features(
self,
x_cat: torch.Tensor, # [B, feature_dim, HW] - the core visual features
detection_bboxes: torch.Tensor, # [B, N, 6] - final detections after NMS
feature_maps: List[torch.Tensor], # Multi-scale feature maps
image_size: Tuple[int, int] = (544, 544),
confidence_threshold: float = 0.25
) -> List[SingleDetectionFeature]:
"""
Extract visual features for individual detection boxes.
Args:
x_cat: Raw concatenated features from YOLO head [B, feature_dim, HW]
detection_bboxes: Final detection boxes after NMS [B, N, 6]
feature_maps: List of multi-scale feature maps for precise mapping
image_size: Input image size (H, W)
confidence_threshold: Minimum confidence to consider a detection
Returns:
List of SingleDetectionFeature objects, one per valid detection
"""
batch_size, feature_dim, hw = x_cat.shape
print(f"Extracting features from x_cat.shape: {x_cat.shape}")
print(f"Processing {batch_size} images with {hw} total positions per image")
detections = []
# Process each image in the batch
for batch_idx in range(batch_size):
# Get valid detections for this image
img_detections = detection_bboxes[batch_idx]
valid_detections = img_detections[img_detections[:, 4] > confidence_threshold]
print(f"Image {batch_idx}: Found {len(valid_detections)} valid detections")
# Process each individual detection
for det_idx, detection in enumerate(valid_detections):
bbox = detection # [x1, y1, x2, y2, score, class_id]
confidence = float(detection[4])
class_id = int(detection[5])
# Calculate center point of the detection
center_x = (bbox[0] + bbox[2]) / 2
center_y = (bbox[1] + bbox[3]) / 2
# Find the precise feature position for this detection
feature_pos, feat_coords = self._map_detection_to_feature_position(
center_x, center_y, x_cat, feature_maps, image_size, batch_idx
)
# Extract the visual features for this specific detection
visual_features = x_cat[batch_idx, :, feature_pos] # [feature_dim]
# Calculate feature quality metrics
feature_norm = torch.norm(visual_features).item()
feature_activation = torch.mean(torch.abs(visual_features)).item()
# Create single detection feature object
single_feature = SingleDetectionFeature(
visual_features=visual_features,
bbox=bbox,
center_point=(float(center_x), float(center_y)),
feature_position=feature_pos,
feature_coordinates=feat_coords,
confidence=confidence,
class_id=class_id,
image_idx=batch_idx,
feature_norm=feature_norm,
feature_activation=feature_activation
)
detections.append(single_feature)
print(f" Detection {det_idx+1}:")
print(f" BBox: [{bbox[0]:.1f}, {bbox[1]:.1f}, {bbox[2]:.1f}, {bbox[3]:.1f}]")
print(f" Center: ({center_x:.1f}, {center_y:.1f})")
print(f" Feature position: {feature_pos} -> coordinates {feat_coords}")
print(f" Feature shape: {list(visual_features.shape)}")
print(f" Feature norm: {feature_norm:.4f}")
print(f" Feature activation: {feature_activation:.4f}")
self.current_detections = detections
self.detection_history.append(detections.copy())
return detections
def _map_detection_to_feature_position(
self,
center_x: float,
center_y: float,
x_cat: torch.Tensor,
feature_maps: List[torch.Tensor],
image_size: Tuple[int, int],
batch_idx: int
) -> Tuple[int, Tuple[int, int]]:
"""
Map detection center to precise position in x_cat feature tensor.
This is the most critical part - accurately finding which feature
in x_cat corresponds to this specific detection.
"""
_, feature_dim, hw = x_cat.shape
img_h, img_w = image_size
# Strategy 1: Use the first (largest) feature map for precise mapping
if feature_maps and len(feature_maps) > 0:
# Get P2 feature map (stride=4, most precise)
p2_features = feature_maps[0] # [B, channels, H, W]
feat_h, feat_w = p2_features.shape[2], p2_features.shape[3]
# Map image coordinates to feature coordinates (stride=4)
feat_x = int(center_x * feat_w / img_w)
feat_y = int(center_y * feat_h / img_h)
# Clamp to valid range
feat_x = max(0, min(feat_x, feat_w - 1))
feat_y = max(0, min(feat_y, feat_h - 1))
# Calculate position in flattened x_cat
# Since x_cat is concatenated from multiple scales, we need to estimate
# For YOLOv8-p2, P2 contributes the first portion
# Estimate P2 contribution to total positions
p2_positions = feat_h * feat_w
if p2_positions <= hw:
# Use P2 position if within bounds
feature_pos = feat_y * feat_w + feat_x
if feature_pos < hw:
return feature_pos, (feat_x, feat_y)
# Strategy 2: Direct spatial mapping to x_cat
# Map center to normalized position
norm_x = center_x / img_w
norm_y = center_y / img_h
# Estimate spatial grid size for x_cat
spatial_size = int(np.sqrt(hw))
if spatial_size * spatial_size > hw:
spatial_size = int(np.sqrt(hw))
# Map to grid position
grid_x = int(norm_x * spatial_size)
grid_y = int(norm_y * spatial_size)
# Clamp to valid range
grid_x = max(0, min(grid_x, spatial_size - 1))
grid_y = max(0, min(grid_y, spatial_size - 1))
# Convert to flat position
feature_pos = grid_y * spatial_size + grid_x
# Ensure within bounds
if feature_pos >= hw:
feature_pos = hw - 1
return feature_pos, (grid_x, grid_y)
def get_detection_feature_by_index(self, detection_idx: int) -> Optional[SingleDetectionFeature]:
"""Get a specific detection feature by index."""
if 0 <= detection_idx < len(self.current_detections):
return self.current_detections[detection_idx]
return None
def get_features_by_class(self, class_id: int) -> List[SingleDetectionFeature]:
"""Get all detection features for a specific class."""
return [det for det in self.current_detections if det.class_id == class_id]
def find_most_similar_detections(
self,
target_feature: torch.Tensor,
top_k: int = 5
) -> List[Tuple[SingleDetectionFeature, float]]:
"""
Find detections with most similar visual features.
Returns list of (detection, similarity_score) tuples.
"""
if not self.current_detections:
return []
similarities = []
target_norm = F.normalize(target_feature.unsqueeze(0), dim=1)
for detection in self.current_detections:
det_norm = F.normalize(detection.visual_features.unsqueeze(0), dim=1)
similarity = torch.cosine_similarity(target_norm, det_norm).item()
similarities.append((detection, similarity))
# Sort by similarity (descending)
similarities.sort(key=lambda x: x[1], reverse=True)
return similarities[:top_k]
def analyze_feature_diversity(self) -> Dict[str, float]:
"""Analyze the diversity of current detection features."""
if len(self.current_detections) < 2:
return {"error": "Need at least 2 detections for diversity analysis"}
# Stack all features
all_features = torch.stack([det.visual_features for det in self.current_detections])
# Calculate pairwise similarities
norm_features = F.normalize(all_features, dim=1)
similarity_matrix = torch.mm(norm_features, norm_features.t())
# Remove diagonal (self-similarities)
mask = ~torch.eye(similarity_matrix.shape[0], dtype=torch.bool, device=similarity_matrix.device)
similarities = similarity_matrix[mask]
return {
"num_detections": len(self.current_detections),
"mean_similarity": float(torch.mean(similarities)),
"std_similarity": float(torch.std(similarities)),
"min_similarity": float(torch.min(similarities)),
"max_similarity": float(torch.max(similarities)),
"feature_dimension": all_features.shape[1]
}
def save_single_features(self, filepath: str):
"""Save individual detection features to file."""
import json
save_data = {
"num_detections": len(self.current_detections),
"detections": [
{
"visual_features": det.visual_features.cpu().numpy().tolist(),
"bbox": det.bbox.cpu().numpy().tolist(),
"center_point": det.center_point,
"feature_position": det.feature_position,
"feature_coordinates": det.feature_coordinates,
"confidence": det.confidence,
"class_id": det.class_id,
"image_idx": det.image_idx,
"feature_norm": det.feature_norm,
"feature_activation": det.feature_activation
}
for det in self.current_detections
]
}
with open(filepath, 'w') as f:
json.dump(save_data, f, indent=2)
# Usage example
def demo_single_detection_extraction():
"""Demonstrate single detection feature extraction."""
# This would be integrated into the main network
pass