FBAGSTM's picture
STM32 AI Experimentation Hub
747451d
# /*---------------------------------------------------------------------------------------------
# * Copyright (c) 2025 STMicroelectronics.
# * All rights reserved.
# *
# * This software is licensed under terms that can be found in the LICENSE file in
# * the root directory of this software component.
# * If no LICENSE file comes with this software, it is provided AS-IS.
# *--------------------------------------------------------------------------------------------*/
import os
import torch
import numpy as np
from torch.utils.data import Dataset
import cv2
import logging
from pycocotools.coco import COCO
class CocoDataset(Dataset):
def __init__(
self,
annotations_path="instances_train2017.json", # full path to instances_*.json
images_path="train2017", # full path to image folder
transform=None,
target_transform=None,
skip_annotations=False,
filter_empty_gt=True,
):
self.img_dir = images_path
self.ann_file = annotations_path
self.transform = transform
self.target_transform = target_transform
self.skip_annotations = skip_annotations
if not os.path.exists(self.ann_file):
raise FileNotFoundError(f"COCO ann file not found: {self.ann_file}")
if not os.path.isdir(self.img_dir):
raise NotADirectoryError(f"Image dir not found: {self.img_dir}")
self.coco = COCO(self.ann_file)
# Filter to only images with annotations if annotations are provided
ids = list(self.coco.imgs.keys())
if filter_empty_gt and (not skip_annotations):
kept = []
for img_id in ids:
if len(self.coco.getAnnIds(imgIds=img_id)) > 0:
kept.append(img_id)
self.ids = kept
else:
self.ids = ids
logging.info(
f"CocoDataset: Filtered {len(ids) - len(self.ids)} images without annotations. "
f"Remaining: {len(self.ids)}"
)
# Categories
self.cat_ids = sorted(self.coco.getCatIds())
self.cats = self.coco.loadCats(self.cat_ids)
# Creating a continuous map from 1..N for training (0 is background)
# COCO IDs are sparse (1..90).
self.class_names = ['BACKGROUND'] + [cat['name'] for cat in self.cats]
# Map coco_category_id -> continuous_index (1-based)
self.coco_id_to_continuous_id = {cat_id: i+1 for i, cat_id in enumerate(self.cat_ids)}
self.continuous_id_to_coco_id = {v: k for k, v in self.coco_id_to_continuous_id.items()}
def __getitem__(self, index):
image_id = self.ids[index]
image, boxes, labels = self._getitem(image_id)
if self.transform:
image, boxes, labels = self.transform(image, boxes, labels)
if self.target_transform and not self.skip_annotations:
boxes, labels = self.target_transform(boxes, labels)
return image, boxes, labels
def _getitem(self, image_id):
img_info = self.coco.loadImgs(image_id)[0]
file_name = img_info['file_name']
image_path = os.path.join(self.img_dir, file_name)
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
boxes = []
labels = []
if not self.skip_annotations:
# Load annotations
ann_ids = self.coco.getAnnIds(imgIds=image_id)
anns = self.coco.loadAnns(ann_ids)
for ann in anns:
if 'bbox' not in ann:
continue
x, y, w, h = ann['bbox']
if w <= 0 or h <= 0:
continue
# Convert to [x1, y1, x2, y2]
x1 = x
y1 = y
x2 = x + w
y2 = y + h
boxes.append([x1, y1, x2, y2])
labels.append(self.coco_id_to_continuous_id[ann['category_id']])
boxes = np.array(boxes, dtype=np.float32)
labels = np.array(labels, dtype=np.int64)
if len(boxes) == 0:
boxes = np.zeros((0, 4), dtype=np.float32)
return image, boxes, labels
def __len__(self):
return len(self.ids)
def get_image(self, index):
image_id = self.ids[index]
img_info = self.coco.loadImgs(image_id)[0]
file_name = img_info['file_name']
image_path = os.path.join(self.img_dir, file_name)
image = cv2.imread(image_path)
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
def get_annotation(self, index):
image_id = self.ids[index]
return image_id, self._getitem(image_id)