File size: 26,618 Bytes
e168a4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
"""
Multi-Scale Feature Extractor for YOLO Detection

This module extracts raw 116-dimensional features from all 4 scales (P2, P3, P4, P5)
for each detection box, providing the most comprehensive visual representation.

Key features:
- Extracts features before DFL processing (116-dim vs 56-dim)
- Maintains exact correspondence with original spatial layout
- Preserves scale ordering (P2 -> P3 -> P4 -> P5)
- Handles exact 2D reconstruction without rotation/flipping
"""

import torch
import numpy as np
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass


@dataclass
class MultiScaleDetectionFeatures:
    """
    Comprehensive features for a single detection across all 4 scales.
    """
    # Detection information
    bbox: torch.Tensor  # [x1, y1, x2, y2, score, class_id]
    center_point: Tuple[float, float]  # (center_x, center_y)
    
    # Multi-scale features (4 scales × 116 dimensions)
    p2_features: torch.Tensor  # [116] - P2 scale (136×136, stride=4)
    p3_features: torch.Tensor  # [116] - P3 scale (68×68, stride=8)  
    p4_features: torch.Tensor  # [116] - P4 scale (34×34, stride=16)
    p5_features: torch.Tensor  # [116] - P5 scale (17×17, stride=32)
    
    # Feature positions (for verification/debugging)
    p2_position: Tuple[int, int]  # (x, y) in P2 feature map
    p3_position: Tuple[int, int]  # (x, y) in P3 feature map
    p4_position: Tuple[int, int]  # (x, y) in P4 feature map
    p5_position: Tuple[int, int]  # (x, y) in P5 feature map
    
    # Metadata
    confidence: float
    class_id: int
    image_idx: int
    
    # 多模态特征 (可选)
    text_features: Optional[torch.Tensor] = None  # [116] - ClinicalBERT文本特征
    multimodal_features: Optional[torch.Tensor] = None  # [580] - 视觉+文本融合特征
    
    @property
    def concatenated_features(self) -> torch.Tensor:
        """Get all scale features concatenated: [4×116] = [464]"""
        return torch.cat([self.p2_features, self.p3_features, self.p4_features, self.p5_features], dim=0)
    
    @property
    def feature_matrix(self) -> torch.Tensor:
        """Get features as matrix: [4, 116]"""
        return torch.stack([self.p2_features, self.p3_features, self.p4_features, self.p5_features], dim=0)


class MultiScaleFeatureExtractor:
    """
    Extracts multi-scale raw features (116-dim) for detection boxes.
    
    This extractor works with the pre-DFL features to capture the richest
    visual information across all 4 scales of YOLOv8-p2.
    """
    
    def __init__(self, input_size: Tuple[int, int] = (544, 544)):
        self.input_size = input_size
        self.feature_sizes = self._calculate_feature_sizes()
        
        # Verify total matches expected HW = 24565
        total_positions = sum(size[0] * size[1] for size in self.feature_sizes.values())
        assert total_positions == 24565, f"Total positions {total_positions} != expected 24565"
        
        # Calculate offsets for each scale in the concatenated tensor
        self._calculate_scale_offsets()

    def _adapt_to_hw(self, hw: int):
        """
        动态适应不同的HW特征图尺寸
        
        Args:
            hw: 实际的特征位置总数
        """
        # 常见的配置映射 - 只保留验证过的配置
        common_configs = {
            21760: {
                'feature_sizes': {
                    'P2': (128, 128),  # 16384 positions
                    'P3': (64, 64),    # 4096 positions
                    'P4': (32, 32),    # 1024 positions  
                    'P5': (16, 16)     # 256 positions
                },
                'scale_offsets': {
                    'P2': 0,
                    'P3': 16384,
                    'P4': 16384 + 4096,
                    'P5': 16384 + 4096 + 1024
                }
            }
        }
        
        if hw in common_configs:
            config = common_configs[hw]
            # 验证配置是否匹配
            total_positions = sum(h*w for h, w in config['feature_sizes'].values())
            if total_positions == hw:
                self.feature_sizes = config['feature_sizes']
                self.scale_offsets = config['scale_offsets']
                # 重新计算scale_ranges
                self._update_scale_ranges()
                return
        
        # 改进的动态计算配置,确保每个尺度都有合理的位置数
        # 使用更保守的比例分配,确保P5至少有一些位置
        
        # 为P5预留最小位置数
        min_p5_positions = 64  # 8×8的最小特征图
        
        # P2获取大约75%的位置
        p2_positions = hw * 3 // 4
        
        # 为P5预留位置
        p5_positions = max(min_p5_positions, hw // 32)  # 至少1/32的位置给P5
        
        remaining_for_p3_p4 = hw - p2_positions - p5_positions
        
        # P3和P4按3:1的比例分配剩余位置
        p3_positions = remaining_for_p3_p4 * 3 // 4
        p4_positions = remaining_for_p3_p4 - p3_positions
        
        # 调整确保总和等于hw
        total = p2_positions + p3_positions + p4_positions + p5_positions
        if total != hw:
            # 微调P2以匹配总数
            p2_positions += (hw - total)
        
        assert p2_positions > 0 and p3_positions > 0 and p4_positions > 0 and p5_positions > 0, "每个尺度都需要有正数位置"
        
        def positions_to_size(positions):
            import math
            side = int(math.sqrt(positions))
            return side, side
        
        p2_size = positions_to_size(p2_positions)
        p3_size = positions_to_size(p3_positions) 
        p4_size = positions_to_size(p4_positions)
        p5_size = positions_to_size(p5_positions)
        
        actual_p2 = p2_size[0] * p2_size[1]
        actual_p3 = p3_size[0] * p3_size[1]
        actual_p4 = p4_size[0] * p4_size[1]
        
        self.feature_sizes = {
            'P2': p2_size,
            'P3': p3_size, 
            'P4': p4_size,
            'P5': p5_size
        }
        
        self.scale_offsets = {
            'P2': 0,
            'P3': actual_p2,
            'P4': actual_p2 + actual_p3,
            'P5': actual_p2 + actual_p3 + actual_p4
        }
        
        # 重新计算scale_ranges
        self._update_scale_ranges()
        
        print(f"[DYNAMIC] 动态计算配置完成 HW={hw}")

    def _update_scale_ranges(self):
        """重新计算scale_ranges以匹配当前配置"""
        self.scale_ranges = {}
        scales = ['P2', 'P3', 'P4', 'P5']
        
        for scale in scales:
            h, w = self.feature_sizes[scale]
            positions = h * w
            offset = self.scale_offsets[scale]
            self.scale_ranges[scale] = (offset, offset + positions)
        
        # Update scale ranges
    
    def _calculate_feature_sizes(self) -> Dict[str, Tuple[int, int]]:
        """Calculate feature map sizes for all scales."""
        h, w = self.input_size
        
        return {
            'P2': (h // 4, w // 4),    # stride=4  -> 136×136 = 18,496
            'P3': (h // 8, w // 8),    # stride=8  -> 68×68   =  4,624  
            'P4': (h // 16, w // 16),  # stride=16 -> 34×34   =  1,156
            'P5': (h // 32, w // 32),  # stride=32 -> 17×17   =    289
        }
    
    def _calculate_scale_offsets(self):
        """Calculate offset positions for each scale in concatenated tensor."""
        self.scale_offsets = {}
        self.scale_ranges = {}
        
        scales = ['P2', 'P3', 'P4', 'P5']
        current_offset = 0
        
        for scale in scales:
            h, w = self.feature_sizes[scale]
            positions = h * w
            self.scale_offsets[scale] = current_offset
            self.scale_ranges[scale] = (current_offset, current_offset + positions)
            current_offset += positions
        
        # Scale offsets calculated
    
    def extract_multi_scale_features(
        self,
        raw_116_features: torch.Tensor,  # [B, 116, HW] - Raw features before DFL
        detection_bboxes: torch.Tensor,  # [B, N, 6] - Detection boxes [x1,y1,x2,y2,score,cls]
        image_size: Optional[Tuple[int, int]] = None,
        confidence_threshold: float = 0.0  # 添加置信度阈值参数,默认为0(不过滤)
    ) -> List[MultiScaleDetectionFeatures]:
        """Extract multi-scale 116-dimensional features for detection boxes."""
        if image_size is None:
            image_size = self.input_size
        
        B, feature_dim, HW = raw_116_features.shape
        assert feature_dim == 116, f"Expected feature_dim=116, got {feature_dim}"
        
        # 动态适应不同的特征图尺寸
        if HW != 24565:
            # 验证当前配置是否匹配HW
            current_positions = sum(h*w for h, w in self.feature_sizes.values())
            if current_positions != HW:
                print(f"[DYNAMIC] 适配特征尺寸: HW={HW}")
                # 尝试基于常见配置重新初始化
                self._adapt_to_hw(HW)
        
        img_h, img_w = image_size
        detections = []
        
        # Extract multi-scale features from raw features
        
        # Process each image in the batch
        for batch_idx in range(B):
            # Get detections for this image
            img_detections = detection_bboxes[batch_idx]
            valid_detections = img_detections[img_detections[:, 4] > confidence_threshold]  # 使用传入的置信度阈值
            
            # Skip images with no valid detections
            
            # Process each detection
            for det_idx, detection in enumerate(valid_detections):
                bbox = detection  # [x1, y1, x2, y2, score, class_id]
                confidence = float(detection[4])
                class_id = int(detection[5])
                
                # Calculate center point
                center_x = (bbox[0] + bbox[2]) / 2
                center_y = (bbox[1] + bbox[3]) / 2
                
                # 减少详细输出以提高性能
                # print(f"    Detection {det_idx+1}: center=({center_x:.1f}, {center_y:.1f}), conf={confidence:.3f}")
                
                # Extract features for all 4 scales
                scale_features = {}
                scale_positions = {}
                
                for scale in ['P2', 'P3', 'P4', 'P5']:
                    # Calculate feature map position for this scale
                    feat_pos, feat_coords = self._map_center_to_scale_position(
                        center_x, center_y, scale, img_w, img_h
                    )
                    
                    # Extract 116-dimensional feature from raw features
                    feature = self._extract_raw_feature_at_position(
                        raw_116_features[batch_idx], feat_pos, scale
                    )
                    
                    scale_features[scale] = feature
                    scale_positions[scale] = feat_coords
                
                # Create multi-scale detection features
                multi_features = MultiScaleDetectionFeatures(
                    bbox=bbox,
                    center_point=(float(center_x), float(center_y)),
                    p2_features=scale_features['P2'],
                    p3_features=scale_features['P3'], 
                    p4_features=scale_features['P4'],
                    p5_features=scale_features['P5'],
                    p2_position=scale_positions['P2'],
                    p3_position=scale_positions['P3'],
                    p4_position=scale_positions['P4'],
                    p5_position=scale_positions['P5'],
                    confidence=confidence,
                    class_id=class_id,
                    image_idx=batch_idx
                )
                
                detections.append(multi_features)
        
        return detections
    
    def _map_center_to_scale_position(
        self, 
        center_x: float, 
        center_y: float,
        scale: str,
        img_w: int,
        img_h: int
    ) -> Tuple[int, Tuple[int, int]]:
        """
        Map image center coordinates to feature map position for a specific scale.
        
        Args:
            center_x, center_y: Center coordinates in original image
            scale: Target scale ('P2', 'P3', 'P4', 'P5')
            img_w, img_h: Image dimensions
            
        Returns:
            (flat_position, (feat_x, feat_y)) in feature map
        """
        # Get stride for this scale
        stride = {'P2': 4, 'P3': 8, 'P4': 16, 'P5': 32}[scale]
        
        # Get feature map dimensions
        feat_h, feat_w = self.feature_sizes[scale]
        
        # Map image coordinates to feature coordinates
        # This maintains exact spatial correspondence without rotation/flipping
        feat_x = int(center_x / stride)
        feat_y = int(center_y / stride)
        
        # Clamp to valid range
        feat_x = max(0, min(feat_x, feat_w - 1))
        feat_y = max(0, min(feat_y, feat_h - 1))
        
        # Convert to flat position in this scale's feature map
        flat_position = feat_y * feat_w + feat_x
        
        # Convert to flat position in concatenated tensor
        concat_position = self.scale_offsets[scale] + flat_position
        
        return concat_position, (feat_x, feat_y)
    
    def _extract_raw_feature_at_position(
        self,
        batch_features: torch.Tensor,  # [116, 24565] for one batch
        position: int,
        scale: str
    ) -> torch.Tensor:
        """
        Extract 116-dimensional raw feature at a specific position.
        
        Args:
            batch_features: Features for one image [116, 24565]
            position: Flat position in concatenated tensor
            scale: Scale name (for verification)
            
        Returns:
            Feature tensor [116]
        """
        # Verify position is within expected range for this scale
        start, end = self.scale_ranges[scale]
        assert start <= position < end, f"Position {position} outside {scale} range [{start}:{end})"
        
        # Additional safety check: ensure position is within tensor bounds
        tensor_size = batch_features.shape[1]
        if position >= tensor_size:
            raise AssertionError(f"Position {position} exceeds tensor size {tensor_size} for scale {scale}")
        
        # Extract the 116-dimensional feature
        feature = batch_features[:, position]  # [116]
        
        return feature
    
    def create_synthetic_test_data(self, batch_size: int = 1) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Create synthetic test data for validation.
        
        Returns:
            (raw_features, detection_bboxes)
            raw_features: [B, 116, 24565]
            detection_bboxes: [B, N, 6]
        """
        print("Creating synthetic test data...")
        
        # Create random raw features (116-dimensional)
        raw_features = torch.randn(batch_size, 116, 24565)
        
        # Create synthetic detections at known positions
        detections_list = []
        
        for batch_idx in range(batch_size):
            batch_detections = []
            
            # Create test detections at strategic positions
            test_positions = [
                (100, 100, 'P2'),  # Should map to P2 position
                (200, 200, 'P2'),  # Should map to P2 position  
                (300, 300, 'P3'),  # Should map to P3 position
                (400, 400, 'P4'),  # Should map to P4 position
                (500, 500, 'P5'),  # Should map to P5 position
            ]
            
            for center_x, center_y, target_scale in test_positions:
                # Create a detection box around the center
                box_size = {'P2': 20, 'P3': 30, 'P4': 50, 'P5': 80}[target_scale]
                
                x1 = center_x - box_size // 2
                y1 = center_y - box_size // 2
                x2 = center_x + box_size // 2
                y2 = center_y + box_size // 2
                
                detection = torch.tensor([
                    x1, y1, x2, y2,  # bbox coordinates
                    0.8 + 0.1 * torch.rand(1).item(),  # confidence (0.8-0.9)
                    torch.randint(0, 52, (1,)).item()   # class_id (0-51)
                ], dtype=torch.float32)
                
                batch_detections.append(detection)
                
                print(f"  Synthetic detection: center=({center_x},{center_y}), scale={target_scale}")
            
            detections_list.append(torch.stack(batch_detections))
        
        # Pad to same number of detections per batch
        max_detections = max(len(dets) for dets in detections_list)
        detection_bboxes = torch.zeros(batch_size, max_detections, 6)
        
        for batch_idx, dets in enumerate(detections_list):
            detection_bboxes[batch_idx, :len(dets)] = dets
        
        print(f"Created synthetic data: raw_features={raw_features.shape}, detections={detection_bboxes.shape}")
        
        return raw_features, detection_bboxes

    def extract_features_from_bbox(
        self,
        raw_116_features: Dict[str, torch.Tensor],  # Dictionary with scale keys {'P2': tensor, 'P3': tensor, ...}
        bbox: torch.Tensor,  # [4] - Single bbox [x1, y1, x2, y2]
        center: Tuple[float, float],  # [2] - Center point [center_x, center_y]
        confidence: float,  # Confidence score
        class_id: int,  # Class ID
        image_idx: int,  # Image index in batch
        image_size: Tuple[int, int]  # [H, W]
    ) -> MultiScaleDetectionFeatures:
        """
        为单个指定的边界框提取多尺度特征,主要用于训练时的GT检测框特征提取
        
        这个方法允许直接从任意指定的边界框位置提取特征,而不依赖于YOLO检测结果。
        在训练阶段,可以使用GT检测框的位置和类别信息来提取完全对齐的特征。
        """
        # 从字典中获取设备信息
        if isinstance(raw_116_features, dict):
            device = raw_116_features['P2'].device if 'P2' in raw_116_features else next(iter(raw_116_features.values())).device
        else:
            device = raw_116_features.device
        
        # 处理center坐标
        if isinstance(center, (tuple, list)):
            center_x, center_y = float(center[0]), float(center[1])
        else:
            center_x, center_y = center[0].item(), center[1].item()
        
        # 为每个尺度提取特征
        scale_features = {}
        scale_positions = {}
        
        if isinstance(raw_116_features, dict):
            # 处理字典格式的多尺度特征
            for scale, stride in [('P2', 4), ('P3', 8), ('P4', 16), ('P5', 32)]:
                # 计算在当前尺度下的特征图坐标
                feat_x = int(center_x / stride)
                feat_y = int(center_y / stride)
                
                if scale in raw_116_features:
                    # 获取对应尺度的特征图
                    scale_feat = raw_116_features[scale]  # [B, C, H, W]
                    
                    # 边界检查,确保坐标在有效范围内
                    _, _, feat_h, feat_w = scale_feat.shape
                    feat_x = max(0, min(feat_x, feat_w - 1))
                    feat_y = max(0, min(feat_y, feat_h - 1))
                    
                    # 提取特征向量 [C]
                    feat_vector = scale_feat[image_idx, :, feat_y, feat_x]  # [C]
                    
                    # 标准化为116维(如果需要)
                    if feat_vector.shape[0] != 116:
                        # 如果不是116维,使用线性变换或重复/裁剪
                        if feat_vector.shape[0] > 116:
                            feat_vector = feat_vector[:116]  # 截断
                        else:
                            # 重复填充到116维
                            repeat_times = (116 + feat_vector.shape[0] - 1) // feat_vector.shape[0]
                            feat_vector = feat_vector.repeat(repeat_times)[:116]
                    
                    scale_features[scale.lower()] = feat_vector
                    scale_positions[f"{scale.lower()}_position"] = (feat_x, feat_y)
                else:
                    # 如果该尺度不存在,创建零向量
                    scale_features[scale.lower()] = torch.zeros(116, device=device)
                    scale_positions[f"{scale.lower()}_position"] = (0, 0)
        else:
            # 处理张量格式的原始116维特征(保持原有逻辑)
            for scale, stride in [('P2', 4), ('P3', 8), ('P4', 16), ('P5', 32)]:
                # 计算在当前尺度下的特征图坐标
                feat_x = int(center_x / stride)
                feat_y = int(center_y / stride)
                
                # 边界检查,确保坐标在有效范围内
                h, w = self.feature_sizes[scale]
                feat_x = max(0, min(feat_x, w - 1))
                feat_y = max(0, min(feat_y, h - 1))
                
                # 计算在拼接特征图中的线性索引
                linear_idx = self.get_feature_index(scale, feat_x, feat_y)
                
                # 提取对应的特征向量 [116]
                feat_vector = raw_116_features[image_idx, :, linear_idx]  # [116]
                
                scale_features[scale.lower()] = feat_vector
                scale_positions[f"{scale.lower()}_position"] = (feat_x, feat_y)
        
        # 构建MultiScaleDetectionFeatures对象
        result = MultiScaleDetectionFeatures(
            bbox=torch.cat([bbox, torch.tensor([confidence, float(class_id)], device=device)]),  # [6]
            center_point=(center_x, center_y),
            p2_features=scale_features['p2'],
            p3_features=scale_features['p3'],
            p4_features=scale_features['p4'],
            p5_features=scale_features['p5'],
            p2_position=scale_positions['p2_position'],
            p3_position=scale_positions['p3_position'],
            p4_position=scale_positions['p4_position'],
            p5_position=scale_positions['p5_position'],
            confidence=confidence,
            class_id=class_id,
            image_idx=image_idx
        )
        
        return result


def test_multi_scale_extractor():
    """Test the multi-scale feature extractor."""
    print("=" * 80)
    print("TESTING MULTI-SCALE FEATURE EXTRACTOR")
    print("=" * 80)
    
    # Initialize extractor
    extractor = MultiScaleFeatureExtractor(input_size=(544, 544))
    
    # Create synthetic test data
    raw_features, detection_bboxes = extractor.create_synthetic_test_data(batch_size=1)
    
    print("\n" + "=" * 80)
    print("EXTRACTING MULTI-SCALE FEATURES")
    print("=" * 80)
    
    # Extract features
    multi_features = extractor.extract_multi_scale_features(
        raw_116_features=raw_features,
        detection_bboxes=detection_bboxes
    )
    
    print(f"\n" + "=" * 80)
    print("ANALYZING RESULTS")
    print("=" * 80)
    
    print(f"Total detections processed: {len(multi_features)}")
    
    for i, det_features in enumerate(multi_features):
        print(f"\n--- Detection {i+1} ---")
        print(f"BBox: [{det_features.bbox[0]:.1f}, {det_features.bbox[1]:.1f}, "
              f"{det_features.bbox[2]:.1f}, {det_features.bbox[3]:.1f}]")
        print(f"Center: {det_features.center_point}")
        print(f"Confidence: {det_features.confidence:.3f}")
        print(f"Class ID: {det_features.class_id}")
        
        # Verify feature dimensions
        print(f"Feature dimensions:")
        print(f"  P2: {det_features.p2_features.shape} (pos: {det_features.p2_position})")
        print(f"  P3: {det_features.p3_features.shape} (pos: {det_features.p3_position})")
        print(f"  P4: {det_features.p4_features.shape} (pos: {det_features.p4_position})")
        print(f"  P5: {det_features.p5_features.shape} (pos: {det_features.p5_position})")
        
        # Feature statistics
        concatenated = det_features.concatenated_features
        feature_matrix = det_features.feature_matrix
        
        print(f"Combined features:")
        print(f"  Concatenated: {concatenated.shape} (4×116=464)")
        print(f"  Matrix:       {feature_matrix.shape} (4×116)")
        print(f"  Feature norm: {torch.norm(concatenated):.4f}")
        
        # Scale-specific statistics
        scales = ['P2', 'P3', 'P4', 'P5']
        for j, scale in enumerate(scales):
            scale_feat = feature_matrix[j]
            print(f"  {scale} norm: {torch.norm(scale_feat):.4f}, mean: {torch.mean(scale_feat):.4f}")
    
    # Test reconstruction
    print(f"\n" + "=" * 80)
    print("VALIDATING POSITION MAPPING")
    print("=" * 80)
    
    for i, det_features in enumerate(multi_features[:2]):  # Test first 2 detections
        center_x, center_y = det_features.center_point
        
        print(f"\nDetection {i+1} position mapping validation:")
        print(f"Image center: ({center_x}, {center_y})")
        
        for scale in ['P2', 'P3', 'P4', 'P5']:
            # Get expected position
            stride = {'P2': 4, 'P3': 8, 'P4': 16, 'P5': 32}[scale]
            expected_feat_x = int(center_x / stride)
            expected_feat_y = int(center_y / stride)
            
            # Get actual position
            if scale == 'P2':
                actual_pos = det_features.p2_position
            elif scale == 'P3':
                actual_pos = det_features.p3_position
            elif scale == 'P4':
                actual_pos = det_features.p4_position
            else:  # P5
                actual_pos = det_features.p5_position
            
            actual_feat_x, actual_feat_y = actual_pos
            
            print(f"  {scale} (stride={stride}):")
            print(f"    Expected: ({expected_feat_x}, {expected_feat_y})")
            print(f"    Actual:   ({actual_feat_x}, {actual_feat_y})")
            print(f"    Match:    {'✓' if (actual_feat_x, actual_feat_y) == (expected_feat_x, expected_feat_y) else '✗'}")
    
    print(f"\n✅ Multi-scale feature extraction test completed!")
    print(f"✅ Each detection now has 4×116 = 464 dimensional multi-scale features")
    
    return extractor, multi_features


if __name__ == "__main__":
    extractor, features = test_multi_scale_extractor()