| """ |
| Multi-Scale Feature Extractor for YOLO Detection |
| |
| This module extracts raw 116-dimensional features from all 4 scales (P2, P3, P4, P5) |
| for each detection box, providing the most comprehensive visual representation. |
| |
| Key features: |
| - Extracts features before DFL processing (116-dim vs 56-dim) |
| - Maintains exact correspondence with original spatial layout |
| - Preserves scale ordering (P2 -> P3 -> P4 -> P5) |
| - Handles exact 2D reconstruction without rotation/flipping |
| """ |
|
|
| import torch |
| import numpy as np |
| from typing import Dict, List, Tuple, Optional |
| from dataclasses import dataclass |
|
|
|
|
| @dataclass |
| class MultiScaleDetectionFeatures: |
| """ |
| Comprehensive features for a single detection across all 4 scales. |
| """ |
| |
| bbox: torch.Tensor |
| center_point: Tuple[float, float] |
| |
| |
| p2_features: torch.Tensor |
| p3_features: torch.Tensor |
| p4_features: torch.Tensor |
| p5_features: torch.Tensor |
| |
| |
| p2_position: Tuple[int, int] |
| p3_position: Tuple[int, int] |
| p4_position: Tuple[int, int] |
| p5_position: Tuple[int, int] |
| |
| |
| confidence: float |
| class_id: int |
| image_idx: int |
| |
| |
| text_features: Optional[torch.Tensor] = None |
| multimodal_features: Optional[torch.Tensor] = None |
| |
| @property |
| def concatenated_features(self) -> torch.Tensor: |
| """Get all scale features concatenated: [4×116] = [464]""" |
| return torch.cat([self.p2_features, self.p3_features, self.p4_features, self.p5_features], dim=0) |
| |
| @property |
| def feature_matrix(self) -> torch.Tensor: |
| """Get features as matrix: [4, 116]""" |
| return torch.stack([self.p2_features, self.p3_features, self.p4_features, self.p5_features], dim=0) |
|
|
|
|
| class MultiScaleFeatureExtractor: |
| """ |
| Extracts multi-scale raw features (116-dim) for detection boxes. |
| |
| This extractor works with the pre-DFL features to capture the richest |
| visual information across all 4 scales of YOLOv8-p2. |
| """ |
| |
| def __init__(self, input_size: Tuple[int, int] = (544, 544)): |
| self.input_size = input_size |
| self.feature_sizes = self._calculate_feature_sizes() |
| |
| |
| total_positions = sum(size[0] * size[1] for size in self.feature_sizes.values()) |
| assert total_positions == 24565, f"Total positions {total_positions} != expected 24565" |
| |
| |
| self._calculate_scale_offsets() |
|
|
| def _adapt_to_hw(self, hw: int): |
| """ |
| 动态适应不同的HW特征图尺寸 |
| |
| Args: |
| hw: 实际的特征位置总数 |
| """ |
| |
| common_configs = { |
| 21760: { |
| 'feature_sizes': { |
| 'P2': (128, 128), |
| 'P3': (64, 64), |
| 'P4': (32, 32), |
| 'P5': (16, 16) |
| }, |
| 'scale_offsets': { |
| 'P2': 0, |
| 'P3': 16384, |
| 'P4': 16384 + 4096, |
| 'P5': 16384 + 4096 + 1024 |
| } |
| } |
| } |
| |
| if hw in common_configs: |
| config = common_configs[hw] |
| |
| total_positions = sum(h*w for h, w in config['feature_sizes'].values()) |
| if total_positions == hw: |
| self.feature_sizes = config['feature_sizes'] |
| self.scale_offsets = config['scale_offsets'] |
| |
| self._update_scale_ranges() |
| return |
| |
| |
| |
| |
| |
| min_p5_positions = 64 |
| |
| |
| p2_positions = hw * 3 // 4 |
| |
| |
| p5_positions = max(min_p5_positions, hw // 32) |
| |
| remaining_for_p3_p4 = hw - p2_positions - p5_positions |
| |
| |
| p3_positions = remaining_for_p3_p4 * 3 // 4 |
| p4_positions = remaining_for_p3_p4 - p3_positions |
| |
| |
| total = p2_positions + p3_positions + p4_positions + p5_positions |
| if total != hw: |
| |
| p2_positions += (hw - total) |
| |
| assert p2_positions > 0 and p3_positions > 0 and p4_positions > 0 and p5_positions > 0, "每个尺度都需要有正数位置" |
| |
| def positions_to_size(positions): |
| import math |
| side = int(math.sqrt(positions)) |
| return side, side |
| |
| p2_size = positions_to_size(p2_positions) |
| p3_size = positions_to_size(p3_positions) |
| p4_size = positions_to_size(p4_positions) |
| p5_size = positions_to_size(p5_positions) |
| |
| actual_p2 = p2_size[0] * p2_size[1] |
| actual_p3 = p3_size[0] * p3_size[1] |
| actual_p4 = p4_size[0] * p4_size[1] |
| |
| self.feature_sizes = { |
| 'P2': p2_size, |
| 'P3': p3_size, |
| 'P4': p4_size, |
| 'P5': p5_size |
| } |
| |
| self.scale_offsets = { |
| 'P2': 0, |
| 'P3': actual_p2, |
| 'P4': actual_p2 + actual_p3, |
| 'P5': actual_p2 + actual_p3 + actual_p4 |
| } |
| |
| |
| self._update_scale_ranges() |
| |
| print(f"[DYNAMIC] 动态计算配置完成 HW={hw}") |
|
|
| def _update_scale_ranges(self): |
| """重新计算scale_ranges以匹配当前配置""" |
| self.scale_ranges = {} |
| scales = ['P2', 'P3', 'P4', 'P5'] |
| |
| for scale in scales: |
| h, w = self.feature_sizes[scale] |
| positions = h * w |
| offset = self.scale_offsets[scale] |
| self.scale_ranges[scale] = (offset, offset + positions) |
| |
| |
| |
| def _calculate_feature_sizes(self) -> Dict[str, Tuple[int, int]]: |
| """Calculate feature map sizes for all scales.""" |
| h, w = self.input_size |
| |
| return { |
| 'P2': (h // 4, w // 4), |
| 'P3': (h // 8, w // 8), |
| 'P4': (h // 16, w // 16), |
| 'P5': (h // 32, w // 32), |
| } |
| |
| def _calculate_scale_offsets(self): |
| """Calculate offset positions for each scale in concatenated tensor.""" |
| self.scale_offsets = {} |
| self.scale_ranges = {} |
| |
| scales = ['P2', 'P3', 'P4', 'P5'] |
| current_offset = 0 |
| |
| for scale in scales: |
| h, w = self.feature_sizes[scale] |
| positions = h * w |
| self.scale_offsets[scale] = current_offset |
| self.scale_ranges[scale] = (current_offset, current_offset + positions) |
| current_offset += positions |
| |
| |
| |
| def extract_multi_scale_features( |
| self, |
| raw_116_features: torch.Tensor, |
| detection_bboxes: torch.Tensor, |
| image_size: Optional[Tuple[int, int]] = None, |
| confidence_threshold: float = 0.0 |
| ) -> List[MultiScaleDetectionFeatures]: |
| """Extract multi-scale 116-dimensional features for detection boxes.""" |
| if image_size is None: |
| image_size = self.input_size |
| |
| B, feature_dim, HW = raw_116_features.shape |
| assert feature_dim == 116, f"Expected feature_dim=116, got {feature_dim}" |
| |
| |
| if HW != 24565: |
| |
| current_positions = sum(h*w for h, w in self.feature_sizes.values()) |
| if current_positions != HW: |
| print(f"[DYNAMIC] 适配特征尺寸: HW={HW}") |
| |
| self._adapt_to_hw(HW) |
| |
| img_h, img_w = image_size |
| detections = [] |
| |
| |
| |
| |
| for batch_idx in range(B): |
| |
| img_detections = detection_bboxes[batch_idx] |
| valid_detections = img_detections[img_detections[:, 4] > confidence_threshold] |
| |
| |
| |
| |
| for det_idx, detection in enumerate(valid_detections): |
| bbox = detection |
| confidence = float(detection[4]) |
| class_id = int(detection[5]) |
| |
| |
| center_x = (bbox[0] + bbox[2]) / 2 |
| center_y = (bbox[1] + bbox[3]) / 2 |
| |
| |
| |
| |
| |
| scale_features = {} |
| scale_positions = {} |
| |
| for scale in ['P2', 'P3', 'P4', 'P5']: |
| |
| feat_pos, feat_coords = self._map_center_to_scale_position( |
| center_x, center_y, scale, img_w, img_h |
| ) |
| |
| |
| feature = self._extract_raw_feature_at_position( |
| raw_116_features[batch_idx], feat_pos, scale |
| ) |
| |
| scale_features[scale] = feature |
| scale_positions[scale] = feat_coords |
| |
| |
| multi_features = MultiScaleDetectionFeatures( |
| bbox=bbox, |
| center_point=(float(center_x), float(center_y)), |
| p2_features=scale_features['P2'], |
| p3_features=scale_features['P3'], |
| p4_features=scale_features['P4'], |
| p5_features=scale_features['P5'], |
| p2_position=scale_positions['P2'], |
| p3_position=scale_positions['P3'], |
| p4_position=scale_positions['P4'], |
| p5_position=scale_positions['P5'], |
| confidence=confidence, |
| class_id=class_id, |
| image_idx=batch_idx |
| ) |
| |
| detections.append(multi_features) |
| |
| return detections |
| |
| def _map_center_to_scale_position( |
| self, |
| center_x: float, |
| center_y: float, |
| scale: str, |
| img_w: int, |
| img_h: int |
| ) -> Tuple[int, Tuple[int, int]]: |
| """ |
| Map image center coordinates to feature map position for a specific scale. |
| |
| Args: |
| center_x, center_y: Center coordinates in original image |
| scale: Target scale ('P2', 'P3', 'P4', 'P5') |
| img_w, img_h: Image dimensions |
| |
| Returns: |
| (flat_position, (feat_x, feat_y)) in feature map |
| """ |
| |
| stride = {'P2': 4, 'P3': 8, 'P4': 16, 'P5': 32}[scale] |
| |
| |
| feat_h, feat_w = self.feature_sizes[scale] |
| |
| |
| |
| feat_x = int(center_x / stride) |
| feat_y = int(center_y / stride) |
| |
| |
| feat_x = max(0, min(feat_x, feat_w - 1)) |
| feat_y = max(0, min(feat_y, feat_h - 1)) |
| |
| |
| flat_position = feat_y * feat_w + feat_x |
| |
| |
| concat_position = self.scale_offsets[scale] + flat_position |
| |
| return concat_position, (feat_x, feat_y) |
| |
| def _extract_raw_feature_at_position( |
| self, |
| batch_features: torch.Tensor, |
| position: int, |
| scale: str |
| ) -> torch.Tensor: |
| """ |
| Extract 116-dimensional raw feature at a specific position. |
| |
| Args: |
| batch_features: Features for one image [116, 24565] |
| position: Flat position in concatenated tensor |
| scale: Scale name (for verification) |
| |
| Returns: |
| Feature tensor [116] |
| """ |
| |
| start, end = self.scale_ranges[scale] |
| assert start <= position < end, f"Position {position} outside {scale} range [{start}:{end})" |
| |
| |
| tensor_size = batch_features.shape[1] |
| if position >= tensor_size: |
| raise AssertionError(f"Position {position} exceeds tensor size {tensor_size} for scale {scale}") |
| |
| |
| feature = batch_features[:, position] |
| |
| return feature |
| |
| def create_synthetic_test_data(self, batch_size: int = 1) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Create synthetic test data for validation. |
| |
| Returns: |
| (raw_features, detection_bboxes) |
| raw_features: [B, 116, 24565] |
| detection_bboxes: [B, N, 6] |
| """ |
| print("Creating synthetic test data...") |
| |
| |
| raw_features = torch.randn(batch_size, 116, 24565) |
| |
| |
| detections_list = [] |
| |
| for batch_idx in range(batch_size): |
| batch_detections = [] |
| |
| |
| test_positions = [ |
| (100, 100, 'P2'), |
| (200, 200, 'P2'), |
| (300, 300, 'P3'), |
| (400, 400, 'P4'), |
| (500, 500, 'P5'), |
| ] |
| |
| for center_x, center_y, target_scale in test_positions: |
| |
| box_size = {'P2': 20, 'P3': 30, 'P4': 50, 'P5': 80}[target_scale] |
| |
| x1 = center_x - box_size // 2 |
| y1 = center_y - box_size // 2 |
| x2 = center_x + box_size // 2 |
| y2 = center_y + box_size // 2 |
| |
| detection = torch.tensor([ |
| x1, y1, x2, y2, |
| 0.8 + 0.1 * torch.rand(1).item(), |
| torch.randint(0, 52, (1,)).item() |
| ], dtype=torch.float32) |
| |
| batch_detections.append(detection) |
| |
| print(f" Synthetic detection: center=({center_x},{center_y}), scale={target_scale}") |
| |
| detections_list.append(torch.stack(batch_detections)) |
| |
| |
| max_detections = max(len(dets) for dets in detections_list) |
| detection_bboxes = torch.zeros(batch_size, max_detections, 6) |
| |
| for batch_idx, dets in enumerate(detections_list): |
| detection_bboxes[batch_idx, :len(dets)] = dets |
| |
| print(f"Created synthetic data: raw_features={raw_features.shape}, detections={detection_bboxes.shape}") |
| |
| return raw_features, detection_bboxes |
|
|
| def extract_features_from_bbox( |
| self, |
| raw_116_features: Dict[str, torch.Tensor], |
| bbox: torch.Tensor, |
| center: Tuple[float, float], |
| confidence: float, |
| class_id: int, |
| image_idx: int, |
| image_size: Tuple[int, int] |
| ) -> MultiScaleDetectionFeatures: |
| """ |
| 为单个指定的边界框提取多尺度特征,主要用于训练时的GT检测框特征提取 |
| |
| 这个方法允许直接从任意指定的边界框位置提取特征,而不依赖于YOLO检测结果。 |
| 在训练阶段,可以使用GT检测框的位置和类别信息来提取完全对齐的特征。 |
| """ |
| |
| if isinstance(raw_116_features, dict): |
| device = raw_116_features['P2'].device if 'P2' in raw_116_features else next(iter(raw_116_features.values())).device |
| else: |
| device = raw_116_features.device |
| |
| |
| if isinstance(center, (tuple, list)): |
| center_x, center_y = float(center[0]), float(center[1]) |
| else: |
| center_x, center_y = center[0].item(), center[1].item() |
| |
| |
| scale_features = {} |
| scale_positions = {} |
| |
| if isinstance(raw_116_features, dict): |
| |
| for scale, stride in [('P2', 4), ('P3', 8), ('P4', 16), ('P5', 32)]: |
| |
| feat_x = int(center_x / stride) |
| feat_y = int(center_y / stride) |
| |
| if scale in raw_116_features: |
| |
| scale_feat = raw_116_features[scale] |
| |
| |
| _, _, feat_h, feat_w = scale_feat.shape |
| feat_x = max(0, min(feat_x, feat_w - 1)) |
| feat_y = max(0, min(feat_y, feat_h - 1)) |
| |
| |
| feat_vector = scale_feat[image_idx, :, feat_y, feat_x] |
| |
| |
| if feat_vector.shape[0] != 116: |
| |
| if feat_vector.shape[0] > 116: |
| feat_vector = feat_vector[:116] |
| else: |
| |
| repeat_times = (116 + feat_vector.shape[0] - 1) // feat_vector.shape[0] |
| feat_vector = feat_vector.repeat(repeat_times)[:116] |
| |
| scale_features[scale.lower()] = feat_vector |
| scale_positions[f"{scale.lower()}_position"] = (feat_x, feat_y) |
| else: |
| |
| scale_features[scale.lower()] = torch.zeros(116, device=device) |
| scale_positions[f"{scale.lower()}_position"] = (0, 0) |
| else: |
| |
| for scale, stride in [('P2', 4), ('P3', 8), ('P4', 16), ('P5', 32)]: |
| |
| feat_x = int(center_x / stride) |
| feat_y = int(center_y / stride) |
| |
| |
| h, w = self.feature_sizes[scale] |
| feat_x = max(0, min(feat_x, w - 1)) |
| feat_y = max(0, min(feat_y, h - 1)) |
| |
| |
| linear_idx = self.get_feature_index(scale, feat_x, feat_y) |
| |
| |
| feat_vector = raw_116_features[image_idx, :, linear_idx] |
| |
| scale_features[scale.lower()] = feat_vector |
| scale_positions[f"{scale.lower()}_position"] = (feat_x, feat_y) |
| |
| |
| result = MultiScaleDetectionFeatures( |
| bbox=torch.cat([bbox, torch.tensor([confidence, float(class_id)], device=device)]), |
| center_point=(center_x, center_y), |
| p2_features=scale_features['p2'], |
| p3_features=scale_features['p3'], |
| p4_features=scale_features['p4'], |
| p5_features=scale_features['p5'], |
| p2_position=scale_positions['p2_position'], |
| p3_position=scale_positions['p3_position'], |
| p4_position=scale_positions['p4_position'], |
| p5_position=scale_positions['p5_position'], |
| confidence=confidence, |
| class_id=class_id, |
| image_idx=image_idx |
| ) |
| |
| return result |
|
|
|
|
| def test_multi_scale_extractor(): |
| """Test the multi-scale feature extractor.""" |
| print("=" * 80) |
| print("TESTING MULTI-SCALE FEATURE EXTRACTOR") |
| print("=" * 80) |
| |
| |
| extractor = MultiScaleFeatureExtractor(input_size=(544, 544)) |
| |
| |
| raw_features, detection_bboxes = extractor.create_synthetic_test_data(batch_size=1) |
| |
| print("\n" + "=" * 80) |
| print("EXTRACTING MULTI-SCALE FEATURES") |
| print("=" * 80) |
| |
| |
| multi_features = extractor.extract_multi_scale_features( |
| raw_116_features=raw_features, |
| detection_bboxes=detection_bboxes |
| ) |
| |
| print(f"\n" + "=" * 80) |
| print("ANALYZING RESULTS") |
| print("=" * 80) |
| |
| print(f"Total detections processed: {len(multi_features)}") |
| |
| for i, det_features in enumerate(multi_features): |
| print(f"\n--- Detection {i+1} ---") |
| print(f"BBox: [{det_features.bbox[0]:.1f}, {det_features.bbox[1]:.1f}, " |
| f"{det_features.bbox[2]:.1f}, {det_features.bbox[3]:.1f}]") |
| print(f"Center: {det_features.center_point}") |
| print(f"Confidence: {det_features.confidence:.3f}") |
| print(f"Class ID: {det_features.class_id}") |
| |
| |
| print(f"Feature dimensions:") |
| print(f" P2: {det_features.p2_features.shape} (pos: {det_features.p2_position})") |
| print(f" P3: {det_features.p3_features.shape} (pos: {det_features.p3_position})") |
| print(f" P4: {det_features.p4_features.shape} (pos: {det_features.p4_position})") |
| print(f" P5: {det_features.p5_features.shape} (pos: {det_features.p5_position})") |
| |
| |
| concatenated = det_features.concatenated_features |
| feature_matrix = det_features.feature_matrix |
| |
| print(f"Combined features:") |
| print(f" Concatenated: {concatenated.shape} (4×116=464)") |
| print(f" Matrix: {feature_matrix.shape} (4×116)") |
| print(f" Feature norm: {torch.norm(concatenated):.4f}") |
| |
| |
| scales = ['P2', 'P3', 'P4', 'P5'] |
| for j, scale in enumerate(scales): |
| scale_feat = feature_matrix[j] |
| print(f" {scale} norm: {torch.norm(scale_feat):.4f}, mean: {torch.mean(scale_feat):.4f}") |
| |
| |
| print(f"\n" + "=" * 80) |
| print("VALIDATING POSITION MAPPING") |
| print("=" * 80) |
| |
| for i, det_features in enumerate(multi_features[:2]): |
| center_x, center_y = det_features.center_point |
| |
| print(f"\nDetection {i+1} position mapping validation:") |
| print(f"Image center: ({center_x}, {center_y})") |
| |
| for scale in ['P2', 'P3', 'P4', 'P5']: |
| |
| stride = {'P2': 4, 'P3': 8, 'P4': 16, 'P5': 32}[scale] |
| expected_feat_x = int(center_x / stride) |
| expected_feat_y = int(center_y / stride) |
| |
| |
| if scale == 'P2': |
| actual_pos = det_features.p2_position |
| elif scale == 'P3': |
| actual_pos = det_features.p3_position |
| elif scale == 'P4': |
| actual_pos = det_features.p4_position |
| else: |
| actual_pos = det_features.p5_position |
| |
| actual_feat_x, actual_feat_y = actual_pos |
| |
| print(f" {scale} (stride={stride}):") |
| print(f" Expected: ({expected_feat_x}, {expected_feat_y})") |
| print(f" Actual: ({actual_feat_x}, {actual_feat_y})") |
| print(f" Match: {'✓' if (actual_feat_x, actual_feat_y) == (expected_feat_x, expected_feat_y) else '✗'}") |
| |
| print(f"\n✅ Multi-scale feature extraction test completed!") |
| print(f"✅ Each detection now has 4×116 = 464 dimensional multi-scale features") |
| |
| return extractor, multi_features |
|
|
|
|
| if __name__ == "__main__": |
| extractor, features = test_multi_scale_extractor() |