Spaces:
Running
Running
| """ | |
| FracAtlas DataLoader for YOLACT+ (ResNet-18 backbone) | |
| ====================================================== | |
| Provides: | |
| - FracAtlasDataset : torch.utils.data.Dataset over COCO-format splits | |
| - detection_collate : custom collate for variable-size masks/boxes | |
| - get_dataloader : factory function for train / val / test loaders | |
| """ | |
| import os | |
| import os | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import warnings | |
| warnings.filterwarnings("ignore", message=".*Premature end.*") | |
| warnings.filterwarnings("ignore", message=".*Corrupt JPEG.*") | |
| # Suppress OpenCV JPEG warnings | |
| try: | |
| cv2.setLogLevel(0) | |
| except AttributeError: | |
| os.environ["OPENCV_LOG_LEVEL"] = "SILENT" | |
| from torch.utils.data import Dataset, DataLoader | |
| from pycocotools.coco import COCO | |
| from pycocotools import mask as coco_mask | |
| import albumentations as A | |
| from albumentations.pytorch import ToTensorV2 | |
| # βββ Augmentation pipelines βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_train_transforms(img_size: int = 550): | |
| return A.Compose( | |
| [ | |
| A.LongestMaxSize(max_size=img_size), | |
| A.PadIfNeeded( | |
| min_height=img_size, | |
| min_width=img_size, | |
| fill=0, | |
| ), | |
| A.HorizontalFlip(p=0.5), | |
| A.VerticalFlip(p=0.2), | |
| A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5), | |
| A.GaussNoise(p=0.3), | |
| A.Affine( | |
| translate_percent=0.05, | |
| scale=(0.9, 1.1), | |
| rotate=(-10, 10), | |
| p=0.4, | |
| ), | |
| A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), | |
| ToTensorV2(), | |
| ], | |
| bbox_params=A.BboxParams( | |
| format="pascal_voc", | |
| label_fields=["class_labels"], | |
| min_visibility=0.3, | |
| ), | |
| ) | |
| def get_val_transforms(img_size: int = 550): | |
| return A.Compose( | |
| [ | |
| A.LongestMaxSize(max_size=img_size), | |
| A.PadIfNeeded( | |
| min_height=img_size, | |
| min_width=img_size, | |
| fill=0, | |
| ), | |
| A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), | |
| ToTensorV2(), | |
| ], | |
| bbox_params=A.BboxParams( | |
| format="pascal_voc", | |
| label_fields=["class_labels"], | |
| min_visibility=0.3, | |
| ), | |
| ) | |
| # βββ Dataset ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class FracAtlasDataset(Dataset): | |
| """ | |
| COCO-format dataset for FracAtlas fracture detection. | |
| Each item returns: | |
| image : FloatTensor [3, H, W] (normalised) | |
| target : dict with keys | |
| boxes : FloatTensor [N, 4] (x1y1x2y2, normalised 0-1) | |
| labels : LongTensor [N] | |
| masks : FloatTensor [N, H, W] (binary, same spatial size as image) | |
| image_id: int | |
| """ | |
| def __init__( | |
| self, | |
| image_dir: str, | |
| ann_file: str, | |
| img_size: int = 550, | |
| split: str = "train", | |
| ): | |
| self.image_dir = image_dir | |
| self.img_size = img_size | |
| self.split = split | |
| self.coco = COCO(ann_file) | |
| self.image_ids = sorted(self.coco.imgs.keys()) | |
| # Build category β 0-indexed label map | |
| # NOTE: FracAtlas uses category_id=0 ('fractured') β handle offset | |
| cats = self.coco.loadCats(self.coco.getCatIds()) | |
| self.cat_id_to_label = {c["id"]: i for i, c in enumerate(cats)} | |
| # If only one class and its id is 0, map it to label 0 | |
| if len(cats) == 1 and cats[0]["id"] == 0: | |
| self.cat_id_to_label = {0: 0} | |
| self.num_classes = len(cats) | |
| self.class_names = [c["name"] for c in cats] | |
| self.transforms = ( | |
| get_train_transforms(img_size) | |
| if split == "train" | |
| else get_val_transforms(img_size) | |
| ) | |
| print( | |
| f"[{split}] {len(self.image_ids)} images | " | |
| f"{self.num_classes} classes: {self.class_names}" | |
| ) | |
| def __len__(self): | |
| return len(self.image_ids) | |
| def _decode_mask(self, ann: dict, h: int, w: int) -> np.ndarray: | |
| """Decode COCO RLE or polygon segmentation to binary mask.""" | |
| seg = ann.get("segmentation", None) | |
| if seg is None: | |
| # Fall back: create mask from bbox | |
| x, y, bw, bh = [int(v) for v in ann["bbox"]] | |
| m = np.zeros((h, w), dtype=np.uint8) | |
| m[y : y + bh, x : x + bw] = 1 | |
| return m | |
| if isinstance(seg, dict): # RLE | |
| return coco_mask.decode(seg).astype(np.uint8) | |
| else: # polygon | |
| rle = coco_mask.frPyObjects(seg, h, w) | |
| merged = coco_mask.merge(rle) | |
| return coco_mask.decode(merged).astype(np.uint8) | |
| def __getitem__(self, idx: int): | |
| img_id = self.image_ids[idx] | |
| img_info = self.coco.imgs[img_id] | |
| # ββ Load image ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| img_path = os.path.join(self.image_dir, img_info["file_name"]) | |
| image = cv2.imread(img_path) | |
| if image is None: | |
| raise FileNotFoundError(f"Cannot read image: {img_path}") | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| orig_h, orig_w = image.shape[:2] | |
| # ββ Load annotations βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| ann_ids = self.coco.getAnnIds(imgIds=img_id) | |
| anns = self.coco.loadAnns(ann_ids) | |
| boxes, class_labels, raw_masks = [], [], [] | |
| for ann in anns: | |
| x, y, bw, bh = ann["bbox"] | |
| x1, y1, x2, y2 = x, y, x + bw, y + bh | |
| # Clip to image bounds | |
| x1 = max(0.0, x1) | |
| y1 = max(0.0, y1) | |
| x2 = min(float(orig_w), x2) | |
| y2 = min(float(orig_h), y2) | |
| if x2 <= x1 or y2 <= y1: | |
| continue | |
| boxes.append([x1, y1, x2, y2]) | |
| class_labels.append(self.cat_id_to_label[ann["category_id"]]) | |
| raw_masks.append(self._decode_mask(ann, orig_h, orig_w)) | |
| # Non-fractured images: create a dummy background instance so the | |
| # tensor shapes are consistent (YOLACT handles empty targets fine too, | |
| # but keeping consistent is safer). | |
| if len(boxes) == 0: | |
| boxes = [[0.0, 0.0, float(orig_w), float(orig_h)]] | |
| class_labels = [0] # background / non-fractured | |
| raw_masks = [np.zeros((orig_h, orig_w), dtype=np.uint8)] | |
| # ββ Albumentations βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| transformed = self.transforms( | |
| image=image, | |
| masks=raw_masks, | |
| bboxes=boxes, | |
| class_labels=class_labels, | |
| ) | |
| image_t = transformed["image"] # [3, H, W] | |
| boxes_t = transformed["bboxes"] | |
| labels_t = transformed["class_labels"] | |
| masks_t = transformed["masks"] # list of HΓW arrays | |
| _, H, W = image_t.shape | |
| # ββ Build target tensors βββββββββββββββββββββββββββββββββββββββββββββ | |
| if len(boxes_t) == 0: | |
| # All boxes removed by augmentation (e.g. min_visibility) | |
| boxes_out = torch.zeros((0, 4), dtype=torch.float32) | |
| labels_out = torch.zeros((0,), dtype=torch.long) | |
| masks_out = torch.zeros((0, H, W), dtype=torch.float32) | |
| else: | |
| boxes_np = np.array(boxes_t, dtype=np.float32) | |
| # Normalise to [0, 1] | |
| boxes_np[:, [0, 2]] /= W | |
| boxes_np[:, [1, 3]] /= H | |
| boxes_np = np.clip(boxes_np, 0.0, 1.0) | |
| boxes_out = torch.from_numpy(boxes_np) | |
| labels_out = torch.tensor(labels_t, dtype=torch.long) | |
| # Albumentations >=2.x returns masks as tensors; older versions return numpy. | |
| def to_float_tensor(m): | |
| if isinstance(m, torch.Tensor): | |
| return m.float() | |
| return torch.from_numpy(np.array(m, dtype=np.float32)) | |
| masks_out = torch.stack([to_float_tensor(m) for m in masks_t]) | |
| target = { | |
| "boxes": boxes_out, | |
| "labels": labels_out, | |
| "masks": masks_out, | |
| "image_id": img_id, | |
| } | |
| return image_t, target | |
| # βββ Collate ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def detection_collate(batch): | |
| """ | |
| Custom collate for object detection. | |
| Images are stacked; targets are kept as a list (variable number of instances). | |
| """ | |
| images, targets = zip(*batch) | |
| images = torch.stack(images, dim=0) # [B, 3, H, W] | |
| return images, list(targets) | |
| # βββ DataLoader factory βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_dataloader( | |
| image_dir: str, | |
| ann_file: str, | |
| split: str = "train", | |
| img_size: int = 550, | |
| batch_size: int = 8, | |
| num_workers: int = 4, | |
| pin_memory: bool = True, | |
| ) -> DataLoader: | |
| """ | |
| Returns a DataLoader for the given split. | |
| Args: | |
| image_dir : path to images/ folder for this split | |
| ann_file : path to annotations.json for this split | |
| split : "train" | "val" | "test" | |
| img_size : input resolution fed to the network (default 550 for YOLACT+) | |
| batch_size : mini-batch size | |
| num_workers: parallel data-loading workers | |
| pin_memory : pin CPU memory for faster GPU transfer | |
| Returns: | |
| torch.utils.data.DataLoader | |
| """ | |
| dataset = FracAtlasDataset( | |
| image_dir=image_dir, | |
| ann_file=ann_file, | |
| img_size=img_size, | |
| split=split, | |
| ) | |
| shuffle = split == "train" | |
| # pin_memory only works when CUDA is available | |
| import torch | |
| use_pin = pin_memory and torch.cuda.is_available() | |
| loader = DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| shuffle=shuffle, | |
| num_workers=num_workers, | |
| pin_memory=use_pin, | |
| collate_fn=detection_collate, | |
| drop_last=(split == "train"), | |
| ) | |
| return loader | |