FracAtlas-YOLACT / dataloader.py
MuhammadAdil63's picture
deploy YOLACT+ fracture detection demo
fcec417
Raw
History Blame Contribute Delete
11 kB
"""
FracAtlas DataLoader for YOLACT+ (ResNet-18 backbone)
======================================================
Provides:
- FracAtlasDataset : torch.utils.data.Dataset over COCO-format splits
- detection_collate : custom collate for variable-size masks/boxes
- get_dataloader : factory function for train / val / test loaders
"""
import os
import os
import cv2
import numpy as np
import torch
import warnings
warnings.filterwarnings("ignore", message=".*Premature end.*")
warnings.filterwarnings("ignore", message=".*Corrupt JPEG.*")
# Suppress OpenCV JPEG warnings
try:
cv2.setLogLevel(0)
except AttributeError:
os.environ["OPENCV_LOG_LEVEL"] = "SILENT"
from torch.utils.data import Dataset, DataLoader
from pycocotools.coco import COCO
from pycocotools import mask as coco_mask
import albumentations as A
from albumentations.pytorch import ToTensorV2
# ─── Augmentation pipelines ───────────────────────────────────────────────────
def get_train_transforms(img_size: int = 550):
return A.Compose(
[
A.LongestMaxSize(max_size=img_size),
A.PadIfNeeded(
min_height=img_size,
min_width=img_size,
fill=0,
),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.2),
A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
A.GaussNoise(p=0.3),
A.Affine(
translate_percent=0.05,
scale=(0.9, 1.1),
rotate=(-10, 10),
p=0.4,
),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
],
bbox_params=A.BboxParams(
format="pascal_voc",
label_fields=["class_labels"],
min_visibility=0.3,
),
)
def get_val_transforms(img_size: int = 550):
return A.Compose(
[
A.LongestMaxSize(max_size=img_size),
A.PadIfNeeded(
min_height=img_size,
min_width=img_size,
fill=0,
),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
],
bbox_params=A.BboxParams(
format="pascal_voc",
label_fields=["class_labels"],
min_visibility=0.3,
),
)
# ─── Dataset ──────────────────────────────────────────────────────────────────
class FracAtlasDataset(Dataset):
"""
COCO-format dataset for FracAtlas fracture detection.
Each item returns:
image : FloatTensor [3, H, W] (normalised)
target : dict with keys
boxes : FloatTensor [N, 4] (x1y1x2y2, normalised 0-1)
labels : LongTensor [N]
masks : FloatTensor [N, H, W] (binary, same spatial size as image)
image_id: int
"""
def __init__(
self,
image_dir: str,
ann_file: str,
img_size: int = 550,
split: str = "train",
):
self.image_dir = image_dir
self.img_size = img_size
self.split = split
self.coco = COCO(ann_file)
self.image_ids = sorted(self.coco.imgs.keys())
# Build category β†’ 0-indexed label map
# NOTE: FracAtlas uses category_id=0 ('fractured') β€” handle offset
cats = self.coco.loadCats(self.coco.getCatIds())
self.cat_id_to_label = {c["id"]: i for i, c in enumerate(cats)}
# If only one class and its id is 0, map it to label 0
if len(cats) == 1 and cats[0]["id"] == 0:
self.cat_id_to_label = {0: 0}
self.num_classes = len(cats)
self.class_names = [c["name"] for c in cats]
self.transforms = (
get_train_transforms(img_size)
if split == "train"
else get_val_transforms(img_size)
)
print(
f"[{split}] {len(self.image_ids)} images | "
f"{self.num_classes} classes: {self.class_names}"
)
def __len__(self):
return len(self.image_ids)
def _decode_mask(self, ann: dict, h: int, w: int) -> np.ndarray:
"""Decode COCO RLE or polygon segmentation to binary mask."""
seg = ann.get("segmentation", None)
if seg is None:
# Fall back: create mask from bbox
x, y, bw, bh = [int(v) for v in ann["bbox"]]
m = np.zeros((h, w), dtype=np.uint8)
m[y : y + bh, x : x + bw] = 1
return m
if isinstance(seg, dict): # RLE
return coco_mask.decode(seg).astype(np.uint8)
else: # polygon
rle = coco_mask.frPyObjects(seg, h, w)
merged = coco_mask.merge(rle)
return coco_mask.decode(merged).astype(np.uint8)
def __getitem__(self, idx: int):
img_id = self.image_ids[idx]
img_info = self.coco.imgs[img_id]
# ── Load image ────────────────────────────────────────────────────────
img_path = os.path.join(self.image_dir, img_info["file_name"])
image = cv2.imread(img_path)
if image is None:
raise FileNotFoundError(f"Cannot read image: {img_path}")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
orig_h, orig_w = image.shape[:2]
# ── Load annotations ─────────────────────────────────────────────────
ann_ids = self.coco.getAnnIds(imgIds=img_id)
anns = self.coco.loadAnns(ann_ids)
boxes, class_labels, raw_masks = [], [], []
for ann in anns:
x, y, bw, bh = ann["bbox"]
x1, y1, x2, y2 = x, y, x + bw, y + bh
# Clip to image bounds
x1 = max(0.0, x1)
y1 = max(0.0, y1)
x2 = min(float(orig_w), x2)
y2 = min(float(orig_h), y2)
if x2 <= x1 or y2 <= y1:
continue
boxes.append([x1, y1, x2, y2])
class_labels.append(self.cat_id_to_label[ann["category_id"]])
raw_masks.append(self._decode_mask(ann, orig_h, orig_w))
# Non-fractured images: create a dummy background instance so the
# tensor shapes are consistent (YOLACT handles empty targets fine too,
# but keeping consistent is safer).
if len(boxes) == 0:
boxes = [[0.0, 0.0, float(orig_w), float(orig_h)]]
class_labels = [0] # background / non-fractured
raw_masks = [np.zeros((orig_h, orig_w), dtype=np.uint8)]
# ── Albumentations ───────────────────────────────────────────────────
transformed = self.transforms(
image=image,
masks=raw_masks,
bboxes=boxes,
class_labels=class_labels,
)
image_t = transformed["image"] # [3, H, W]
boxes_t = transformed["bboxes"]
labels_t = transformed["class_labels"]
masks_t = transformed["masks"] # list of HΓ—W arrays
_, H, W = image_t.shape
# ── Build target tensors ─────────────────────────────────────────────
if len(boxes_t) == 0:
# All boxes removed by augmentation (e.g. min_visibility)
boxes_out = torch.zeros((0, 4), dtype=torch.float32)
labels_out = torch.zeros((0,), dtype=torch.long)
masks_out = torch.zeros((0, H, W), dtype=torch.float32)
else:
boxes_np = np.array(boxes_t, dtype=np.float32)
# Normalise to [0, 1]
boxes_np[:, [0, 2]] /= W
boxes_np[:, [1, 3]] /= H
boxes_np = np.clip(boxes_np, 0.0, 1.0)
boxes_out = torch.from_numpy(boxes_np)
labels_out = torch.tensor(labels_t, dtype=torch.long)
# Albumentations >=2.x returns masks as tensors; older versions return numpy.
def to_float_tensor(m):
if isinstance(m, torch.Tensor):
return m.float()
return torch.from_numpy(np.array(m, dtype=np.float32))
masks_out = torch.stack([to_float_tensor(m) for m in masks_t])
target = {
"boxes": boxes_out,
"labels": labels_out,
"masks": masks_out,
"image_id": img_id,
}
return image_t, target
# ─── Collate ──────────────────────────────────────────────────────────────────
def detection_collate(batch):
"""
Custom collate for object detection.
Images are stacked; targets are kept as a list (variable number of instances).
"""
images, targets = zip(*batch)
images = torch.stack(images, dim=0) # [B, 3, H, W]
return images, list(targets)
# ─── DataLoader factory ───────────────────────────────────────────────────────
def get_dataloader(
image_dir: str,
ann_file: str,
split: str = "train",
img_size: int = 550,
batch_size: int = 8,
num_workers: int = 4,
pin_memory: bool = True,
) -> DataLoader:
"""
Returns a DataLoader for the given split.
Args:
image_dir : path to images/ folder for this split
ann_file : path to annotations.json for this split
split : "train" | "val" | "test"
img_size : input resolution fed to the network (default 550 for YOLACT+)
batch_size : mini-batch size
num_workers: parallel data-loading workers
pin_memory : pin CPU memory for faster GPU transfer
Returns:
torch.utils.data.DataLoader
"""
dataset = FracAtlasDataset(
image_dir=image_dir,
ann_file=ann_file,
img_size=img_size,
split=split,
)
shuffle = split == "train"
# pin_memory only works when CUDA is available
import torch
use_pin = pin_memory and torch.cuda.is_available()
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=use_pin,
collate_fn=detection_collate,
drop_last=(split == "train"),
)
return loader