|
|
|
|
|
|
|
|
import torch |
|
|
from torchvision import transforms |
|
|
from torch.utils.data import Dataset |
|
|
|
|
|
import os |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
|
|
|
class CrosswalkDataset(Dataset): |
|
|
def __init__(self, src_dir, transform=None): |
|
|
self.src_dir = src_dir |
|
|
self.transform = transform |
|
|
|
|
|
dir_files = sorted(os.listdir(src_dir)) |
|
|
self.image_paths = [file_path for file_path in dir_files if file_path.endswith((".png", ".jpg", ".jpeg"))] |
|
|
self.label_paths = [file_path for file_path in dir_files if file_path.endswith(".txt")] |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.image_paths) |
|
|
|
|
|
def __getitem__(self, index): |
|
|
image_path = os.path.join(self.src_dir, self.image_paths[index]) |
|
|
label_path = os.path.join(self.src_dir, self.label_paths[index]) |
|
|
|
|
|
label = [0, 0] |
|
|
try: |
|
|
if np.array([int(open(label_path).read().strip())]) == 1: |
|
|
label = [1, 0] |
|
|
else: |
|
|
label = [0, 1] |
|
|
except: |
|
|
pass |
|
|
image = Image.open(image_path) |
|
|
|
|
|
if self.transform is None: |
|
|
self.transform = transforms.ToTensor() |
|
|
|
|
|
return (self.transform(image), torch.FloatTensor(label)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vgg_transform = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.3, 0.3, 0.3]) |
|
|
]) |
|
|
|
|
|
res_transform = transforms.Compose([ |
|
|
transforms.Resize((256, 256)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.3, 0.3, 0.3]) |
|
|
]) |
|
|
|
|
|
mob3_transform = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
|
]) |