kuai / diffusion-dpo-coco-ade /verify_annotations.py
Larer's picture
Add files using upload-large-folder tool
2214a66
"""
验证脚本:检查 annotations 是否正确,以及在原始 COCO 上测试检测器
"""
import os
import json
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
import numpy as np
import torch
from torchvision.models.detection import maskrcnn_resnet50_fpn_v2, MaskRCNN_ResNet50_FPN_V2_Weights
from torchvision.transforms import functional as TF
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
# ============================================================================
# 配置
# ============================================================================
CONFIG = {
# 原始 COCO 数据
'original_coco_dir': '/home/wanghongbo06/baipurui/DATA/COCO/val2017',
'original_ann_file': '/home/wanghongbo06/baipurui/DATA/COCO/annotations/instances_val2017.json',
# prepare 生成的数据
'patch_gt_dir': '/home/wanghongbo06/baipurui/DATA/COCO_patch/gt',
'patch_ann_file': '/home/wanghongbo06/baipurui/DATA/COCO_patch/patch_annotations.json',
'device': 'cuda',
}
# ============================================================================
COCO_CATEGORY_IDS = [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49,
50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
82, 84, 85, 86, 87, 88, 89, 90
]
def visualize_annotations(ann_file, image_dir, num_images=5):
"""可视化 GT annotations"""
print("="*60)
print("检查 Annotations 格式")
print("="*60)
coco = COCO(ann_file)
image_ids = list(coco.imgs.keys())[:num_images]
fig, axes = plt.subplots(1, num_images, figsize=(4*num_images, 4))
if num_images == 1:
axes = [axes]
for ax, img_id in zip(axes, image_ids):
img_info = coco.imgs[img_id]
img_path = Path(image_dir) / img_info['file_name']
if not img_path.exists():
print(f"图片不存在: {img_path}")
continue
img = Image.open(img_path)
ax.imshow(img)
ax.set_title(f"ID: {img_id}\n{img_info['file_name']}")
# 绘制 GT bbox
ann_ids = coco.getAnnIds(imgIds=img_id)
anns = coco.loadAnns(ann_ids)
for ann in anns:
bbox = ann['bbox'] # [x, y, w, h]
rect = patches.Rectangle(
(bbox[0], bbox[1]), bbox[2], bbox[3],
linewidth=2, edgecolor='r', facecolor='none'
)
ax.add_patch(rect)
ax.axis('off')
print(f"Image {img_id}: {len(anns)} annotations")
plt.tight_layout()
plt.savefig('verify_annotations.png', dpi=150)
print(f"\n可视化保存到: verify_annotations.png")
def check_annotation_format(ann_file):
"""检查 annotation 格式"""
print("\n" + "="*60)
print("检查 Annotation JSON 格式")
print("="*60)
with open(ann_file, 'r') as f:
data = json.load(f)
print(f"Images 数量: {len(data['images'])}")
print(f"Annotations 数量: {len(data['annotations'])}")
print(f"Categories 数量: {len(data['categories'])}")
# 检查 image ids
image_ids = [img['id'] for img in data['images']]
print(f"\nImage ID 范围: {min(image_ids)} - {max(image_ids)}")
print(f"Image ID 是否连续: {image_ids == list(range(len(image_ids)))}")
# 检查 annotation 的 image_id 是否都有对应的 image
ann_image_ids = set(ann['image_id'] for ann in data['annotations'])
missing = ann_image_ids - set(image_ids)
if missing:
print(f"警告: {len(missing)} 个 annotation 的 image_id 找不到对应的 image!")
else:
print("所有 annotation 的 image_id 都有对应的 image ✓")
# 检查 category_ids
ann_cat_ids = set(ann['category_id'] for ann in data['annotations'])
print(f"\n使用的 category_id: {sorted(ann_cat_ids)[:10]}... (共 {len(ann_cat_ids)} 个)")
# 检查第一个 annotation
if data['annotations']:
print("\n第一个 annotation 示例:")
ann = data['annotations'][0]
for key, value in ann.items():
if key == 'segmentation':
print(f" {key}: {type(value)}, 长度={len(value) if value else 0}")
else:
print(f" {key}: {value}")
def test_on_original_coco(coco_dir, ann_file, device='cuda', num_images=100):
"""在原始 COCO 上测试 Mask R-CNN"""
print("\n" + "="*60)
print("在原始 COCO 上测试 Mask R-CNN")
print("="*60)
coco = COCO(ann_file)
image_ids = list(coco.imgs.keys())[:num_images]
print(f"测试图片数: {num_images}")
# 加载模型
print("加载 Mask R-CNN...")
weights = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT
model = maskrcnn_resnet50_fpn_v2(weights=weights)
model.eval()
model.to(device)
# 注意:torchvision 的 Mask R-CNN 输出的 labels 已经是 COCO category_id
# 不需要额外映射!
bbox_results = []
print("运行检测...")
from tqdm import tqdm
for img_id in tqdm(image_ids):
img_info = coco.imgs[img_id]
img_path = Path(coco_dir) / img_info['file_name']
if not img_path.exists():
continue
img = Image.open(img_path).convert('RGB')
img_tensor = TF.to_tensor(img).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(img_tensor)[0]
boxes = outputs['boxes'].cpu().numpy()
labels = outputs['labels'].cpu().numpy()
scores = outputs['scores'].cpu().numpy()
for i in range(len(boxes)):
if scores[i] < 0.05: # 提高阈值过滤低置信度
continue
x1, y1, x2, y2 = boxes[i]
category_id = int(labels[i]) # 直接使用,已经是 COCO category_id
bbox_results.append({
'image_id': img_id,
'category_id': category_id,
'bbox': [float(x1), float(y1), float(x2-x1), float(y2-y1)],
'score': float(scores[i]),
})
print(f"检测到 {len(bbox_results)} 个物体")
# 评估
if bbox_results:
coco_dt = coco.loadRes(bbox_results)
coco_eval = COCOeval(coco, coco_dt, 'bbox')
coco_eval.params.imgIds = image_ids
coco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()
print(f"\n原始 COCO 上的 AP: {coco_eval.stats[0]*100:.2f}%")
return coco_eval.stats[0]
else:
print("没有检测到任何物体!")
return 0
def main():
print("="*60)
print("Annotations 验证工具")
print("="*60)
# 1. 检查 annotation 格式
if Path(CONFIG['patch_ann_file']).exists():
check_annotation_format(CONFIG['patch_ann_file'])
# 2. 可视化 annotations
if Path(CONFIG['patch_gt_dir']).exists() and Path(CONFIG['patch_ann_file']).exists():
visualize_annotations(
CONFIG['patch_ann_file'],
CONFIG['patch_gt_dir'],
num_images=5
)
# 3. 在原始 COCO 上测试(作为 baseline)
if Path(CONFIG['original_coco_dir']).exists():
device = CONFIG['device']
if device == 'cuda' and not torch.cuda.is_available():
device = 'cpu'
ap = test_on_original_coco(
CONFIG['original_coco_dir'],
CONFIG['original_ann_file'],
device=device,
num_images=100 # 测试 100 张原图
)
print("\n" + "="*60)
print("结论")
print("="*60)
if ap > 0.3:
print(f"原始 COCO 上 AP={ap*100:.1f}%,检测器正常")
print("问题可能在于 resize 后的 annotations 转换")
else:
print(f"原始 COCO 上 AP={ap*100:.1f}%,检测器或评估代码可能有问题")
if __name__ == '__main__':
main()