""" COCO Patch-based Evaluation 直接输入超分后的patches路径,评估Object Detection和Instance Segmentation指标 Metrics: - Object Detection: AP^b, AP^b_50, AP^b_75 - Instance Segmentation: AP^m, AP^m_50, AP^m_75 """ import os import json from pathlib import Path from typing import Dict, List, Tuple import torch from torch.utils.data import DataLoader, Dataset from torchvision.models.detection import ( maskrcnn_resnet50_fpn_v2, MaskRCNN_ResNet50_FPN_V2_Weights, ) from torchvision.transforms import functional as TF from PIL import Image import numpy as np from tqdm import tqdm from pycocotools.coco import COCO from pycocotools.cocoeval import COCOeval from pycocotools import mask as mask_util # ============================================================================ # 配置参数 - 请修改这里 # ============================================================================ CONFIG = { # SR patches 目录 (你的超分模型输出,512x512) # 图片命名需与 prepare_coco_patches.py 生成的一致 'sr_dir': '/home/wanghongbo06/baipurui/results/COCO/DreamClear/results/output', # Patch annotations 文件 (prepare_coco_patches.py 生成的) 'ann_file': '/home/wanghongbo06/baipurui/DATA/COCO_patch/patch_annotations.json', # 推理配置 'device': 'cuda', 'batch_size': 8, 'num_workers': 4, # 输出 'output': './coco_patch_eval_results.json', } # ============================================================================ # COCO类别ID映射 COCO_CATEGORY_IDS = [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90 ] class PatchDataset(Dataset): """Patch数据集""" def __init__(self, image_dir: str, ann_file: str): self.image_dir = Path(image_dir) self.coco = COCO(ann_file) # 过滤出存在的图片 all_image_ids = list(self.coco.imgs.keys()) self.image_ids = [] missing_count = 0 print("Checking image files...") for img_id in all_image_ids: img_info = self.coco.imgs[img_id] img_path = self.image_dir / img_info['file_name'] if img_path.exists(): self.image_ids.append(img_id) else: missing_count += 1 if missing_count > 0: print(f"Warning: {missing_count} images not found, skipped.") print(f"Valid images: {len(self.image_ids)} / {len(all_image_ids)}") def __len__(self): return len(self.image_ids) def __getitem__(self, idx: int) -> Dict: image_id = self.image_ids[idx] img_info = self.coco.imgs[image_id] img_path = self.image_dir / img_info['file_name'] image = Image.open(img_path).convert('RGB') image_tensor = TF.to_tensor(image) return { 'image': image_tensor, 'image_id': image_id, } def collate_fn(batch): return batch class PatchEvaluator: """Patch评估器""" def __init__(self, ann_file: str, device: str = 'cuda'): self.ann_file = ann_file self.device = device self.coco_gt = COCO(ann_file) print("Loading Mask R-CNN ResNet50 FPN v2...") weights = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT self.detector = maskrcnn_resnet50_fpn_v2(weights=weights) self.detector.eval() self.detector.to(device) # 注意:torchvision Mask R-CNN 输出的 labels 已经是 COCO category_id # 不需要额外映射! @torch.no_grad() def predict(self, images: List[torch.Tensor]) -> List[Dict]: images = [img.to(self.device) for img in images] return self.detector(images) def convert_to_coco_format(self, predictions: List[Dict], image_ids: List[int]) -> Tuple: bbox_results = [] segm_results = [] for pred, image_id in zip(predictions, image_ids): boxes = pred['boxes'].cpu().numpy() labels = pred['labels'].cpu().numpy() scores = pred['scores'].cpu().numpy() masks = pred['masks'].cpu().numpy() for i in range(len(boxes)): if scores[i] < 0.05: # 提高阈值 continue x1, y1, x2, y2 = boxes[i] bbox = [float(x1), float(y1), float(x2 - x1), float(y2 - y1)] category_id = int(labels[i]) # 直接使用,已经是 COCO category_id bbox_results.append({ 'image_id': int(image_id), 'category_id': category_id, 'bbox': bbox, 'score': float(scores[i]), }) mask = masks[i, 0] mask_binary = (mask > 0.5).astype(np.uint8) rle = mask_util.encode(np.asfortranarray(mask_binary)) rle['counts'] = rle['counts'].decode('utf-8') segm_results.append({ 'image_id': int(image_id), 'category_id': category_id, 'segmentation': rle, 'score': float(scores[i]), }) return bbox_results, segm_results def evaluate(self, dataset: PatchDataset, batch_size: int = 8, num_workers: int = 4) -> Dict: dataloader = DataLoader( dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn, ) all_bbox_results = [] all_segm_results = [] print("Running inference...") for batch in tqdm(dataloader): images = [item['image'] for item in batch] image_ids = [item['image_id'] for item in batch] predictions = self.predict(images) bbox_results, segm_results = self.convert_to_coco_format(predictions, image_ids) all_bbox_results.extend(bbox_results) all_segm_results.extend(segm_results) print("\n" + "="*60) print("Object Detection (Bounding Box) Results:") print("="*60) bbox_metrics = self._run_eval(all_bbox_results, 'bbox') print("\n" + "="*60) print("Instance Segmentation (Mask) Results:") print("="*60) segm_metrics = self._run_eval(all_segm_results, 'segm') return { 'AP_bbox': bbox_metrics['AP'], 'AP_bbox_50': bbox_metrics['AP_50'], 'AP_bbox_75': bbox_metrics['AP_75'], 'AP_mask': segm_metrics['AP'], 'AP_mask_50': segm_metrics['AP_50'], 'AP_mask_75': segm_metrics['AP_75'], } def _run_eval(self, results: List[Dict], iou_type: str) -> Dict: if not results: print(f"No predictions for {iou_type}!") return {'AP': 0.0, 'AP_50': 0.0, 'AP_75': 0.0} try: coco_dt = self.coco_gt.loadRes(results) coco_eval = COCOeval(self.coco_gt, coco_dt, iou_type) coco_eval.evaluate() coco_eval.accumulate() coco_eval.summarize() return { 'AP': coco_eval.stats[0], 'AP_50': coco_eval.stats[1], 'AP_75': coco_eval.stats[2], } except KeyError as e: print(f"Error evaluating {iou_type}: {e}") print(f"Some GT annotations may be missing 'segmentation' field.") print(f"Skipping {iou_type} evaluation.") return {'AP': 0.0, 'AP_50': 0.0, 'AP_75': 0.0} def main(): device = CONFIG['device'] if device == 'cuda' and not torch.cuda.is_available(): print("CUDA not available, using CPU") device = 'cpu' # 自动生成 output 路径:根据 sr_dir 最后一个目录名 sr_dir = Path(CONFIG['sr_dir']) baseline_name = sr_dir.name # 获取最后一个目录名,如 'sr', 'gt', 'bicubic' 等 output_path = Path(f"./coco_patch_eval_results_{baseline_name}.json") print(f"Loading SR patches from: {CONFIG['sr_dir']}") print(f"Loading annotations from: {CONFIG['ann_file']}") print(f"Output will be saved to: {output_path}") dataset = PatchDataset( image_dir=CONFIG['sr_dir'], ann_file=CONFIG['ann_file'], ) print(f"Dataset size: {len(dataset)} patches") evaluator = PatchEvaluator(ann_file=CONFIG['ann_file'], device=device) print("\nStarting evaluation...") results = evaluator.evaluate( dataset=dataset, batch_size=CONFIG['batch_size'], num_workers=CONFIG['num_workers'], ) print("\n" + "="*60) print("EVALUATION SUMMARY") print("="*60) print("\nObject Detection (Bounding Box):") print(f" AP^b : {results['AP_bbox']*100:.2f}") print(f" AP^b_50 : {results['AP_bbox_50']*100:.2f}") print(f" AP^b_75 : {results['AP_bbox_75']*100:.2f}") print("\nInstance Segmentation (Mask):") print(f" AP^m : {results['AP_mask']*100:.2f}") print(f" AP^m_50 : {results['AP_mask_50']*100:.2f}") print(f" AP^m_75 : {results['AP_mask_75']*100:.2f}") output_path.parent.mkdir(parents=True, exist_ok=True) with open(output_path, 'w') as f: json.dump(results, f, indent=2) print(f"\nResults saved to {output_path}") if __name__ == '__main__': main()