File size: 9,891 Bytes
2214a66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
"""
COCO Patch-based Evaluation
直接输入超分后的patches路径,评估Object Detection和Instance Segmentation指标

Metrics:
    - Object Detection: AP^b, AP^b_50, AP^b_75
    - Instance Segmentation: AP^m, AP^m_50, AP^m_75
"""

import os
import json
from pathlib import Path
from typing import Dict, List, Tuple

import torch
from torch.utils.data import DataLoader, Dataset
from torchvision.models.detection import (
    maskrcnn_resnet50_fpn_v2,
    MaskRCNN_ResNet50_FPN_V2_Weights,
)
from torchvision.transforms import functional as TF
from PIL import Image
import numpy as np
from tqdm import tqdm
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from pycocotools import mask as mask_util


# ============================================================================
# 配置参数 - 请修改这里
# ============================================================================
CONFIG = {
    # SR patches 目录 (你的超分模型输出,512x512)
    # 图片命名需与 prepare_coco_patches.py 生成的一致
    'sr_dir': '/home/wanghongbo06/baipurui/results/COCO/DreamClear/results/output',
    
    # Patch annotations 文件 (prepare_coco_patches.py 生成的)
    'ann_file': '/home/wanghongbo06/baipurui/DATA/COCO_patch/patch_annotations.json',
    
    # 推理配置
    'device': 'cuda',
    'batch_size': 8,
    'num_workers': 4,
    
    # 输出
    'output': './coco_patch_eval_results.json',
}
# ============================================================================


# COCO类别ID映射
COCO_CATEGORY_IDS = [
    1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17,
    18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
    35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49,
    50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
    64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
    82, 84, 85, 86, 87, 88, 89, 90
]


class PatchDataset(Dataset):
    """Patch数据集"""
    
    def __init__(self, image_dir: str, ann_file: str):
        self.image_dir = Path(image_dir)
        self.coco = COCO(ann_file)
        
        # 过滤出存在的图片
        all_image_ids = list(self.coco.imgs.keys())
        self.image_ids = []
        missing_count = 0
        
        print("Checking image files...")
        for img_id in all_image_ids:
            img_info = self.coco.imgs[img_id]
            img_path = self.image_dir / img_info['file_name']
            if img_path.exists():
                self.image_ids.append(img_id)
            else:
                missing_count += 1
        
        if missing_count > 0:
            print(f"Warning: {missing_count} images not found, skipped.")
            print(f"Valid images: {len(self.image_ids)} / {len(all_image_ids)}")
    
    def __len__(self):
        return len(self.image_ids)
    
    def __getitem__(self, idx: int) -> Dict:
        image_id = self.image_ids[idx]
        img_info = self.coco.imgs[image_id]
        img_path = self.image_dir / img_info['file_name']
        
        image = Image.open(img_path).convert('RGB')
        image_tensor = TF.to_tensor(image)
        
        return {
            'image': image_tensor,
            'image_id': image_id,
        }


def collate_fn(batch):
    return batch


class PatchEvaluator:
    """Patch评估器"""
    
    def __init__(self, ann_file: str, device: str = 'cuda'):
        self.ann_file = ann_file
        self.device = device
        self.coco_gt = COCO(ann_file)
        
        print("Loading Mask R-CNN ResNet50 FPN v2...")
        weights = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT
        self.detector = maskrcnn_resnet50_fpn_v2(weights=weights)
        self.detector.eval()
        self.detector.to(device)
        
        # 注意:torchvision Mask R-CNN 输出的 labels 已经是 COCO category_id
        # 不需要额外映射!
    
    @torch.no_grad()
    def predict(self, images: List[torch.Tensor]) -> List[Dict]:
        images = [img.to(self.device) for img in images]
        return self.detector(images)
    
    def convert_to_coco_format(self, predictions: List[Dict], image_ids: List[int]) -> Tuple:
        bbox_results = []
        segm_results = []
        
        for pred, image_id in zip(predictions, image_ids):
            boxes = pred['boxes'].cpu().numpy()
            labels = pred['labels'].cpu().numpy()
            scores = pred['scores'].cpu().numpy()
            masks = pred['masks'].cpu().numpy()
            
            for i in range(len(boxes)):
                if scores[i] < 0.05:  # 提高阈值
                    continue
                
                x1, y1, x2, y2 = boxes[i]
                bbox = [float(x1), float(y1), float(x2 - x1), float(y2 - y1)]
                
                category_id = int(labels[i])  # 直接使用,已经是 COCO category_id
                
                bbox_results.append({
                    'image_id': int(image_id),
                    'category_id': category_id,
                    'bbox': bbox,
                    'score': float(scores[i]),
                })
                
                mask = masks[i, 0]
                mask_binary = (mask > 0.5).astype(np.uint8)
                rle = mask_util.encode(np.asfortranarray(mask_binary))
                rle['counts'] = rle['counts'].decode('utf-8')
                
                segm_results.append({
                    'image_id': int(image_id),
                    'category_id': category_id,
                    'segmentation': rle,
                    'score': float(scores[i]),
                })
        
        return bbox_results, segm_results
    
    def evaluate(self, dataset: PatchDataset, batch_size: int = 8, num_workers: int = 4) -> Dict:
        dataloader = DataLoader(
            dataset, batch_size=batch_size, shuffle=False,
            num_workers=num_workers, collate_fn=collate_fn,
        )
        
        all_bbox_results = []
        all_segm_results = []
        
        print("Running inference...")
        for batch in tqdm(dataloader):
            images = [item['image'] for item in batch]
            image_ids = [item['image_id'] for item in batch]
            
            predictions = self.predict(images)
            bbox_results, segm_results = self.convert_to_coco_format(predictions, image_ids)
            
            all_bbox_results.extend(bbox_results)
            all_segm_results.extend(segm_results)
        
        print("\n" + "="*60)
        print("Object Detection (Bounding Box) Results:")
        print("="*60)
        bbox_metrics = self._run_eval(all_bbox_results, 'bbox')
        
        print("\n" + "="*60)
        print("Instance Segmentation (Mask) Results:")
        print("="*60)
        segm_metrics = self._run_eval(all_segm_results, 'segm')
        
        return {
            'AP_bbox': bbox_metrics['AP'],
            'AP_bbox_50': bbox_metrics['AP_50'],
            'AP_bbox_75': bbox_metrics['AP_75'],
            'AP_mask': segm_metrics['AP'],
            'AP_mask_50': segm_metrics['AP_50'],
            'AP_mask_75': segm_metrics['AP_75'],
        }
    
    def _run_eval(self, results: List[Dict], iou_type: str) -> Dict:
        if not results:
            print(f"No predictions for {iou_type}!")
            return {'AP': 0.0, 'AP_50': 0.0, 'AP_75': 0.0}
        
        try:
            coco_dt = self.coco_gt.loadRes(results)
            coco_eval = COCOeval(self.coco_gt, coco_dt, iou_type)
            coco_eval.evaluate()
            coco_eval.accumulate()
            coco_eval.summarize()
            
            return {
                'AP': coco_eval.stats[0],
                'AP_50': coco_eval.stats[1],
                'AP_75': coco_eval.stats[2],
            }
        except KeyError as e:
            print(f"Error evaluating {iou_type}: {e}")
            print(f"Some GT annotations may be missing 'segmentation' field.")
            print(f"Skipping {iou_type} evaluation.")
            return {'AP': 0.0, 'AP_50': 0.0, 'AP_75': 0.0}


def main():
    device = CONFIG['device']
    if device == 'cuda' and not torch.cuda.is_available():
        print("CUDA not available, using CPU")
        device = 'cpu'
    
    # 自动生成 output 路径:根据 sr_dir 最后一个目录名
    sr_dir = Path(CONFIG['sr_dir'])
    baseline_name = sr_dir.name  # 获取最后一个目录名,如 'sr', 'gt', 'bicubic' 等
    output_path = Path(f"./coco_patch_eval_results_{baseline_name}.json")
    
    print(f"Loading SR patches from: {CONFIG['sr_dir']}")
    print(f"Loading annotations from: {CONFIG['ann_file']}")
    print(f"Output will be saved to: {output_path}")
    
    dataset = PatchDataset(
        image_dir=CONFIG['sr_dir'],
        ann_file=CONFIG['ann_file'],
    )
    print(f"Dataset size: {len(dataset)} patches")
    
    evaluator = PatchEvaluator(ann_file=CONFIG['ann_file'], device=device)
    
    print("\nStarting evaluation...")
    results = evaluator.evaluate(
        dataset=dataset,
        batch_size=CONFIG['batch_size'],
        num_workers=CONFIG['num_workers'],
    )
    
    print("\n" + "="*60)
    print("EVALUATION SUMMARY")
    print("="*60)
    print("\nObject Detection (Bounding Box):")
    print(f"  AP^b        : {results['AP_bbox']*100:.2f}")
    print(f"  AP^b_50     : {results['AP_bbox_50']*100:.2f}")
    print(f"  AP^b_75     : {results['AP_bbox_75']*100:.2f}")
    print("\nInstance Segmentation (Mask):")
    print(f"  AP^m        : {results['AP_mask']*100:.2f}")
    print(f"  AP^m_50     : {results['AP_mask_50']*100:.2f}")
    print(f"  AP^m_75     : {results['AP_mask_75']*100:.2f}")
    
    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, 'w') as f:
        json.dump(results, f, indent=2)
    print(f"\nResults saved to {output_path}")


if __name__ == '__main__':
    main()