""" 验证脚本:检查 annotations 是否正确,以及在原始 COCO 上测试检测器 """ import os import json from pathlib import Path import matplotlib.pyplot as plt import matplotlib.patches as patches from PIL import Image import numpy as np import torch from torchvision.models.detection import maskrcnn_resnet50_fpn_v2, MaskRCNN_ResNet50_FPN_V2_Weights from torchvision.transforms import functional as TF from pycocotools.coco import COCO from pycocotools.cocoeval import COCOeval # ============================================================================ # 配置 # ============================================================================ CONFIG = { # 原始 COCO 数据 'original_coco_dir': '/home/wanghongbo06/baipurui/DATA/COCO/val2017', 'original_ann_file': '/home/wanghongbo06/baipurui/DATA/COCO/annotations/instances_val2017.json', # prepare 生成的数据 'patch_gt_dir': '/home/wanghongbo06/baipurui/DATA/COCO_patch/gt', 'patch_ann_file': '/home/wanghongbo06/baipurui/DATA/COCO_patch/patch_annotations.json', 'device': 'cuda', } # ============================================================================ COCO_CATEGORY_IDS = [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90 ] def visualize_annotations(ann_file, image_dir, num_images=5): """可视化 GT annotations""" print("="*60) print("检查 Annotations 格式") print("="*60) coco = COCO(ann_file) image_ids = list(coco.imgs.keys())[:num_images] fig, axes = plt.subplots(1, num_images, figsize=(4*num_images, 4)) if num_images == 1: axes = [axes] for ax, img_id in zip(axes, image_ids): img_info = coco.imgs[img_id] img_path = Path(image_dir) / img_info['file_name'] if not img_path.exists(): print(f"图片不存在: {img_path}") continue img = Image.open(img_path) ax.imshow(img) ax.set_title(f"ID: {img_id}\n{img_info['file_name']}") # 绘制 GT bbox ann_ids = coco.getAnnIds(imgIds=img_id) anns = coco.loadAnns(ann_ids) for ann in anns: bbox = ann['bbox'] # [x, y, w, h] rect = patches.Rectangle( (bbox[0], bbox[1]), bbox[2], bbox[3], linewidth=2, edgecolor='r', facecolor='none' ) ax.add_patch(rect) ax.axis('off') print(f"Image {img_id}: {len(anns)} annotations") plt.tight_layout() plt.savefig('verify_annotations.png', dpi=150) print(f"\n可视化保存到: verify_annotations.png") def check_annotation_format(ann_file): """检查 annotation 格式""" print("\n" + "="*60) print("检查 Annotation JSON 格式") print("="*60) with open(ann_file, 'r') as f: data = json.load(f) print(f"Images 数量: {len(data['images'])}") print(f"Annotations 数量: {len(data['annotations'])}") print(f"Categories 数量: {len(data['categories'])}") # 检查 image ids image_ids = [img['id'] for img in data['images']] print(f"\nImage ID 范围: {min(image_ids)} - {max(image_ids)}") print(f"Image ID 是否连续: {image_ids == list(range(len(image_ids)))}") # 检查 annotation 的 image_id 是否都有对应的 image ann_image_ids = set(ann['image_id'] for ann in data['annotations']) missing = ann_image_ids - set(image_ids) if missing: print(f"警告: {len(missing)} 个 annotation 的 image_id 找不到对应的 image!") else: print("所有 annotation 的 image_id 都有对应的 image ✓") # 检查 category_ids ann_cat_ids = set(ann['category_id'] for ann in data['annotations']) print(f"\n使用的 category_id: {sorted(ann_cat_ids)[:10]}... (共 {len(ann_cat_ids)} 个)") # 检查第一个 annotation if data['annotations']: print("\n第一个 annotation 示例:") ann = data['annotations'][0] for key, value in ann.items(): if key == 'segmentation': print(f" {key}: {type(value)}, 长度={len(value) if value else 0}") else: print(f" {key}: {value}") def test_on_original_coco(coco_dir, ann_file, device='cuda', num_images=100): """在原始 COCO 上测试 Mask R-CNN""" print("\n" + "="*60) print("在原始 COCO 上测试 Mask R-CNN") print("="*60) coco = COCO(ann_file) image_ids = list(coco.imgs.keys())[:num_images] print(f"测试图片数: {num_images}") # 加载模型 print("加载 Mask R-CNN...") weights = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT model = maskrcnn_resnet50_fpn_v2(weights=weights) model.eval() model.to(device) # 注意:torchvision 的 Mask R-CNN 输出的 labels 已经是 COCO category_id # 不需要额外映射! bbox_results = [] print("运行检测...") from tqdm import tqdm for img_id in tqdm(image_ids): img_info = coco.imgs[img_id] img_path = Path(coco_dir) / img_info['file_name'] if not img_path.exists(): continue img = Image.open(img_path).convert('RGB') img_tensor = TF.to_tensor(img).unsqueeze(0).to(device) with torch.no_grad(): outputs = model(img_tensor)[0] boxes = outputs['boxes'].cpu().numpy() labels = outputs['labels'].cpu().numpy() scores = outputs['scores'].cpu().numpy() for i in range(len(boxes)): if scores[i] < 0.05: # 提高阈值过滤低置信度 continue x1, y1, x2, y2 = boxes[i] category_id = int(labels[i]) # 直接使用,已经是 COCO category_id bbox_results.append({ 'image_id': img_id, 'category_id': category_id, 'bbox': [float(x1), float(y1), float(x2-x1), float(y2-y1)], 'score': float(scores[i]), }) print(f"检测到 {len(bbox_results)} 个物体") # 评估 if bbox_results: coco_dt = coco.loadRes(bbox_results) coco_eval = COCOeval(coco, coco_dt, 'bbox') coco_eval.params.imgIds = image_ids coco_eval.evaluate() coco_eval.accumulate() coco_eval.summarize() print(f"\n原始 COCO 上的 AP: {coco_eval.stats[0]*100:.2f}%") return coco_eval.stats[0] else: print("没有检测到任何物体!") return 0 def main(): print("="*60) print("Annotations 验证工具") print("="*60) # 1. 检查 annotation 格式 if Path(CONFIG['patch_ann_file']).exists(): check_annotation_format(CONFIG['patch_ann_file']) # 2. 可视化 annotations if Path(CONFIG['patch_gt_dir']).exists() and Path(CONFIG['patch_ann_file']).exists(): visualize_annotations( CONFIG['patch_ann_file'], CONFIG['patch_gt_dir'], num_images=5 ) # 3. 在原始 COCO 上测试(作为 baseline) if Path(CONFIG['original_coco_dir']).exists(): device = CONFIG['device'] if device == 'cuda' and not torch.cuda.is_available(): device = 'cpu' ap = test_on_original_coco( CONFIG['original_coco_dir'], CONFIG['original_ann_file'], device=device, num_images=100 # 测试 100 张原图 ) print("\n" + "="*60) print("结论") print("="*60) if ap > 0.3: print(f"原始 COCO 上 AP={ap*100:.1f}%,检测器正常") print("问题可能在于 resize 后的 annotations 转换") else: print(f"原始 COCO 上 AP={ap*100:.1f}%,检测器或评估代码可能有问题") if __name__ == '__main__': main()