File size: 781 Bytes
c858478
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import json
import os
import cv2
import torch
from torch.utils.data import Dataset

class RelationshipDataset(Dataset):
    def __init__(self, image_dir, label_path):
        self.image_dir = image_dir

        with open(label_path) as f:
            self.data = json.load(f)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        img_path = os.path.join(self.image_dir, item["image"])
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = image / 255.0
        image = (image - 0.5) / 0.5

        image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1)

        label = torch.tensor(item["label"], dtype=torch.long)

        return image, label