ControlNet / data /mnist_dataset.py
YashNagraj75's picture
Add dataloader and add logging for training script
9774d79
import glob
import os
import cv2
import numpy as np
import torchvision
from PIL import Image
from torch.utils.data.dataset import Dataset
from tqdm import tqdm
class MnistDataset(Dataset):
r"""
Nothing special here. Just a simple dataset class for mnist images.
Created a dataset class rather using torchvision to allow
replacement with any other image dataset
"""
def __init__(self, split, im_path, im_ext="png", im_size=28, return_hints=False):
r"""
Init method for initializing the dataset properties
:param split: train/test to locate the image files
:param im_path: root folder of images
:param im_ext: image extension. assumes all
images would be this type.
"""
self.split = split
self.im_ext = im_ext
self.return_hints = return_hints
self.images = self.load_images(im_path)
def load_images(self, im_path):
r"""
Gets all images from the path specified
and stacks them all up
:param im_path:
:return:
"""
assert os.path.exists(im_path), "images path {} does not exist".format(im_path)
ims = []
labels = []
for d_name in tqdm(os.listdir(im_path)):
for fname in glob.glob(
os.path.join(im_path, d_name, "*.{}".format(self.im_ext))
):
ims.append(fname)
print("Found {} images for split {}".format(len(ims), self.split))
return ims
def __len__(self):
return len(self.images)
def __getitem__(self, index):
im = Image.open(self.images[index])
im_tensor = torchvision.transforms.ToTensor()(im)
# Convert input to -1 to 1 range.
im_tensor = (2 * im_tensor) - 1
if self.return_hints:
canny_image = Image.open(self.images[index])
canny_image = np.array(canny_image)
canny_image = cv2.Canny(canny_image, 100, 200)
canny_image = canny_image[:, :, None]
canny_image = np.concatenate(
[canny_image, canny_image, canny_image], axis=2
)
canny_image_tensor = torchvision.transforms.ToTensor()(canny_image)
return im_tensor, canny_image_tensor
else:
return im_tensor