""" FracAtlas DataLoader for YOLACT+ (ResNet-18 backbone) ====================================================== Provides: - FracAtlasDataset : torch.utils.data.Dataset over COCO-format splits - detection_collate : custom collate for variable-size masks/boxes - get_dataloader : factory function for train / val / test loaders """ import os import os import cv2 import numpy as np import torch import warnings warnings.filterwarnings("ignore", message=".*Premature end.*") warnings.filterwarnings("ignore", message=".*Corrupt JPEG.*") # Suppress OpenCV JPEG warnings try: cv2.setLogLevel(0) except AttributeError: os.environ["OPENCV_LOG_LEVEL"] = "SILENT" from torch.utils.data import Dataset, DataLoader from pycocotools.coco import COCO from pycocotools import mask as coco_mask import albumentations as A from albumentations.pytorch import ToTensorV2 # ─── Augmentation pipelines ─────────────────────────────────────────────────── def get_train_transforms(img_size: int = 550): return A.Compose( [ A.LongestMaxSize(max_size=img_size), A.PadIfNeeded( min_height=img_size, min_width=img_size, fill=0, ), A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.2), A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5), A.GaussNoise(p=0.3), A.Affine( translate_percent=0.05, scale=(0.9, 1.1), rotate=(-10, 10), p=0.4, ), A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ToTensorV2(), ], bbox_params=A.BboxParams( format="pascal_voc", label_fields=["class_labels"], min_visibility=0.3, ), ) def get_val_transforms(img_size: int = 550): return A.Compose( [ A.LongestMaxSize(max_size=img_size), A.PadIfNeeded( min_height=img_size, min_width=img_size, fill=0, ), A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ToTensorV2(), ], bbox_params=A.BboxParams( format="pascal_voc", label_fields=["class_labels"], min_visibility=0.3, ), ) # ─── Dataset ────────────────────────────────────────────────────────────────── class FracAtlasDataset(Dataset): """ COCO-format dataset for FracAtlas fracture detection. Each item returns: image : FloatTensor [3, H, W] (normalised) target : dict with keys boxes : FloatTensor [N, 4] (x1y1x2y2, normalised 0-1) labels : LongTensor [N] masks : FloatTensor [N, H, W] (binary, same spatial size as image) image_id: int """ def __init__( self, image_dir: str, ann_file: str, img_size: int = 550, split: str = "train", ): self.image_dir = image_dir self.img_size = img_size self.split = split self.coco = COCO(ann_file) self.image_ids = sorted(self.coco.imgs.keys()) # Build category → 0-indexed label map # NOTE: FracAtlas uses category_id=0 ('fractured') — handle offset cats = self.coco.loadCats(self.coco.getCatIds()) self.cat_id_to_label = {c["id"]: i for i, c in enumerate(cats)} # If only one class and its id is 0, map it to label 0 if len(cats) == 1 and cats[0]["id"] == 0: self.cat_id_to_label = {0: 0} self.num_classes = len(cats) self.class_names = [c["name"] for c in cats] self.transforms = ( get_train_transforms(img_size) if split == "train" else get_val_transforms(img_size) ) print( f"[{split}] {len(self.image_ids)} images | " f"{self.num_classes} classes: {self.class_names}" ) def __len__(self): return len(self.image_ids) def _decode_mask(self, ann: dict, h: int, w: int) -> np.ndarray: """Decode COCO RLE or polygon segmentation to binary mask.""" seg = ann.get("segmentation", None) if seg is None: # Fall back: create mask from bbox x, y, bw, bh = [int(v) for v in ann["bbox"]] m = np.zeros((h, w), dtype=np.uint8) m[y : y + bh, x : x + bw] = 1 return m if isinstance(seg, dict): # RLE return coco_mask.decode(seg).astype(np.uint8) else: # polygon rle = coco_mask.frPyObjects(seg, h, w) merged = coco_mask.merge(rle) return coco_mask.decode(merged).astype(np.uint8) def __getitem__(self, idx: int): img_id = self.image_ids[idx] img_info = self.coco.imgs[img_id] # ── Load image ──────────────────────────────────────────────────────── img_path = os.path.join(self.image_dir, img_info["file_name"]) image = cv2.imread(img_path) if image is None: raise FileNotFoundError(f"Cannot read image: {img_path}") image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) orig_h, orig_w = image.shape[:2] # ── Load annotations ───────────────────────────────────────────────── ann_ids = self.coco.getAnnIds(imgIds=img_id) anns = self.coco.loadAnns(ann_ids) boxes, class_labels, raw_masks = [], [], [] for ann in anns: x, y, bw, bh = ann["bbox"] x1, y1, x2, y2 = x, y, x + bw, y + bh # Clip to image bounds x1 = max(0.0, x1) y1 = max(0.0, y1) x2 = min(float(orig_w), x2) y2 = min(float(orig_h), y2) if x2 <= x1 or y2 <= y1: continue boxes.append([x1, y1, x2, y2]) class_labels.append(self.cat_id_to_label[ann["category_id"]]) raw_masks.append(self._decode_mask(ann, orig_h, orig_w)) # Non-fractured images: create a dummy background instance so the # tensor shapes are consistent (YOLACT handles empty targets fine too, # but keeping consistent is safer). if len(boxes) == 0: boxes = [[0.0, 0.0, float(orig_w), float(orig_h)]] class_labels = [0] # background / non-fractured raw_masks = [np.zeros((orig_h, orig_w), dtype=np.uint8)] # ── Albumentations ─────────────────────────────────────────────────── transformed = self.transforms( image=image, masks=raw_masks, bboxes=boxes, class_labels=class_labels, ) image_t = transformed["image"] # [3, H, W] boxes_t = transformed["bboxes"] labels_t = transformed["class_labels"] masks_t = transformed["masks"] # list of H×W arrays _, H, W = image_t.shape # ── Build target tensors ───────────────────────────────────────────── if len(boxes_t) == 0: # All boxes removed by augmentation (e.g. min_visibility) boxes_out = torch.zeros((0, 4), dtype=torch.float32) labels_out = torch.zeros((0,), dtype=torch.long) masks_out = torch.zeros((0, H, W), dtype=torch.float32) else: boxes_np = np.array(boxes_t, dtype=np.float32) # Normalise to [0, 1] boxes_np[:, [0, 2]] /= W boxes_np[:, [1, 3]] /= H boxes_np = np.clip(boxes_np, 0.0, 1.0) boxes_out = torch.from_numpy(boxes_np) labels_out = torch.tensor(labels_t, dtype=torch.long) # Albumentations >=2.x returns masks as tensors; older versions return numpy. def to_float_tensor(m): if isinstance(m, torch.Tensor): return m.float() return torch.from_numpy(np.array(m, dtype=np.float32)) masks_out = torch.stack([to_float_tensor(m) for m in masks_t]) target = { "boxes": boxes_out, "labels": labels_out, "masks": masks_out, "image_id": img_id, } return image_t, target # ─── Collate ────────────────────────────────────────────────────────────────── def detection_collate(batch): """ Custom collate for object detection. Images are stacked; targets are kept as a list (variable number of instances). """ images, targets = zip(*batch) images = torch.stack(images, dim=0) # [B, 3, H, W] return images, list(targets) # ─── DataLoader factory ─────────────────────────────────────────────────────── def get_dataloader( image_dir: str, ann_file: str, split: str = "train", img_size: int = 550, batch_size: int = 8, num_workers: int = 4, pin_memory: bool = True, ) -> DataLoader: """ Returns a DataLoader for the given split. Args: image_dir : path to images/ folder for this split ann_file : path to annotations.json for this split split : "train" | "val" | "test" img_size : input resolution fed to the network (default 550 for YOLACT+) batch_size : mini-batch size num_workers: parallel data-loading workers pin_memory : pin CPU memory for faster GPU transfer Returns: torch.utils.data.DataLoader """ dataset = FracAtlasDataset( image_dir=image_dir, ann_file=ann_file, img_size=img_size, split=split, ) shuffle = split == "train" # pin_memory only works when CUDA is available import torch use_pin = pin_memory and torch.cuda.is_available() loader = DataLoader( dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=use_pin, collate_fn=detection_collate, drop_last=(split == "train"), ) return loader