| """ |
| COCO Patch-based Evaluation |
| 直接输入超分后的patches路径,评估Object Detection和Instance Segmentation指标 |
| |
| Metrics: |
| - Object Detection: AP^b, AP^b_50, AP^b_75 |
| - Instance Segmentation: AP^m, AP^m_50, AP^m_75 |
| """ |
|
|
| import os |
| import json |
| from pathlib import Path |
| from typing import Dict, List, Tuple |
|
|
| import torch |
| from torch.utils.data import DataLoader, Dataset |
| from torchvision.models.detection import ( |
| maskrcnn_resnet50_fpn_v2, |
| MaskRCNN_ResNet50_FPN_V2_Weights, |
| ) |
| from torchvision.transforms import functional as TF |
| from PIL import Image |
| import numpy as np |
| from tqdm import tqdm |
| from pycocotools.coco import COCO |
| from pycocotools.cocoeval import COCOeval |
| from pycocotools import mask as mask_util |
|
|
|
|
| |
| |
| |
| CONFIG = { |
| |
| |
| 'sr_dir': '/home/wanghongbo06/baipurui/results/COCO/DreamClear/results/output', |
| |
| |
| 'ann_file': '/home/wanghongbo06/baipurui/DATA/COCO_patch/patch_annotations.json', |
| |
| |
| 'device': 'cuda', |
| 'batch_size': 8, |
| 'num_workers': 4, |
| |
| |
| 'output': './coco_patch_eval_results.json', |
| } |
| |
|
|
|
|
| |
| COCO_CATEGORY_IDS = [ |
| 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, |
| 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, |
| 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, |
| 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, |
| 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, |
| 82, 84, 85, 86, 87, 88, 89, 90 |
| ] |
|
|
|
|
| class PatchDataset(Dataset): |
| """Patch数据集""" |
| |
| def __init__(self, image_dir: str, ann_file: str): |
| self.image_dir = Path(image_dir) |
| self.coco = COCO(ann_file) |
| |
| |
| all_image_ids = list(self.coco.imgs.keys()) |
| self.image_ids = [] |
| missing_count = 0 |
| |
| print("Checking image files...") |
| for img_id in all_image_ids: |
| img_info = self.coco.imgs[img_id] |
| img_path = self.image_dir / img_info['file_name'] |
| if img_path.exists(): |
| self.image_ids.append(img_id) |
| else: |
| missing_count += 1 |
| |
| if missing_count > 0: |
| print(f"Warning: {missing_count} images not found, skipped.") |
| print(f"Valid images: {len(self.image_ids)} / {len(all_image_ids)}") |
| |
| def __len__(self): |
| return len(self.image_ids) |
| |
| def __getitem__(self, idx: int) -> Dict: |
| image_id = self.image_ids[idx] |
| img_info = self.coco.imgs[image_id] |
| img_path = self.image_dir / img_info['file_name'] |
| |
| image = Image.open(img_path).convert('RGB') |
| image_tensor = TF.to_tensor(image) |
| |
| return { |
| 'image': image_tensor, |
| 'image_id': image_id, |
| } |
|
|
|
|
| def collate_fn(batch): |
| return batch |
|
|
|
|
| class PatchEvaluator: |
| """Patch评估器""" |
| |
| def __init__(self, ann_file: str, device: str = 'cuda'): |
| self.ann_file = ann_file |
| self.device = device |
| self.coco_gt = COCO(ann_file) |
| |
| print("Loading Mask R-CNN ResNet50 FPN v2...") |
| weights = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT |
| self.detector = maskrcnn_resnet50_fpn_v2(weights=weights) |
| self.detector.eval() |
| self.detector.to(device) |
| |
| |
| |
| |
| @torch.no_grad() |
| def predict(self, images: List[torch.Tensor]) -> List[Dict]: |
| images = [img.to(self.device) for img in images] |
| return self.detector(images) |
| |
| def convert_to_coco_format(self, predictions: List[Dict], image_ids: List[int]) -> Tuple: |
| bbox_results = [] |
| segm_results = [] |
| |
| for pred, image_id in zip(predictions, image_ids): |
| boxes = pred['boxes'].cpu().numpy() |
| labels = pred['labels'].cpu().numpy() |
| scores = pred['scores'].cpu().numpy() |
| masks = pred['masks'].cpu().numpy() |
| |
| for i in range(len(boxes)): |
| if scores[i] < 0.05: |
| continue |
| |
| x1, y1, x2, y2 = boxes[i] |
| bbox = [float(x1), float(y1), float(x2 - x1), float(y2 - y1)] |
| |
| category_id = int(labels[i]) |
| |
| bbox_results.append({ |
| 'image_id': int(image_id), |
| 'category_id': category_id, |
| 'bbox': bbox, |
| 'score': float(scores[i]), |
| }) |
| |
| mask = masks[i, 0] |
| mask_binary = (mask > 0.5).astype(np.uint8) |
| rle = mask_util.encode(np.asfortranarray(mask_binary)) |
| rle['counts'] = rle['counts'].decode('utf-8') |
| |
| segm_results.append({ |
| 'image_id': int(image_id), |
| 'category_id': category_id, |
| 'segmentation': rle, |
| 'score': float(scores[i]), |
| }) |
| |
| return bbox_results, segm_results |
| |
| def evaluate(self, dataset: PatchDataset, batch_size: int = 8, num_workers: int = 4) -> Dict: |
| dataloader = DataLoader( |
| dataset, batch_size=batch_size, shuffle=False, |
| num_workers=num_workers, collate_fn=collate_fn, |
| ) |
| |
| all_bbox_results = [] |
| all_segm_results = [] |
| |
| print("Running inference...") |
| for batch in tqdm(dataloader): |
| images = [item['image'] for item in batch] |
| image_ids = [item['image_id'] for item in batch] |
| |
| predictions = self.predict(images) |
| bbox_results, segm_results = self.convert_to_coco_format(predictions, image_ids) |
| |
| all_bbox_results.extend(bbox_results) |
| all_segm_results.extend(segm_results) |
| |
| print("\n" + "="*60) |
| print("Object Detection (Bounding Box) Results:") |
| print("="*60) |
| bbox_metrics = self._run_eval(all_bbox_results, 'bbox') |
| |
| print("\n" + "="*60) |
| print("Instance Segmentation (Mask) Results:") |
| print("="*60) |
| segm_metrics = self._run_eval(all_segm_results, 'segm') |
| |
| return { |
| 'AP_bbox': bbox_metrics['AP'], |
| 'AP_bbox_50': bbox_metrics['AP_50'], |
| 'AP_bbox_75': bbox_metrics['AP_75'], |
| 'AP_mask': segm_metrics['AP'], |
| 'AP_mask_50': segm_metrics['AP_50'], |
| 'AP_mask_75': segm_metrics['AP_75'], |
| } |
| |
| def _run_eval(self, results: List[Dict], iou_type: str) -> Dict: |
| if not results: |
| print(f"No predictions for {iou_type}!") |
| return {'AP': 0.0, 'AP_50': 0.0, 'AP_75': 0.0} |
| |
| try: |
| coco_dt = self.coco_gt.loadRes(results) |
| coco_eval = COCOeval(self.coco_gt, coco_dt, iou_type) |
| coco_eval.evaluate() |
| coco_eval.accumulate() |
| coco_eval.summarize() |
| |
| return { |
| 'AP': coco_eval.stats[0], |
| 'AP_50': coco_eval.stats[1], |
| 'AP_75': coco_eval.stats[2], |
| } |
| except KeyError as e: |
| print(f"Error evaluating {iou_type}: {e}") |
| print(f"Some GT annotations may be missing 'segmentation' field.") |
| print(f"Skipping {iou_type} evaluation.") |
| return {'AP': 0.0, 'AP_50': 0.0, 'AP_75': 0.0} |
|
|
|
|
| def main(): |
| device = CONFIG['device'] |
| if device == 'cuda' and not torch.cuda.is_available(): |
| print("CUDA not available, using CPU") |
| device = 'cpu' |
| |
| |
| sr_dir = Path(CONFIG['sr_dir']) |
| baseline_name = sr_dir.name |
| output_path = Path(f"./coco_patch_eval_results_{baseline_name}.json") |
| |
| print(f"Loading SR patches from: {CONFIG['sr_dir']}") |
| print(f"Loading annotations from: {CONFIG['ann_file']}") |
| print(f"Output will be saved to: {output_path}") |
| |
| dataset = PatchDataset( |
| image_dir=CONFIG['sr_dir'], |
| ann_file=CONFIG['ann_file'], |
| ) |
| print(f"Dataset size: {len(dataset)} patches") |
| |
| evaluator = PatchEvaluator(ann_file=CONFIG['ann_file'], device=device) |
| |
| print("\nStarting evaluation...") |
| results = evaluator.evaluate( |
| dataset=dataset, |
| batch_size=CONFIG['batch_size'], |
| num_workers=CONFIG['num_workers'], |
| ) |
| |
| print("\n" + "="*60) |
| print("EVALUATION SUMMARY") |
| print("="*60) |
| print("\nObject Detection (Bounding Box):") |
| print(f" AP^b : {results['AP_bbox']*100:.2f}") |
| print(f" AP^b_50 : {results['AP_bbox_50']*100:.2f}") |
| print(f" AP^b_75 : {results['AP_bbox_75']*100:.2f}") |
| print("\nInstance Segmentation (Mask):") |
| print(f" AP^m : {results['AP_mask']*100:.2f}") |
| print(f" AP^m_50 : {results['AP_mask_50']*100:.2f}") |
| print(f" AP^m_75 : {results['AP_mask_75']*100:.2f}") |
| |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| with open(output_path, 'w') as f: |
| json.dump(results, f, indent=2) |
| print(f"\nResults saved to {output_path}") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|