Spaces:
Runtime error
Runtime error
| import os | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch | |
| from torchvision import datasets | |
| from torch.utils.data import Dataset, DataLoader, Sampler | |
| import torchvision.transforms as T | |
| # ImageNet Mean and Std values used for Normalization. | |
| Imagenet_Mean = np.array([0.485, 0.456, 0.406]) | |
| Imagenet_Std = np.array([0.229, 0.224, 0.225]) | |
| def process_img(image_path: os.path, batch_size: int = 1, target_shape = None) -> torch.Tensor: | |
| """ | |
| Function to Pre-process image. | |
| Args: | |
| image_path -> Path of image. | |
| batch_size -> To create batch of images. | |
| target_shape -> Target size of output image. | |
| """ | |
| if not os.path.exists(image_path): | |
| raise Exception(f"Path doesn't exist: {image_path}") | |
| image = Image.open(image_path) | |
| # image = cv2.imread(image_path)[:, :, ::-1] | |
| if target_shape is not None: | |
| if isinstance(target_shape, int) and target_shape != -1: | |
| width, height = image.size | |
| # height, width = image.shape[:2] | |
| new_width = target_shape | |
| new_height = int(height * (new_width / width)) | |
| image = image.resize((new_width, new_height)) | |
| # image = cv2.resize(image, (new_width, new_height), interpolation = cv2.INTER_CUBIC) | |
| else: | |
| image = image.resize((target_shape[0], target_shape[1])) | |
| # image = cv2.resize(image, (target_shape[0], target_shape[1]), interpolation = cv2.INTER_CUBIC) | |
| image = np.array(image).astype(np.float32) | |
| # image = image.astype(np.float32) | |
| image /= 255.0 | |
| transforms = T.Compose([T.ToTensor(), | |
| T.Normalize(mean = Imagenet_Mean, std = Imagenet_Std) | |
| ]) | |
| image = transforms(image) | |
| image = image.repeat(batch_size, 1, 1, 1) | |
| return image | |
| """________________________________________________________________________________________________________________________________________________________________""" | |
| """Class to create a subset of data to be used for training.""" | |
| class SequentialSampler(Sampler): | |
| def __init__(self, datasource, subset_size: int = None): | |
| """ | |
| Args: | |
| datasource -> Torch Dataset. | |
| subset_size -> Number of samples to be used. | |
| """ | |
| assert isinstance(datasource, Dataset) or isinstance(datasource, datasets.ImageFolder), "Datasource must be either Torch Dataset or Torchvision ImageFolder." | |
| self.data_source = datasource | |
| if subset_size is None: | |
| subset_size = len(self.data_source) | |
| assert 0 < subset_size <= len(self.data_source), f"Subset size must be between (0, {len(self.data_source)})." | |
| self.subset_size = subset_size | |
| def __iter__(self): | |
| return iter(range(self.subset_size)) | |
| def __len__(self): | |
| return self.subset_size | |
| """________________________________________________________________________________________________________________________________________________________________""" | |
| def Train_DataLoader(training_config: dict) -> DataLoader: | |
| """ | |
| Function to create DataLoader for training. | |
| Args: | |
| training_config -> Dictionary of params config. | |
| """ | |
| transforms = T.Compose([T.Resize(training_config['image_size']), | |
| T.CenterCrop(training_config['image_size']), | |
| T.ToTensor(), | |
| T.Normalize(mean = Imagenet_Mean, std = Imagenet_Std) | |
| ]) | |
| Train_dataset = datasets.ImageFolder(training_config['dataset_path'], transform = transforms) | |
| sampler = SequentialSampler(Train_dataset, subset_size = training_config['subset_size']) | |
| training_config['subset_size'] = len(sampler) | |
| train_dl = DataLoader(Train_dataset, batch_size = training_config['batch_size'], sampler = sampler, drop_last = True) | |
| return train_dl | |
| """________________________________________________________________________________________________________________________________________________________________""" | |
| def gram_matrix(feature_map: torch.Tensor, should_mormalize: bool = True): | |
| """ | |
| Function to create Gram matrix. | |
| Gram matrix - It captures the texture information from the feature maps extracted from convolutional | |
| neural network. It is obtained using dot product between the feature maps. In other words, | |
| it is a simple covariance matrix between different feature maps. | |
| Args: | |
| feature_map -> Feature extraced from pre-trained vision models. | |
| should_normalize -> Should normalize the matrix. | |
| """ | |
| (b, ch, h, w) = feature_map.size() | |
| features = feature_map.view(b, ch, h * w) | |
| features_t = features.transpose(1, 2) | |
| gram_mat = features.bmm(features_t) | |
| if should_mormalize: | |
| gram_mat = gram_mat / (ch * h * w) | |
| return gram_mat | |
| """________________________________________________________________________________________________________________________________________________________________""" | |
| def total_variation_loss(image_batch) -> torch.Tensor: | |
| """ | |
| Function for Total variation loss. | |
| Used for spatial continuity between the pixels of the generated image, thereby denoising it and giving it visual coherence. | |
| Args: | |
| image_batch -> Batch of image tensors. | |
| """ | |
| batch_size = image_batch.shape[0] | |
| tv_height = torch.sum(torch.abs(image_batch[:, :, :-1, :] - image_batch[:, :, 1:, :])) | |
| tv_width = torch.sum(torch.abs(image_batch[:, :, :, :-1]) - image_batch[:, :, :, 1:]) | |
| return (tv_height + tv_width) / batch_size | |
| """________________________________________________________________________________________________________________________________________________________________""" | |
| def post_process_img(image: np.ndarray) -> np.ndarray: | |
| """ | |
| Function to post process the generated image. | |
| Args: | |
| image -> Numpy array of generated image. | |
| """ | |
| assert isinstance(image, np.ndarray), f"Expected Numpy Array, but got {type(image)}" | |
| mean = Imagenet_Mean.reshape(-1, 1, 1) | |
| std = Imagenet_Std.reshape(-1, 1, 1) | |
| image = (image * std) + mean | |
| image = (np.clip(image, 0., 1.) * 255).astype(np.uint8) | |
| image = np.moveaxis(image, 0, 2) | |
| return image | |
| """________________________________________________________________________________________________________________________________________________________________""" | |
| def save_and_display(config: dict, image: np.ndarray, should_display: bool = False, index: int = None): | |
| """ | |
| Function to save and display images. | |
| Args: | |
| config -> Dictionary of config. | |
| image -> Numpy array of image. | |
| should_display -> To display image. | |
| index -> Index of generated image during training. | |
| """ | |
| assert isinstance(image, np.ndarray), f"Expected Numpy Array, but got {type(image)}" | |
| image = post_process_img(image) | |
| if index is not None: | |
| fake_fname = 'generated-images-{0:0=4d}.jpg'.format(index) | |
| else: | |
| fake_fname = f'Stylized-image-{config["content_img_name"].split(".")[0]}.jpg' | |
| os.makedirs(config['save_folder'], exist_ok = True) | |
| plt.imsave(os.path.join(config['save_folder'], fake_fname), image) | |
| # cv2.imwrite(os.path.join(config['save_folder'], fake_fname), image[:, :, ::-1]) | |
| if should_display: | |
| plt.imshow(image) | |
| plt.show() | |
| """________________________________________________________________________________________________________________________________________________________________""" | |
| def print_parameters(model: torch.nn.Module) -> None: | |
| """ | |
| Function to print number of parameters. | |
| Args: | |
| model -> Torch model | |
| """ | |
| Num_of_parameters = sum(p.numel() for p in model.parameters()) | |
| print("Model Parameters : {:.3f} M".format(Num_of_parameters / 1e6)) |