Daniel-F's picture
add train folder
5be3c34
import os
import pandas as pd
import torch
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as T
class NosePointDataset(Dataset):
def __init__(self, root = "/fs/scratch/PAS2099/danielf/medical/nose_clicks_lazy", image_size = (64, 64), device='cpu'):
self.root = root
self.files = sorted(os.listdir(root))
self.files = [f for f in self.files if f.endswith('.png')]
self.device = device
self.base_transform = T.Compose([
T.Resize(image_size),
T.ToTensor(), # [0, 1], shape (1, H, W)
])
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
image = Image.open(os.path.join(self.root, self.files[idx])).convert('RGB')
orig_w, orig_h = image.size
with open(os.path.join(self.root, self.files[idx].replace('.png', '.txt')), 'r') as f:
coords = f.read().strip().split(',')
x, y = float(coords[0]), float(coords[1])
x_norm = x / orig_w
y_norm = y / orig_h
image = self.base_transform(image).to(self.device) # [C, H, W], [0, 1]
coord = torch.tensor([x_norm, y_norm], dtype=torch.float32).to(self.device) # [2], normalized coordinates
return image, coord