| """ |
| 验证脚本:检查 annotations 是否正确,以及在原始 COCO 上测试检测器 |
| """ |
|
|
| import os |
| import json |
| from pathlib import Path |
| import matplotlib.pyplot as plt |
| import matplotlib.patches as patches |
| from PIL import Image |
| import numpy as np |
|
|
| import torch |
| from torchvision.models.detection import maskrcnn_resnet50_fpn_v2, MaskRCNN_ResNet50_FPN_V2_Weights |
| from torchvision.transforms import functional as TF |
| from pycocotools.coco import COCO |
| from pycocotools.cocoeval import COCOeval |
|
|
|
|
| |
| |
| |
| CONFIG = { |
| |
| 'original_coco_dir': '/home/wanghongbo06/baipurui/DATA/COCO/val2017', |
| 'original_ann_file': '/home/wanghongbo06/baipurui/DATA/COCO/annotations/instances_val2017.json', |
| |
| |
| 'patch_gt_dir': '/home/wanghongbo06/baipurui/DATA/COCO_patch/gt', |
| 'patch_ann_file': '/home/wanghongbo06/baipurui/DATA/COCO_patch/patch_annotations.json', |
| |
| 'device': 'cuda', |
| } |
| |
|
|
| COCO_CATEGORY_IDS = [ |
| 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, |
| 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, |
| 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, |
| 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, |
| 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, |
| 82, 84, 85, 86, 87, 88, 89, 90 |
| ] |
|
|
|
|
| def visualize_annotations(ann_file, image_dir, num_images=5): |
| """可视化 GT annotations""" |
| print("="*60) |
| print("检查 Annotations 格式") |
| print("="*60) |
| |
| coco = COCO(ann_file) |
| image_ids = list(coco.imgs.keys())[:num_images] |
| |
| fig, axes = plt.subplots(1, num_images, figsize=(4*num_images, 4)) |
| if num_images == 1: |
| axes = [axes] |
| |
| for ax, img_id in zip(axes, image_ids): |
| img_info = coco.imgs[img_id] |
| img_path = Path(image_dir) / img_info['file_name'] |
| |
| if not img_path.exists(): |
| print(f"图片不存在: {img_path}") |
| continue |
| |
| img = Image.open(img_path) |
| ax.imshow(img) |
| ax.set_title(f"ID: {img_id}\n{img_info['file_name']}") |
| |
| |
| ann_ids = coco.getAnnIds(imgIds=img_id) |
| anns = coco.loadAnns(ann_ids) |
| |
| for ann in anns: |
| bbox = ann['bbox'] |
| rect = patches.Rectangle( |
| (bbox[0], bbox[1]), bbox[2], bbox[3], |
| linewidth=2, edgecolor='r', facecolor='none' |
| ) |
| ax.add_patch(rect) |
| |
| ax.axis('off') |
| print(f"Image {img_id}: {len(anns)} annotations") |
| |
| plt.tight_layout() |
| plt.savefig('verify_annotations.png', dpi=150) |
| print(f"\n可视化保存到: verify_annotations.png") |
|
|
|
|
| def check_annotation_format(ann_file): |
| """检查 annotation 格式""" |
| print("\n" + "="*60) |
| print("检查 Annotation JSON 格式") |
| print("="*60) |
| |
| with open(ann_file, 'r') as f: |
| data = json.load(f) |
| |
| print(f"Images 数量: {len(data['images'])}") |
| print(f"Annotations 数量: {len(data['annotations'])}") |
| print(f"Categories 数量: {len(data['categories'])}") |
| |
| |
| image_ids = [img['id'] for img in data['images']] |
| print(f"\nImage ID 范围: {min(image_ids)} - {max(image_ids)}") |
| print(f"Image ID 是否连续: {image_ids == list(range(len(image_ids)))}") |
| |
| |
| ann_image_ids = set(ann['image_id'] for ann in data['annotations']) |
| missing = ann_image_ids - set(image_ids) |
| if missing: |
| print(f"警告: {len(missing)} 个 annotation 的 image_id 找不到对应的 image!") |
| else: |
| print("所有 annotation 的 image_id 都有对应的 image ✓") |
| |
| |
| ann_cat_ids = set(ann['category_id'] for ann in data['annotations']) |
| print(f"\n使用的 category_id: {sorted(ann_cat_ids)[:10]}... (共 {len(ann_cat_ids)} 个)") |
| |
| |
| if data['annotations']: |
| print("\n第一个 annotation 示例:") |
| ann = data['annotations'][0] |
| for key, value in ann.items(): |
| if key == 'segmentation': |
| print(f" {key}: {type(value)}, 长度={len(value) if value else 0}") |
| else: |
| print(f" {key}: {value}") |
|
|
|
|
| def test_on_original_coco(coco_dir, ann_file, device='cuda', num_images=100): |
| """在原始 COCO 上测试 Mask R-CNN""" |
| print("\n" + "="*60) |
| print("在原始 COCO 上测试 Mask R-CNN") |
| print("="*60) |
| |
| coco = COCO(ann_file) |
| image_ids = list(coco.imgs.keys())[:num_images] |
| |
| print(f"测试图片数: {num_images}") |
| |
| |
| print("加载 Mask R-CNN...") |
| weights = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT |
| model = maskrcnn_resnet50_fpn_v2(weights=weights) |
| model.eval() |
| model.to(device) |
| |
| |
| |
| |
| bbox_results = [] |
| |
| print("运行检测...") |
| from tqdm import tqdm |
| for img_id in tqdm(image_ids): |
| img_info = coco.imgs[img_id] |
| img_path = Path(coco_dir) / img_info['file_name'] |
| |
| if not img_path.exists(): |
| continue |
| |
| img = Image.open(img_path).convert('RGB') |
| img_tensor = TF.to_tensor(img).unsqueeze(0).to(device) |
| |
| with torch.no_grad(): |
| outputs = model(img_tensor)[0] |
| |
| boxes = outputs['boxes'].cpu().numpy() |
| labels = outputs['labels'].cpu().numpy() |
| scores = outputs['scores'].cpu().numpy() |
| |
| for i in range(len(boxes)): |
| if scores[i] < 0.05: |
| continue |
| |
| x1, y1, x2, y2 = boxes[i] |
| category_id = int(labels[i]) |
| |
| bbox_results.append({ |
| 'image_id': img_id, |
| 'category_id': category_id, |
| 'bbox': [float(x1), float(y1), float(x2-x1), float(y2-y1)], |
| 'score': float(scores[i]), |
| }) |
| |
| print(f"检测到 {len(bbox_results)} 个物体") |
| |
| |
| if bbox_results: |
| coco_dt = coco.loadRes(bbox_results) |
| coco_eval = COCOeval(coco, coco_dt, 'bbox') |
| coco_eval.params.imgIds = image_ids |
| coco_eval.evaluate() |
| coco_eval.accumulate() |
| coco_eval.summarize() |
| |
| print(f"\n原始 COCO 上的 AP: {coco_eval.stats[0]*100:.2f}%") |
| return coco_eval.stats[0] |
| else: |
| print("没有检测到任何物体!") |
| return 0 |
|
|
|
|
| def main(): |
| print("="*60) |
| print("Annotations 验证工具") |
| print("="*60) |
| |
| |
| if Path(CONFIG['patch_ann_file']).exists(): |
| check_annotation_format(CONFIG['patch_ann_file']) |
| |
| |
| if Path(CONFIG['patch_gt_dir']).exists() and Path(CONFIG['patch_ann_file']).exists(): |
| visualize_annotations( |
| CONFIG['patch_ann_file'], |
| CONFIG['patch_gt_dir'], |
| num_images=5 |
| ) |
| |
| |
| if Path(CONFIG['original_coco_dir']).exists(): |
| device = CONFIG['device'] |
| if device == 'cuda' and not torch.cuda.is_available(): |
| device = 'cpu' |
| |
| ap = test_on_original_coco( |
| CONFIG['original_coco_dir'], |
| CONFIG['original_ann_file'], |
| device=device, |
| num_images=100 |
| ) |
| |
| print("\n" + "="*60) |
| print("结论") |
| print("="*60) |
| if ap > 0.3: |
| print(f"原始 COCO 上 AP={ap*100:.1f}%,检测器正常") |
| print("问题可能在于 resize 后的 annotations 转换") |
| else: |
| print(f"原始 COCO 上 AP={ap*100:.1f}%,检测器或评估代码可能有问题") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|
|
|