MedAI-ACM / src /utils /data_utils.py
Tirath5504's picture
deploy
bf07f10
import os
from PIL import Image
import torch
from torch.utils.data import Dataset
import torchvision.transforms as T
def get_transforms(split: str, img_size: int = 224):
"""Returns train or val/test transforms."""
if split == 'train':
return T.Compose([
T.Resize((int(img_size*1.1), int(img_size*1.1))),
T.RandomResizedCrop(img_size, scale=(0.8, 1.0)),
T.RandomRotation(15),
T.RandomHorizontalFlip(),
T.ToTensor(),
T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])
else:
return T.Compose([
T.Resize((img_size, img_size)),
T.CenterCrop(img_size),
T.ToTensor(),
T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])
class FractureDataset(Dataset):
"""Dataset for fracture images with optional bounding box cropping."""
def __init__(self, df, img_root: str = '.', transform=None, use_bbox: bool = False):
self.entries = df
self.img_root = img_root
self.transform = transform
self.use_bbox = use_bbox
def __len__(self):
return len(self.entries)
def __getitem__(self, idx):
row = self.entries[idx]
img_path = row['image_path']
if not os.path.isabs(img_path):
img_path = os.path.join(self.img_root, img_path)
img = Image.open(img_path).convert('RGB')
if self.use_bbox and all(k in row for k in ('bbox_xmin','bbox_ymin','bbox_xmax','bbox_ymax')):
xmin = int(row['bbox_xmin'])
ymin = int(row['bbox_ymin'])
xmax = int(row['bbox_xmax'])
ymax = int(row['bbox_ymax'])
img = img.crop((xmin, ymin, xmax, ymax))
label = int(row['label'])
if self.transform:
img = self.transform(img)
return img, label, img_path