TEAMS / lib /utils /multi_scale_feature_extractor.py
Richard-ZZZZZ's picture
Upload folder using huggingface_hub
e168a4d verified
"""
Multi-Scale Feature Extractor for YOLO Detection
This module extracts raw 116-dimensional features from all 4 scales (P2, P3, P4, P5)
for each detection box, providing the most comprehensive visual representation.
Key features:
- Extracts features before DFL processing (116-dim vs 56-dim)
- Maintains exact correspondence with original spatial layout
- Preserves scale ordering (P2 -> P3 -> P4 -> P5)
- Handles exact 2D reconstruction without rotation/flipping
"""
import torch
import numpy as np
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
@dataclass
class MultiScaleDetectionFeatures:
"""
Comprehensive features for a single detection across all 4 scales.
"""
# Detection information
bbox: torch.Tensor # [x1, y1, x2, y2, score, class_id]
center_point: Tuple[float, float] # (center_x, center_y)
# Multi-scale features (4 scales × 116 dimensions)
p2_features: torch.Tensor # [116] - P2 scale (136×136, stride=4)
p3_features: torch.Tensor # [116] - P3 scale (68×68, stride=8)
p4_features: torch.Tensor # [116] - P4 scale (34×34, stride=16)
p5_features: torch.Tensor # [116] - P5 scale (17×17, stride=32)
# Feature positions (for verification/debugging)
p2_position: Tuple[int, int] # (x, y) in P2 feature map
p3_position: Tuple[int, int] # (x, y) in P3 feature map
p4_position: Tuple[int, int] # (x, y) in P4 feature map
p5_position: Tuple[int, int] # (x, y) in P5 feature map
# Metadata
confidence: float
class_id: int
image_idx: int
# 多模态特征 (可选)
text_features: Optional[torch.Tensor] = None # [116] - ClinicalBERT文本特征
multimodal_features: Optional[torch.Tensor] = None # [580] - 视觉+文本融合特征
@property
def concatenated_features(self) -> torch.Tensor:
"""Get all scale features concatenated: [4×116] = [464]"""
return torch.cat([self.p2_features, self.p3_features, self.p4_features, self.p5_features], dim=0)
@property
def feature_matrix(self) -> torch.Tensor:
"""Get features as matrix: [4, 116]"""
return torch.stack([self.p2_features, self.p3_features, self.p4_features, self.p5_features], dim=0)
class MultiScaleFeatureExtractor:
"""
Extracts multi-scale raw features (116-dim) for detection boxes.
This extractor works with the pre-DFL features to capture the richest
visual information across all 4 scales of YOLOv8-p2.
"""
def __init__(self, input_size: Tuple[int, int] = (544, 544)):
self.input_size = input_size
self.feature_sizes = self._calculate_feature_sizes()
# Verify total matches expected HW = 24565
total_positions = sum(size[0] * size[1] for size in self.feature_sizes.values())
assert total_positions == 24565, f"Total positions {total_positions} != expected 24565"
# Calculate offsets for each scale in the concatenated tensor
self._calculate_scale_offsets()
def _adapt_to_hw(self, hw: int):
"""
动态适应不同的HW特征图尺寸
Args:
hw: 实际的特征位置总数
"""
# 常见的配置映射 - 只保留验证过的配置
common_configs = {
21760: {
'feature_sizes': {
'P2': (128, 128), # 16384 positions
'P3': (64, 64), # 4096 positions
'P4': (32, 32), # 1024 positions
'P5': (16, 16) # 256 positions
},
'scale_offsets': {
'P2': 0,
'P3': 16384,
'P4': 16384 + 4096,
'P5': 16384 + 4096 + 1024
}
}
}
if hw in common_configs:
config = common_configs[hw]
# 验证配置是否匹配
total_positions = sum(h*w for h, w in config['feature_sizes'].values())
if total_positions == hw:
self.feature_sizes = config['feature_sizes']
self.scale_offsets = config['scale_offsets']
# 重新计算scale_ranges
self._update_scale_ranges()
return
# 改进的动态计算配置,确保每个尺度都有合理的位置数
# 使用更保守的比例分配,确保P5至少有一些位置
# 为P5预留最小位置数
min_p5_positions = 64 # 8×8的最小特征图
# P2获取大约75%的位置
p2_positions = hw * 3 // 4
# 为P5预留位置
p5_positions = max(min_p5_positions, hw // 32) # 至少1/32的位置给P5
remaining_for_p3_p4 = hw - p2_positions - p5_positions
# P3和P4按3:1的比例分配剩余位置
p3_positions = remaining_for_p3_p4 * 3 // 4
p4_positions = remaining_for_p3_p4 - p3_positions
# 调整确保总和等于hw
total = p2_positions + p3_positions + p4_positions + p5_positions
if total != hw:
# 微调P2以匹配总数
p2_positions += (hw - total)
assert p2_positions > 0 and p3_positions > 0 and p4_positions > 0 and p5_positions > 0, "每个尺度都需要有正数位置"
def positions_to_size(positions):
import math
side = int(math.sqrt(positions))
return side, side
p2_size = positions_to_size(p2_positions)
p3_size = positions_to_size(p3_positions)
p4_size = positions_to_size(p4_positions)
p5_size = positions_to_size(p5_positions)
actual_p2 = p2_size[0] * p2_size[1]
actual_p3 = p3_size[0] * p3_size[1]
actual_p4 = p4_size[0] * p4_size[1]
self.feature_sizes = {
'P2': p2_size,
'P3': p3_size,
'P4': p4_size,
'P5': p5_size
}
self.scale_offsets = {
'P2': 0,
'P3': actual_p2,
'P4': actual_p2 + actual_p3,
'P5': actual_p2 + actual_p3 + actual_p4
}
# 重新计算scale_ranges
self._update_scale_ranges()
print(f"[DYNAMIC] 动态计算配置完成 HW={hw}")
def _update_scale_ranges(self):
"""重新计算scale_ranges以匹配当前配置"""
self.scale_ranges = {}
scales = ['P2', 'P3', 'P4', 'P5']
for scale in scales:
h, w = self.feature_sizes[scale]
positions = h * w
offset = self.scale_offsets[scale]
self.scale_ranges[scale] = (offset, offset + positions)
# Update scale ranges
def _calculate_feature_sizes(self) -> Dict[str, Tuple[int, int]]:
"""Calculate feature map sizes for all scales."""
h, w = self.input_size
return {
'P2': (h // 4, w // 4), # stride=4 -> 136×136 = 18,496
'P3': (h // 8, w // 8), # stride=8 -> 68×68 = 4,624
'P4': (h // 16, w // 16), # stride=16 -> 34×34 = 1,156
'P5': (h // 32, w // 32), # stride=32 -> 17×17 = 289
}
def _calculate_scale_offsets(self):
"""Calculate offset positions for each scale in concatenated tensor."""
self.scale_offsets = {}
self.scale_ranges = {}
scales = ['P2', 'P3', 'P4', 'P5']
current_offset = 0
for scale in scales:
h, w = self.feature_sizes[scale]
positions = h * w
self.scale_offsets[scale] = current_offset
self.scale_ranges[scale] = (current_offset, current_offset + positions)
current_offset += positions
# Scale offsets calculated
def extract_multi_scale_features(
self,
raw_116_features: torch.Tensor, # [B, 116, HW] - Raw features before DFL
detection_bboxes: torch.Tensor, # [B, N, 6] - Detection boxes [x1,y1,x2,y2,score,cls]
image_size: Optional[Tuple[int, int]] = None,
confidence_threshold: float = 0.0 # 添加置信度阈值参数,默认为0(不过滤)
) -> List[MultiScaleDetectionFeatures]:
"""Extract multi-scale 116-dimensional features for detection boxes."""
if image_size is None:
image_size = self.input_size
B, feature_dim, HW = raw_116_features.shape
assert feature_dim == 116, f"Expected feature_dim=116, got {feature_dim}"
# 动态适应不同的特征图尺寸
if HW != 24565:
# 验证当前配置是否匹配HW
current_positions = sum(h*w for h, w in self.feature_sizes.values())
if current_positions != HW:
print(f"[DYNAMIC] 适配特征尺寸: HW={HW}")
# 尝试基于常见配置重新初始化
self._adapt_to_hw(HW)
img_h, img_w = image_size
detections = []
# Extract multi-scale features from raw features
# Process each image in the batch
for batch_idx in range(B):
# Get detections for this image
img_detections = detection_bboxes[batch_idx]
valid_detections = img_detections[img_detections[:, 4] > confidence_threshold] # 使用传入的置信度阈值
# Skip images with no valid detections
# Process each detection
for det_idx, detection in enumerate(valid_detections):
bbox = detection # [x1, y1, x2, y2, score, class_id]
confidence = float(detection[4])
class_id = int(detection[5])
# Calculate center point
center_x = (bbox[0] + bbox[2]) / 2
center_y = (bbox[1] + bbox[3]) / 2
# 减少详细输出以提高性能
# print(f" Detection {det_idx+1}: center=({center_x:.1f}, {center_y:.1f}), conf={confidence:.3f}")
# Extract features for all 4 scales
scale_features = {}
scale_positions = {}
for scale in ['P2', 'P3', 'P4', 'P5']:
# Calculate feature map position for this scale
feat_pos, feat_coords = self._map_center_to_scale_position(
center_x, center_y, scale, img_w, img_h
)
# Extract 116-dimensional feature from raw features
feature = self._extract_raw_feature_at_position(
raw_116_features[batch_idx], feat_pos, scale
)
scale_features[scale] = feature
scale_positions[scale] = feat_coords
# Create multi-scale detection features
multi_features = MultiScaleDetectionFeatures(
bbox=bbox,
center_point=(float(center_x), float(center_y)),
p2_features=scale_features['P2'],
p3_features=scale_features['P3'],
p4_features=scale_features['P4'],
p5_features=scale_features['P5'],
p2_position=scale_positions['P2'],
p3_position=scale_positions['P3'],
p4_position=scale_positions['P4'],
p5_position=scale_positions['P5'],
confidence=confidence,
class_id=class_id,
image_idx=batch_idx
)
detections.append(multi_features)
return detections
def _map_center_to_scale_position(
self,
center_x: float,
center_y: float,
scale: str,
img_w: int,
img_h: int
) -> Tuple[int, Tuple[int, int]]:
"""
Map image center coordinates to feature map position for a specific scale.
Args:
center_x, center_y: Center coordinates in original image
scale: Target scale ('P2', 'P3', 'P4', 'P5')
img_w, img_h: Image dimensions
Returns:
(flat_position, (feat_x, feat_y)) in feature map
"""
# Get stride for this scale
stride = {'P2': 4, 'P3': 8, 'P4': 16, 'P5': 32}[scale]
# Get feature map dimensions
feat_h, feat_w = self.feature_sizes[scale]
# Map image coordinates to feature coordinates
# This maintains exact spatial correspondence without rotation/flipping
feat_x = int(center_x / stride)
feat_y = int(center_y / stride)
# Clamp to valid range
feat_x = max(0, min(feat_x, feat_w - 1))
feat_y = max(0, min(feat_y, feat_h - 1))
# Convert to flat position in this scale's feature map
flat_position = feat_y * feat_w + feat_x
# Convert to flat position in concatenated tensor
concat_position = self.scale_offsets[scale] + flat_position
return concat_position, (feat_x, feat_y)
def _extract_raw_feature_at_position(
self,
batch_features: torch.Tensor, # [116, 24565] for one batch
position: int,
scale: str
) -> torch.Tensor:
"""
Extract 116-dimensional raw feature at a specific position.
Args:
batch_features: Features for one image [116, 24565]
position: Flat position in concatenated tensor
scale: Scale name (for verification)
Returns:
Feature tensor [116]
"""
# Verify position is within expected range for this scale
start, end = self.scale_ranges[scale]
assert start <= position < end, f"Position {position} outside {scale} range [{start}:{end})"
# Additional safety check: ensure position is within tensor bounds
tensor_size = batch_features.shape[1]
if position >= tensor_size:
raise AssertionError(f"Position {position} exceeds tensor size {tensor_size} for scale {scale}")
# Extract the 116-dimensional feature
feature = batch_features[:, position] # [116]
return feature
def create_synthetic_test_data(self, batch_size: int = 1) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Create synthetic test data for validation.
Returns:
(raw_features, detection_bboxes)
raw_features: [B, 116, 24565]
detection_bboxes: [B, N, 6]
"""
print("Creating synthetic test data...")
# Create random raw features (116-dimensional)
raw_features = torch.randn(batch_size, 116, 24565)
# Create synthetic detections at known positions
detections_list = []
for batch_idx in range(batch_size):
batch_detections = []
# Create test detections at strategic positions
test_positions = [
(100, 100, 'P2'), # Should map to P2 position
(200, 200, 'P2'), # Should map to P2 position
(300, 300, 'P3'), # Should map to P3 position
(400, 400, 'P4'), # Should map to P4 position
(500, 500, 'P5'), # Should map to P5 position
]
for center_x, center_y, target_scale in test_positions:
# Create a detection box around the center
box_size = {'P2': 20, 'P3': 30, 'P4': 50, 'P5': 80}[target_scale]
x1 = center_x - box_size // 2
y1 = center_y - box_size // 2
x2 = center_x + box_size // 2
y2 = center_y + box_size // 2
detection = torch.tensor([
x1, y1, x2, y2, # bbox coordinates
0.8 + 0.1 * torch.rand(1).item(), # confidence (0.8-0.9)
torch.randint(0, 52, (1,)).item() # class_id (0-51)
], dtype=torch.float32)
batch_detections.append(detection)
print(f" Synthetic detection: center=({center_x},{center_y}), scale={target_scale}")
detections_list.append(torch.stack(batch_detections))
# Pad to same number of detections per batch
max_detections = max(len(dets) for dets in detections_list)
detection_bboxes = torch.zeros(batch_size, max_detections, 6)
for batch_idx, dets in enumerate(detections_list):
detection_bboxes[batch_idx, :len(dets)] = dets
print(f"Created synthetic data: raw_features={raw_features.shape}, detections={detection_bboxes.shape}")
return raw_features, detection_bboxes
def extract_features_from_bbox(
self,
raw_116_features: Dict[str, torch.Tensor], # Dictionary with scale keys {'P2': tensor, 'P3': tensor, ...}
bbox: torch.Tensor, # [4] - Single bbox [x1, y1, x2, y2]
center: Tuple[float, float], # [2] - Center point [center_x, center_y]
confidence: float, # Confidence score
class_id: int, # Class ID
image_idx: int, # Image index in batch
image_size: Tuple[int, int] # [H, W]
) -> MultiScaleDetectionFeatures:
"""
为单个指定的边界框提取多尺度特征,主要用于训练时的GT检测框特征提取
这个方法允许直接从任意指定的边界框位置提取特征,而不依赖于YOLO检测结果。
在训练阶段,可以使用GT检测框的位置和类别信息来提取完全对齐的特征。
"""
# 从字典中获取设备信息
if isinstance(raw_116_features, dict):
device = raw_116_features['P2'].device if 'P2' in raw_116_features else next(iter(raw_116_features.values())).device
else:
device = raw_116_features.device
# 处理center坐标
if isinstance(center, (tuple, list)):
center_x, center_y = float(center[0]), float(center[1])
else:
center_x, center_y = center[0].item(), center[1].item()
# 为每个尺度提取特征
scale_features = {}
scale_positions = {}
if isinstance(raw_116_features, dict):
# 处理字典格式的多尺度特征
for scale, stride in [('P2', 4), ('P3', 8), ('P4', 16), ('P5', 32)]:
# 计算在当前尺度下的特征图坐标
feat_x = int(center_x / stride)
feat_y = int(center_y / stride)
if scale in raw_116_features:
# 获取对应尺度的特征图
scale_feat = raw_116_features[scale] # [B, C, H, W]
# 边界检查,确保坐标在有效范围内
_, _, feat_h, feat_w = scale_feat.shape
feat_x = max(0, min(feat_x, feat_w - 1))
feat_y = max(0, min(feat_y, feat_h - 1))
# 提取特征向量 [C]
feat_vector = scale_feat[image_idx, :, feat_y, feat_x] # [C]
# 标准化为116维(如果需要)
if feat_vector.shape[0] != 116:
# 如果不是116维,使用线性变换或重复/裁剪
if feat_vector.shape[0] > 116:
feat_vector = feat_vector[:116] # 截断
else:
# 重复填充到116维
repeat_times = (116 + feat_vector.shape[0] - 1) // feat_vector.shape[0]
feat_vector = feat_vector.repeat(repeat_times)[:116]
scale_features[scale.lower()] = feat_vector
scale_positions[f"{scale.lower()}_position"] = (feat_x, feat_y)
else:
# 如果该尺度不存在,创建零向量
scale_features[scale.lower()] = torch.zeros(116, device=device)
scale_positions[f"{scale.lower()}_position"] = (0, 0)
else:
# 处理张量格式的原始116维特征(保持原有逻辑)
for scale, stride in [('P2', 4), ('P3', 8), ('P4', 16), ('P5', 32)]:
# 计算在当前尺度下的特征图坐标
feat_x = int(center_x / stride)
feat_y = int(center_y / stride)
# 边界检查,确保坐标在有效范围内
h, w = self.feature_sizes[scale]
feat_x = max(0, min(feat_x, w - 1))
feat_y = max(0, min(feat_y, h - 1))
# 计算在拼接特征图中的线性索引
linear_idx = self.get_feature_index(scale, feat_x, feat_y)
# 提取对应的特征向量 [116]
feat_vector = raw_116_features[image_idx, :, linear_idx] # [116]
scale_features[scale.lower()] = feat_vector
scale_positions[f"{scale.lower()}_position"] = (feat_x, feat_y)
# 构建MultiScaleDetectionFeatures对象
result = MultiScaleDetectionFeatures(
bbox=torch.cat([bbox, torch.tensor([confidence, float(class_id)], device=device)]), # [6]
center_point=(center_x, center_y),
p2_features=scale_features['p2'],
p3_features=scale_features['p3'],
p4_features=scale_features['p4'],
p5_features=scale_features['p5'],
p2_position=scale_positions['p2_position'],
p3_position=scale_positions['p3_position'],
p4_position=scale_positions['p4_position'],
p5_position=scale_positions['p5_position'],
confidence=confidence,
class_id=class_id,
image_idx=image_idx
)
return result
def test_multi_scale_extractor():
"""Test the multi-scale feature extractor."""
print("=" * 80)
print("TESTING MULTI-SCALE FEATURE EXTRACTOR")
print("=" * 80)
# Initialize extractor
extractor = MultiScaleFeatureExtractor(input_size=(544, 544))
# Create synthetic test data
raw_features, detection_bboxes = extractor.create_synthetic_test_data(batch_size=1)
print("\n" + "=" * 80)
print("EXTRACTING MULTI-SCALE FEATURES")
print("=" * 80)
# Extract features
multi_features = extractor.extract_multi_scale_features(
raw_116_features=raw_features,
detection_bboxes=detection_bboxes
)
print(f"\n" + "=" * 80)
print("ANALYZING RESULTS")
print("=" * 80)
print(f"Total detections processed: {len(multi_features)}")
for i, det_features in enumerate(multi_features):
print(f"\n--- Detection {i+1} ---")
print(f"BBox: [{det_features.bbox[0]:.1f}, {det_features.bbox[1]:.1f}, "
f"{det_features.bbox[2]:.1f}, {det_features.bbox[3]:.1f}]")
print(f"Center: {det_features.center_point}")
print(f"Confidence: {det_features.confidence:.3f}")
print(f"Class ID: {det_features.class_id}")
# Verify feature dimensions
print(f"Feature dimensions:")
print(f" P2: {det_features.p2_features.shape} (pos: {det_features.p2_position})")
print(f" P3: {det_features.p3_features.shape} (pos: {det_features.p3_position})")
print(f" P4: {det_features.p4_features.shape} (pos: {det_features.p4_position})")
print(f" P5: {det_features.p5_features.shape} (pos: {det_features.p5_position})")
# Feature statistics
concatenated = det_features.concatenated_features
feature_matrix = det_features.feature_matrix
print(f"Combined features:")
print(f" Concatenated: {concatenated.shape} (4×116=464)")
print(f" Matrix: {feature_matrix.shape} (4×116)")
print(f" Feature norm: {torch.norm(concatenated):.4f}")
# Scale-specific statistics
scales = ['P2', 'P3', 'P4', 'P5']
for j, scale in enumerate(scales):
scale_feat = feature_matrix[j]
print(f" {scale} norm: {torch.norm(scale_feat):.4f}, mean: {torch.mean(scale_feat):.4f}")
# Test reconstruction
print(f"\n" + "=" * 80)
print("VALIDATING POSITION MAPPING")
print("=" * 80)
for i, det_features in enumerate(multi_features[:2]): # Test first 2 detections
center_x, center_y = det_features.center_point
print(f"\nDetection {i+1} position mapping validation:")
print(f"Image center: ({center_x}, {center_y})")
for scale in ['P2', 'P3', 'P4', 'P5']:
# Get expected position
stride = {'P2': 4, 'P3': 8, 'P4': 16, 'P5': 32}[scale]
expected_feat_x = int(center_x / stride)
expected_feat_y = int(center_y / stride)
# Get actual position
if scale == 'P2':
actual_pos = det_features.p2_position
elif scale == 'P3':
actual_pos = det_features.p3_position
elif scale == 'P4':
actual_pos = det_features.p4_position
else: # P5
actual_pos = det_features.p5_position
actual_feat_x, actual_feat_y = actual_pos
print(f" {scale} (stride={stride}):")
print(f" Expected: ({expected_feat_x}, {expected_feat_y})")
print(f" Actual: ({actual_feat_x}, {actual_feat_y})")
print(f" Match: {'✓' if (actual_feat_x, actual_feat_y) == (expected_feat_x, expected_feat_y) else '✗'}")
print(f"\n✅ Multi-scale feature extraction test completed!")
print(f"✅ Each detection now has 4×116 = 464 dimensional multi-scale features")
return extractor, multi_features
if __name__ == "__main__":
extractor, features = test_multi_scale_extractor()