""" Multi-Scale Feature Extractor for YOLO Detection This module extracts raw 116-dimensional features from all 4 scales (P2, P3, P4, P5) for each detection box, providing the most comprehensive visual representation. Key features: - Extracts features before DFL processing (116-dim vs 56-dim) - Maintains exact correspondence with original spatial layout - Preserves scale ordering (P2 -> P3 -> P4 -> P5) - Handles exact 2D reconstruction without rotation/flipping """ import torch import numpy as np from typing import Dict, List, Tuple, Optional from dataclasses import dataclass @dataclass class MultiScaleDetectionFeatures: """ Comprehensive features for a single detection across all 4 scales. """ # Detection information bbox: torch.Tensor # [x1, y1, x2, y2, score, class_id] center_point: Tuple[float, float] # (center_x, center_y) # Multi-scale features (4 scales × 116 dimensions) p2_features: torch.Tensor # [116] - P2 scale (136×136, stride=4) p3_features: torch.Tensor # [116] - P3 scale (68×68, stride=8) p4_features: torch.Tensor # [116] - P4 scale (34×34, stride=16) p5_features: torch.Tensor # [116] - P5 scale (17×17, stride=32) # Feature positions (for verification/debugging) p2_position: Tuple[int, int] # (x, y) in P2 feature map p3_position: Tuple[int, int] # (x, y) in P3 feature map p4_position: Tuple[int, int] # (x, y) in P4 feature map p5_position: Tuple[int, int] # (x, y) in P5 feature map # Metadata confidence: float class_id: int image_idx: int # 多模态特征 (可选) text_features: Optional[torch.Tensor] = None # [116] - ClinicalBERT文本特征 multimodal_features: Optional[torch.Tensor] = None # [580] - 视觉+文本融合特征 @property def concatenated_features(self) -> torch.Tensor: """Get all scale features concatenated: [4×116] = [464]""" return torch.cat([self.p2_features, self.p3_features, self.p4_features, self.p5_features], dim=0) @property def feature_matrix(self) -> torch.Tensor: """Get features as matrix: [4, 116]""" return torch.stack([self.p2_features, self.p3_features, self.p4_features, self.p5_features], dim=0) class MultiScaleFeatureExtractor: """ Extracts multi-scale raw features (116-dim) for detection boxes. This extractor works with the pre-DFL features to capture the richest visual information across all 4 scales of YOLOv8-p2. """ def __init__(self, input_size: Tuple[int, int] = (544, 544)): self.input_size = input_size self.feature_sizes = self._calculate_feature_sizes() # Verify total matches expected HW = 24565 total_positions = sum(size[0] * size[1] for size in self.feature_sizes.values()) assert total_positions == 24565, f"Total positions {total_positions} != expected 24565" # Calculate offsets for each scale in the concatenated tensor self._calculate_scale_offsets() def _adapt_to_hw(self, hw: int): """ 动态适应不同的HW特征图尺寸 Args: hw: 实际的特征位置总数 """ # 常见的配置映射 - 只保留验证过的配置 common_configs = { 21760: { 'feature_sizes': { 'P2': (128, 128), # 16384 positions 'P3': (64, 64), # 4096 positions 'P4': (32, 32), # 1024 positions 'P5': (16, 16) # 256 positions }, 'scale_offsets': { 'P2': 0, 'P3': 16384, 'P4': 16384 + 4096, 'P5': 16384 + 4096 + 1024 } } } if hw in common_configs: config = common_configs[hw] # 验证配置是否匹配 total_positions = sum(h*w for h, w in config['feature_sizes'].values()) if total_positions == hw: self.feature_sizes = config['feature_sizes'] self.scale_offsets = config['scale_offsets'] # 重新计算scale_ranges self._update_scale_ranges() return # 改进的动态计算配置,确保每个尺度都有合理的位置数 # 使用更保守的比例分配,确保P5至少有一些位置 # 为P5预留最小位置数 min_p5_positions = 64 # 8×8的最小特征图 # P2获取大约75%的位置 p2_positions = hw * 3 // 4 # 为P5预留位置 p5_positions = max(min_p5_positions, hw // 32) # 至少1/32的位置给P5 remaining_for_p3_p4 = hw - p2_positions - p5_positions # P3和P4按3:1的比例分配剩余位置 p3_positions = remaining_for_p3_p4 * 3 // 4 p4_positions = remaining_for_p3_p4 - p3_positions # 调整确保总和等于hw total = p2_positions + p3_positions + p4_positions + p5_positions if total != hw: # 微调P2以匹配总数 p2_positions += (hw - total) assert p2_positions > 0 and p3_positions > 0 and p4_positions > 0 and p5_positions > 0, "每个尺度都需要有正数位置" def positions_to_size(positions): import math side = int(math.sqrt(positions)) return side, side p2_size = positions_to_size(p2_positions) p3_size = positions_to_size(p3_positions) p4_size = positions_to_size(p4_positions) p5_size = positions_to_size(p5_positions) actual_p2 = p2_size[0] * p2_size[1] actual_p3 = p3_size[0] * p3_size[1] actual_p4 = p4_size[0] * p4_size[1] self.feature_sizes = { 'P2': p2_size, 'P3': p3_size, 'P4': p4_size, 'P5': p5_size } self.scale_offsets = { 'P2': 0, 'P3': actual_p2, 'P4': actual_p2 + actual_p3, 'P5': actual_p2 + actual_p3 + actual_p4 } # 重新计算scale_ranges self._update_scale_ranges() print(f"[DYNAMIC] 动态计算配置完成 HW={hw}") def _update_scale_ranges(self): """重新计算scale_ranges以匹配当前配置""" self.scale_ranges = {} scales = ['P2', 'P3', 'P4', 'P5'] for scale in scales: h, w = self.feature_sizes[scale] positions = h * w offset = self.scale_offsets[scale] self.scale_ranges[scale] = (offset, offset + positions) # Update scale ranges def _calculate_feature_sizes(self) -> Dict[str, Tuple[int, int]]: """Calculate feature map sizes for all scales.""" h, w = self.input_size return { 'P2': (h // 4, w // 4), # stride=4 -> 136×136 = 18,496 'P3': (h // 8, w // 8), # stride=8 -> 68×68 = 4,624 'P4': (h // 16, w // 16), # stride=16 -> 34×34 = 1,156 'P5': (h // 32, w // 32), # stride=32 -> 17×17 = 289 } def _calculate_scale_offsets(self): """Calculate offset positions for each scale in concatenated tensor.""" self.scale_offsets = {} self.scale_ranges = {} scales = ['P2', 'P3', 'P4', 'P5'] current_offset = 0 for scale in scales: h, w = self.feature_sizes[scale] positions = h * w self.scale_offsets[scale] = current_offset self.scale_ranges[scale] = (current_offset, current_offset + positions) current_offset += positions # Scale offsets calculated def extract_multi_scale_features( self, raw_116_features: torch.Tensor, # [B, 116, HW] - Raw features before DFL detection_bboxes: torch.Tensor, # [B, N, 6] - Detection boxes [x1,y1,x2,y2,score,cls] image_size: Optional[Tuple[int, int]] = None, confidence_threshold: float = 0.0 # 添加置信度阈值参数,默认为0(不过滤) ) -> List[MultiScaleDetectionFeatures]: """Extract multi-scale 116-dimensional features for detection boxes.""" if image_size is None: image_size = self.input_size B, feature_dim, HW = raw_116_features.shape assert feature_dim == 116, f"Expected feature_dim=116, got {feature_dim}" # 动态适应不同的特征图尺寸 if HW != 24565: # 验证当前配置是否匹配HW current_positions = sum(h*w for h, w in self.feature_sizes.values()) if current_positions != HW: print(f"[DYNAMIC] 适配特征尺寸: HW={HW}") # 尝试基于常见配置重新初始化 self._adapt_to_hw(HW) img_h, img_w = image_size detections = [] # Extract multi-scale features from raw features # Process each image in the batch for batch_idx in range(B): # Get detections for this image img_detections = detection_bboxes[batch_idx] valid_detections = img_detections[img_detections[:, 4] > confidence_threshold] # 使用传入的置信度阈值 # Skip images with no valid detections # Process each detection for det_idx, detection in enumerate(valid_detections): bbox = detection # [x1, y1, x2, y2, score, class_id] confidence = float(detection[4]) class_id = int(detection[5]) # Calculate center point center_x = (bbox[0] + bbox[2]) / 2 center_y = (bbox[1] + bbox[3]) / 2 # 减少详细输出以提高性能 # print(f" Detection {det_idx+1}: center=({center_x:.1f}, {center_y:.1f}), conf={confidence:.3f}") # Extract features for all 4 scales scale_features = {} scale_positions = {} for scale in ['P2', 'P3', 'P4', 'P5']: # Calculate feature map position for this scale feat_pos, feat_coords = self._map_center_to_scale_position( center_x, center_y, scale, img_w, img_h ) # Extract 116-dimensional feature from raw features feature = self._extract_raw_feature_at_position( raw_116_features[batch_idx], feat_pos, scale ) scale_features[scale] = feature scale_positions[scale] = feat_coords # Create multi-scale detection features multi_features = MultiScaleDetectionFeatures( bbox=bbox, center_point=(float(center_x), float(center_y)), p2_features=scale_features['P2'], p3_features=scale_features['P3'], p4_features=scale_features['P4'], p5_features=scale_features['P5'], p2_position=scale_positions['P2'], p3_position=scale_positions['P3'], p4_position=scale_positions['P4'], p5_position=scale_positions['P5'], confidence=confidence, class_id=class_id, image_idx=batch_idx ) detections.append(multi_features) return detections def _map_center_to_scale_position( self, center_x: float, center_y: float, scale: str, img_w: int, img_h: int ) -> Tuple[int, Tuple[int, int]]: """ Map image center coordinates to feature map position for a specific scale. Args: center_x, center_y: Center coordinates in original image scale: Target scale ('P2', 'P3', 'P4', 'P5') img_w, img_h: Image dimensions Returns: (flat_position, (feat_x, feat_y)) in feature map """ # Get stride for this scale stride = {'P2': 4, 'P3': 8, 'P4': 16, 'P5': 32}[scale] # Get feature map dimensions feat_h, feat_w = self.feature_sizes[scale] # Map image coordinates to feature coordinates # This maintains exact spatial correspondence without rotation/flipping feat_x = int(center_x / stride) feat_y = int(center_y / stride) # Clamp to valid range feat_x = max(0, min(feat_x, feat_w - 1)) feat_y = max(0, min(feat_y, feat_h - 1)) # Convert to flat position in this scale's feature map flat_position = feat_y * feat_w + feat_x # Convert to flat position in concatenated tensor concat_position = self.scale_offsets[scale] + flat_position return concat_position, (feat_x, feat_y) def _extract_raw_feature_at_position( self, batch_features: torch.Tensor, # [116, 24565] for one batch position: int, scale: str ) -> torch.Tensor: """ Extract 116-dimensional raw feature at a specific position. Args: batch_features: Features for one image [116, 24565] position: Flat position in concatenated tensor scale: Scale name (for verification) Returns: Feature tensor [116] """ # Verify position is within expected range for this scale start, end = self.scale_ranges[scale] assert start <= position < end, f"Position {position} outside {scale} range [{start}:{end})" # Additional safety check: ensure position is within tensor bounds tensor_size = batch_features.shape[1] if position >= tensor_size: raise AssertionError(f"Position {position} exceeds tensor size {tensor_size} for scale {scale}") # Extract the 116-dimensional feature feature = batch_features[:, position] # [116] return feature def create_synthetic_test_data(self, batch_size: int = 1) -> Tuple[torch.Tensor, torch.Tensor]: """ Create synthetic test data for validation. Returns: (raw_features, detection_bboxes) raw_features: [B, 116, 24565] detection_bboxes: [B, N, 6] """ print("Creating synthetic test data...") # Create random raw features (116-dimensional) raw_features = torch.randn(batch_size, 116, 24565) # Create synthetic detections at known positions detections_list = [] for batch_idx in range(batch_size): batch_detections = [] # Create test detections at strategic positions test_positions = [ (100, 100, 'P2'), # Should map to P2 position (200, 200, 'P2'), # Should map to P2 position (300, 300, 'P3'), # Should map to P3 position (400, 400, 'P4'), # Should map to P4 position (500, 500, 'P5'), # Should map to P5 position ] for center_x, center_y, target_scale in test_positions: # Create a detection box around the center box_size = {'P2': 20, 'P3': 30, 'P4': 50, 'P5': 80}[target_scale] x1 = center_x - box_size // 2 y1 = center_y - box_size // 2 x2 = center_x + box_size // 2 y2 = center_y + box_size // 2 detection = torch.tensor([ x1, y1, x2, y2, # bbox coordinates 0.8 + 0.1 * torch.rand(1).item(), # confidence (0.8-0.9) torch.randint(0, 52, (1,)).item() # class_id (0-51) ], dtype=torch.float32) batch_detections.append(detection) print(f" Synthetic detection: center=({center_x},{center_y}), scale={target_scale}") detections_list.append(torch.stack(batch_detections)) # Pad to same number of detections per batch max_detections = max(len(dets) for dets in detections_list) detection_bboxes = torch.zeros(batch_size, max_detections, 6) for batch_idx, dets in enumerate(detections_list): detection_bboxes[batch_idx, :len(dets)] = dets print(f"Created synthetic data: raw_features={raw_features.shape}, detections={detection_bboxes.shape}") return raw_features, detection_bboxes def extract_features_from_bbox( self, raw_116_features: Dict[str, torch.Tensor], # Dictionary with scale keys {'P2': tensor, 'P3': tensor, ...} bbox: torch.Tensor, # [4] - Single bbox [x1, y1, x2, y2] center: Tuple[float, float], # [2] - Center point [center_x, center_y] confidence: float, # Confidence score class_id: int, # Class ID image_idx: int, # Image index in batch image_size: Tuple[int, int] # [H, W] ) -> MultiScaleDetectionFeatures: """ 为单个指定的边界框提取多尺度特征,主要用于训练时的GT检测框特征提取 这个方法允许直接从任意指定的边界框位置提取特征,而不依赖于YOLO检测结果。 在训练阶段,可以使用GT检测框的位置和类别信息来提取完全对齐的特征。 """ # 从字典中获取设备信息 if isinstance(raw_116_features, dict): device = raw_116_features['P2'].device if 'P2' in raw_116_features else next(iter(raw_116_features.values())).device else: device = raw_116_features.device # 处理center坐标 if isinstance(center, (tuple, list)): center_x, center_y = float(center[0]), float(center[1]) else: center_x, center_y = center[0].item(), center[1].item() # 为每个尺度提取特征 scale_features = {} scale_positions = {} if isinstance(raw_116_features, dict): # 处理字典格式的多尺度特征 for scale, stride in [('P2', 4), ('P3', 8), ('P4', 16), ('P5', 32)]: # 计算在当前尺度下的特征图坐标 feat_x = int(center_x / stride) feat_y = int(center_y / stride) if scale in raw_116_features: # 获取对应尺度的特征图 scale_feat = raw_116_features[scale] # [B, C, H, W] # 边界检查,确保坐标在有效范围内 _, _, feat_h, feat_w = scale_feat.shape feat_x = max(0, min(feat_x, feat_w - 1)) feat_y = max(0, min(feat_y, feat_h - 1)) # 提取特征向量 [C] feat_vector = scale_feat[image_idx, :, feat_y, feat_x] # [C] # 标准化为116维(如果需要) if feat_vector.shape[0] != 116: # 如果不是116维,使用线性变换或重复/裁剪 if feat_vector.shape[0] > 116: feat_vector = feat_vector[:116] # 截断 else: # 重复填充到116维 repeat_times = (116 + feat_vector.shape[0] - 1) // feat_vector.shape[0] feat_vector = feat_vector.repeat(repeat_times)[:116] scale_features[scale.lower()] = feat_vector scale_positions[f"{scale.lower()}_position"] = (feat_x, feat_y) else: # 如果该尺度不存在,创建零向量 scale_features[scale.lower()] = torch.zeros(116, device=device) scale_positions[f"{scale.lower()}_position"] = (0, 0) else: # 处理张量格式的原始116维特征(保持原有逻辑) for scale, stride in [('P2', 4), ('P3', 8), ('P4', 16), ('P5', 32)]: # 计算在当前尺度下的特征图坐标 feat_x = int(center_x / stride) feat_y = int(center_y / stride) # 边界检查,确保坐标在有效范围内 h, w = self.feature_sizes[scale] feat_x = max(0, min(feat_x, w - 1)) feat_y = max(0, min(feat_y, h - 1)) # 计算在拼接特征图中的线性索引 linear_idx = self.get_feature_index(scale, feat_x, feat_y) # 提取对应的特征向量 [116] feat_vector = raw_116_features[image_idx, :, linear_idx] # [116] scale_features[scale.lower()] = feat_vector scale_positions[f"{scale.lower()}_position"] = (feat_x, feat_y) # 构建MultiScaleDetectionFeatures对象 result = MultiScaleDetectionFeatures( bbox=torch.cat([bbox, torch.tensor([confidence, float(class_id)], device=device)]), # [6] center_point=(center_x, center_y), p2_features=scale_features['p2'], p3_features=scale_features['p3'], p4_features=scale_features['p4'], p5_features=scale_features['p5'], p2_position=scale_positions['p2_position'], p3_position=scale_positions['p3_position'], p4_position=scale_positions['p4_position'], p5_position=scale_positions['p5_position'], confidence=confidence, class_id=class_id, image_idx=image_idx ) return result def test_multi_scale_extractor(): """Test the multi-scale feature extractor.""" print("=" * 80) print("TESTING MULTI-SCALE FEATURE EXTRACTOR") print("=" * 80) # Initialize extractor extractor = MultiScaleFeatureExtractor(input_size=(544, 544)) # Create synthetic test data raw_features, detection_bboxes = extractor.create_synthetic_test_data(batch_size=1) print("\n" + "=" * 80) print("EXTRACTING MULTI-SCALE FEATURES") print("=" * 80) # Extract features multi_features = extractor.extract_multi_scale_features( raw_116_features=raw_features, detection_bboxes=detection_bboxes ) print(f"\n" + "=" * 80) print("ANALYZING RESULTS") print("=" * 80) print(f"Total detections processed: {len(multi_features)}") for i, det_features in enumerate(multi_features): print(f"\n--- Detection {i+1} ---") print(f"BBox: [{det_features.bbox[0]:.1f}, {det_features.bbox[1]:.1f}, " f"{det_features.bbox[2]:.1f}, {det_features.bbox[3]:.1f}]") print(f"Center: {det_features.center_point}") print(f"Confidence: {det_features.confidence:.3f}") print(f"Class ID: {det_features.class_id}") # Verify feature dimensions print(f"Feature dimensions:") print(f" P2: {det_features.p2_features.shape} (pos: {det_features.p2_position})") print(f" P3: {det_features.p3_features.shape} (pos: {det_features.p3_position})") print(f" P4: {det_features.p4_features.shape} (pos: {det_features.p4_position})") print(f" P5: {det_features.p5_features.shape} (pos: {det_features.p5_position})") # Feature statistics concatenated = det_features.concatenated_features feature_matrix = det_features.feature_matrix print(f"Combined features:") print(f" Concatenated: {concatenated.shape} (4×116=464)") print(f" Matrix: {feature_matrix.shape} (4×116)") print(f" Feature norm: {torch.norm(concatenated):.4f}") # Scale-specific statistics scales = ['P2', 'P3', 'P4', 'P5'] for j, scale in enumerate(scales): scale_feat = feature_matrix[j] print(f" {scale} norm: {torch.norm(scale_feat):.4f}, mean: {torch.mean(scale_feat):.4f}") # Test reconstruction print(f"\n" + "=" * 80) print("VALIDATING POSITION MAPPING") print("=" * 80) for i, det_features in enumerate(multi_features[:2]): # Test first 2 detections center_x, center_y = det_features.center_point print(f"\nDetection {i+1} position mapping validation:") print(f"Image center: ({center_x}, {center_y})") for scale in ['P2', 'P3', 'P4', 'P5']: # Get expected position stride = {'P2': 4, 'P3': 8, 'P4': 16, 'P5': 32}[scale] expected_feat_x = int(center_x / stride) expected_feat_y = int(center_y / stride) # Get actual position if scale == 'P2': actual_pos = det_features.p2_position elif scale == 'P3': actual_pos = det_features.p3_position elif scale == 'P4': actual_pos = det_features.p4_position else: # P5 actual_pos = det_features.p5_position actual_feat_x, actual_feat_y = actual_pos print(f" {scale} (stride={stride}):") print(f" Expected: ({expected_feat_x}, {expected_feat_y})") print(f" Actual: ({actual_feat_x}, {actual_feat_y})") print(f" Match: {'✓' if (actual_feat_x, actual_feat_y) == (expected_feat_x, expected_feat_y) else '✗'}") print(f"\n✅ Multi-scale feature extraction test completed!") print(f"✅ Each detection now has 4×116 = 464 dimensional multi-scale features") return extractor, multi_features if __name__ == "__main__": extractor, features = test_multi_scale_extractor()