Spaces:
Runtime error
Runtime error
File size: 1,943 Bytes
bf07f10 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import os
from PIL import Image
import torch
from torch.utils.data import Dataset
import torchvision.transforms as T
def get_transforms(split: str, img_size: int = 224):
"""Returns train or val/test transforms."""
if split == 'train':
return T.Compose([
T.Resize((int(img_size*1.1), int(img_size*1.1))),
T.RandomResizedCrop(img_size, scale=(0.8, 1.0)),
T.RandomRotation(15),
T.RandomHorizontalFlip(),
T.ToTensor(),
T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])
else:
return T.Compose([
T.Resize((img_size, img_size)),
T.CenterCrop(img_size),
T.ToTensor(),
T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])
class FractureDataset(Dataset):
"""Dataset for fracture images with optional bounding box cropping."""
def __init__(self, df, img_root: str = '.', transform=None, use_bbox: bool = False):
self.entries = df
self.img_root = img_root
self.transform = transform
self.use_bbox = use_bbox
def __len__(self):
return len(self.entries)
def __getitem__(self, idx):
row = self.entries[idx]
img_path = row['image_path']
if not os.path.isabs(img_path):
img_path = os.path.join(self.img_root, img_path)
img = Image.open(img_path).convert('RGB')
if self.use_bbox and all(k in row for k in ('bbox_xmin','bbox_ymin','bbox_xmax','bbox_ymax')):
xmin = int(row['bbox_xmin'])
ymin = int(row['bbox_ymin'])
xmax = int(row['bbox_xmax'])
ymax = int(row['bbox_ymax'])
img = img.crop((xmin, ymin, xmax, ymax))
label = int(row['label'])
if self.transform:
img = self.transform(img)
return img, label, img_path
|