kuai / diffusion-dpo-coco-ade /test_coco_patch.py
Larer's picture
Add files using upload-large-folder tool
2214a66
"""
COCO Patch-based Evaluation
直接输入超分后的patches路径,评估Object Detection和Instance Segmentation指标
Metrics:
- Object Detection: AP^b, AP^b_50, AP^b_75
- Instance Segmentation: AP^m, AP^m_50, AP^m_75
"""
import os
import json
from pathlib import Path
from typing import Dict, List, Tuple
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision.models.detection import (
maskrcnn_resnet50_fpn_v2,
MaskRCNN_ResNet50_FPN_V2_Weights,
)
from torchvision.transforms import functional as TF
from PIL import Image
import numpy as np
from tqdm import tqdm
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from pycocotools import mask as mask_util
# ============================================================================
# 配置参数 - 请修改这里
# ============================================================================
CONFIG = {
# SR patches 目录 (你的超分模型输出,512x512)
# 图片命名需与 prepare_coco_patches.py 生成的一致
'sr_dir': '/home/wanghongbo06/baipurui/results/COCO/DreamClear/results/output',
# Patch annotations 文件 (prepare_coco_patches.py 生成的)
'ann_file': '/home/wanghongbo06/baipurui/DATA/COCO_patch/patch_annotations.json',
# 推理配置
'device': 'cuda',
'batch_size': 8,
'num_workers': 4,
# 输出
'output': './coco_patch_eval_results.json',
}
# ============================================================================
# COCO类别ID映射
COCO_CATEGORY_IDS = [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49,
50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
82, 84, 85, 86, 87, 88, 89, 90
]
class PatchDataset(Dataset):
"""Patch数据集"""
def __init__(self, image_dir: str, ann_file: str):
self.image_dir = Path(image_dir)
self.coco = COCO(ann_file)
# 过滤出存在的图片
all_image_ids = list(self.coco.imgs.keys())
self.image_ids = []
missing_count = 0
print("Checking image files...")
for img_id in all_image_ids:
img_info = self.coco.imgs[img_id]
img_path = self.image_dir / img_info['file_name']
if img_path.exists():
self.image_ids.append(img_id)
else:
missing_count += 1
if missing_count > 0:
print(f"Warning: {missing_count} images not found, skipped.")
print(f"Valid images: {len(self.image_ids)} / {len(all_image_ids)}")
def __len__(self):
return len(self.image_ids)
def __getitem__(self, idx: int) -> Dict:
image_id = self.image_ids[idx]
img_info = self.coco.imgs[image_id]
img_path = self.image_dir / img_info['file_name']
image = Image.open(img_path).convert('RGB')
image_tensor = TF.to_tensor(image)
return {
'image': image_tensor,
'image_id': image_id,
}
def collate_fn(batch):
return batch
class PatchEvaluator:
"""Patch评估器"""
def __init__(self, ann_file: str, device: str = 'cuda'):
self.ann_file = ann_file
self.device = device
self.coco_gt = COCO(ann_file)
print("Loading Mask R-CNN ResNet50 FPN v2...")
weights = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT
self.detector = maskrcnn_resnet50_fpn_v2(weights=weights)
self.detector.eval()
self.detector.to(device)
# 注意:torchvision Mask R-CNN 输出的 labels 已经是 COCO category_id
# 不需要额外映射!
@torch.no_grad()
def predict(self, images: List[torch.Tensor]) -> List[Dict]:
images = [img.to(self.device) for img in images]
return self.detector(images)
def convert_to_coco_format(self, predictions: List[Dict], image_ids: List[int]) -> Tuple:
bbox_results = []
segm_results = []
for pred, image_id in zip(predictions, image_ids):
boxes = pred['boxes'].cpu().numpy()
labels = pred['labels'].cpu().numpy()
scores = pred['scores'].cpu().numpy()
masks = pred['masks'].cpu().numpy()
for i in range(len(boxes)):
if scores[i] < 0.05: # 提高阈值
continue
x1, y1, x2, y2 = boxes[i]
bbox = [float(x1), float(y1), float(x2 - x1), float(y2 - y1)]
category_id = int(labels[i]) # 直接使用,已经是 COCO category_id
bbox_results.append({
'image_id': int(image_id),
'category_id': category_id,
'bbox': bbox,
'score': float(scores[i]),
})
mask = masks[i, 0]
mask_binary = (mask > 0.5).astype(np.uint8)
rle = mask_util.encode(np.asfortranarray(mask_binary))
rle['counts'] = rle['counts'].decode('utf-8')
segm_results.append({
'image_id': int(image_id),
'category_id': category_id,
'segmentation': rle,
'score': float(scores[i]),
})
return bbox_results, segm_results
def evaluate(self, dataset: PatchDataset, batch_size: int = 8, num_workers: int = 4) -> Dict:
dataloader = DataLoader(
dataset, batch_size=batch_size, shuffle=False,
num_workers=num_workers, collate_fn=collate_fn,
)
all_bbox_results = []
all_segm_results = []
print("Running inference...")
for batch in tqdm(dataloader):
images = [item['image'] for item in batch]
image_ids = [item['image_id'] for item in batch]
predictions = self.predict(images)
bbox_results, segm_results = self.convert_to_coco_format(predictions, image_ids)
all_bbox_results.extend(bbox_results)
all_segm_results.extend(segm_results)
print("\n" + "="*60)
print("Object Detection (Bounding Box) Results:")
print("="*60)
bbox_metrics = self._run_eval(all_bbox_results, 'bbox')
print("\n" + "="*60)
print("Instance Segmentation (Mask) Results:")
print("="*60)
segm_metrics = self._run_eval(all_segm_results, 'segm')
return {
'AP_bbox': bbox_metrics['AP'],
'AP_bbox_50': bbox_metrics['AP_50'],
'AP_bbox_75': bbox_metrics['AP_75'],
'AP_mask': segm_metrics['AP'],
'AP_mask_50': segm_metrics['AP_50'],
'AP_mask_75': segm_metrics['AP_75'],
}
def _run_eval(self, results: List[Dict], iou_type: str) -> Dict:
if not results:
print(f"No predictions for {iou_type}!")
return {'AP': 0.0, 'AP_50': 0.0, 'AP_75': 0.0}
try:
coco_dt = self.coco_gt.loadRes(results)
coco_eval = COCOeval(self.coco_gt, coco_dt, iou_type)
coco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()
return {
'AP': coco_eval.stats[0],
'AP_50': coco_eval.stats[1],
'AP_75': coco_eval.stats[2],
}
except KeyError as e:
print(f"Error evaluating {iou_type}: {e}")
print(f"Some GT annotations may be missing 'segmentation' field.")
print(f"Skipping {iou_type} evaluation.")
return {'AP': 0.0, 'AP_50': 0.0, 'AP_75': 0.0}
def main():
device = CONFIG['device']
if device == 'cuda' and not torch.cuda.is_available():
print("CUDA not available, using CPU")
device = 'cpu'
# 自动生成 output 路径:根据 sr_dir 最后一个目录名
sr_dir = Path(CONFIG['sr_dir'])
baseline_name = sr_dir.name # 获取最后一个目录名,如 'sr', 'gt', 'bicubic' 等
output_path = Path(f"./coco_patch_eval_results_{baseline_name}.json")
print(f"Loading SR patches from: {CONFIG['sr_dir']}")
print(f"Loading annotations from: {CONFIG['ann_file']}")
print(f"Output will be saved to: {output_path}")
dataset = PatchDataset(
image_dir=CONFIG['sr_dir'],
ann_file=CONFIG['ann_file'],
)
print(f"Dataset size: {len(dataset)} patches")
evaluator = PatchEvaluator(ann_file=CONFIG['ann_file'], device=device)
print("\nStarting evaluation...")
results = evaluator.evaluate(
dataset=dataset,
batch_size=CONFIG['batch_size'],
num_workers=CONFIG['num_workers'],
)
print("\n" + "="*60)
print("EVALUATION SUMMARY")
print("="*60)
print("\nObject Detection (Bounding Box):")
print(f" AP^b : {results['AP_bbox']*100:.2f}")
print(f" AP^b_50 : {results['AP_bbox_50']*100:.2f}")
print(f" AP^b_75 : {results['AP_bbox_75']*100:.2f}")
print("\nInstance Segmentation (Mask):")
print(f" AP^m : {results['AP_mask']*100:.2f}")
print(f" AP^m_50 : {results['AP_mask_50']*100:.2f}")
print(f" AP^m_75 : {results['AP_mask_75']*100:.2f}")
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w') as f:
json.dump(results, f, indent=2)
print(f"\nResults saved to {output_path}")
if __name__ == '__main__':
main()