import os import pandas as pd import torch from torch.utils.data import Dataset from PIL import Image import torchvision.transforms as T class NosePointDataset(Dataset): def __init__(self, root = "/fs/scratch/PAS2099/danielf/medical/nose_clicks_lazy", image_size = (64, 64), device='cpu'): self.root = root self.files = sorted(os.listdir(root)) self.files = [f for f in self.files if f.endswith('.png')] self.device = device self.base_transform = T.Compose([ T.Resize(image_size), T.ToTensor(), # [0, 1], shape (1, H, W) ]) def __len__(self): return len(self.files) def __getitem__(self, idx): image = Image.open(os.path.join(self.root, self.files[idx])).convert('RGB') orig_w, orig_h = image.size with open(os.path.join(self.root, self.files[idx].replace('.png', '.txt')), 'r') as f: coords = f.read().strip().split(',') x, y = float(coords[0]), float(coords[1]) x_norm = x / orig_w y_norm = y / orig_h image = self.base_transform(image).to(self.device) # [C, H, W], [0, 1] coord = torch.tensor([x_norm, y_norm], dtype=torch.float32).to(self.device) # [2], normalized coordinates return image, coord