File size: 1,308 Bytes
5be3c34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import os
import pandas as pd
import torch
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as T

class NosePointDataset(Dataset):
    def __init__(self, root = "/fs/scratch/PAS2099/danielf/medical/nose_clicks_lazy", image_size = (64, 64), device='cpu'):
        
        self.root = root
        self.files = sorted(os.listdir(root))
        self.files = [f for f in self.files if f.endswith('.png')]
        self.device = device

        self.base_transform = T.Compose([
            T.Resize(image_size),
            T.ToTensor(),   # [0, 1], shape (1, H, W)
        ])

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
    
        image = Image.open(os.path.join(self.root, self.files[idx])).convert('RGB')

        orig_w, orig_h = image.size
        with open(os.path.join(self.root, self.files[idx].replace('.png', '.txt')), 'r') as f:
            coords = f.read().strip().split(',')
            x, y = float(coords[0]), float(coords[1])

        x_norm = x / orig_w
        y_norm = y / orig_h

        image = self.base_transform(image).to(self.device)  # [C, H, W], [0, 1]

        coord = torch.tensor([x_norm, y_norm], dtype=torch.float32).to(self.device)  # [2], normalized coordinates

        return image, coord