VJyzCELERY's picture
Commit to hf space
c3d45c0
raw
history blame
1.18 kB
from torch.utils.data import Dataset
import torch
import os
import numpy as np
import cv2
def collate_fn(batch):
imgs = [img for img, _ in batch]
labels = torch.tensor([label for _, label in batch])
return imgs, labels
class ImageDataset(Dataset):
def __init__(self,root_path : str,img_size=(256,256)):
classes = os.listdir(root_path)
self.img_size = img_size
self.classes = classes
data = []
for idx,class_name in enumerate(classes):
class_path = os.path.join(root_path,class_name)
files = os.listdir(class_path)
for file in files:
filepath = os.path.join(class_path,file)
data.append({"image_path":filepath,"label":class_name,"id":idx})
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self,idx):
curr = self.data[idx]
label = curr['id']
img_path = curr['image_path']
img = cv2.imread(img_path)
img = cv2.resize(img,(self.img_size))
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
img = img.astype(np.float32) / 255.0
return img,label