import tensorflow as tf import numpy as np import matplotlib.pyplot as plt from keras import backend as K class StyleTransfer: """A class for neural style transfer. Uses the Inception model to extract features from the content and style images.""" content_layers = ["conv2d_88", "conv2d_91", "conv2d_92", "conv2d_85", "conv2d_93"] style_layers = ["conv2d", "conv2d_1", "conv2d_2", "conv2d_3", "conv2d_4"] content_and_style_layers = content_layers + style_layers NUM_CONTENT_LAYERS = len(content_layers) NUM_STYLE_LAYERS = len(style_layers) def __init__(self, content_image_path, style_image_path) -> None: """Initializes the class Parameters ---------- content_image_path : str path to the content image style_image_path : str path to the style image Returns ------- None """ self.content_image_path = content_image_path self.style_image_path = style_image_path self.model = None def tensor_to_image(self, tensor): """converts a tensor to an image""" tensor_shape = tf.shape(tensor) number_elem_shape = tf.shape(tensor_shape) if number_elem_shape > 3: assert tensor_shape[0] == 1, "There are more than one image" tensor = tensor[0] return tf.keras.preprocessing.image.array_to_img(tensor) def load_image(self, path_to_img): """loads an image as a tensor and scales it to 512 pixels""" max_dim = 512 image = tf.io.read_file(path_to_img) image = tf.image.decode_jpeg(image) image = tf.image.convert_image_dtype(image, tf.float32) shape = tf.shape(image)[:-1] shape = tf.cast(tf.shape(image)[:-1], tf.float32) long_dim = max(shape) scale = max_dim / long_dim new_shape = tf.cast(shape * scale, tf.int32) image = tf.image.resize(image, new_shape) image = image[tf.newaxis, :] image = tf.image.convert_image_dtype(image, tf.uint8) return image def imshow(self, image, title=""): """displays an image""" if len(image.shape) > 3: image = tf.squeeze(image, axis=0) plt.imshow(image) plt.title(title) def show_images_with_style(self, images, titles=[]): """displays a row of images with corresponding titles""" if len(images) != len(titles): return plt.figure(figsize=(20, 12)) for idx, (image, title) in enumerate(zip(images, titles)): plt.subplot(1, len(images), idx + 1) plt.xticks([]) plt.yticks([]) self.imshow(image, title) plt.show() def preprocess_image(self, image): """preprocesses a given image to use with Inception model""" image = tf.cast(image, dtype=tf.float32) image = (image / 127.5) - 1.0 return image def display_images(self): """displays the content and style images""" content_image = self.load_image(self.content_image_path) style_image = self.load_image(self.style_image_path) self.show_images_with_style( [content_image, style_image], titles=[f"Content image", f"Style image"], ) def gram_matrix(self, input_tensor): """Calculates the gram matrix and divides by the number of locations Parameters ---------- input_tensor : tensor tensor to calculate the gram matrix from Returns ------- tensor gram matrix of the input tensor """ gram = tf.linalg.einsum("bijc,bijd->bcd", input_tensor, input_tensor) input_shape = tf.shape(input_tensor) num_locations = tf.cast(input_shape[1] * input_shape[2], tf.float32) scaled_gram = gram / num_locations return scaled_gram def get_features(self, image, type=None): """Returns the features of the image Parameters ---------- image : tensor image to extract features from type : str type of features to extract. Either "style" or "content". If `None` is provided, both content and style features are returned Returns ------- list list of features of the content and style images """ preprocessed_image = self.preprocess_image(image) outputs = self.model(preprocessed_image) if type == "style": style_outputs = outputs[self.NUM_CONTENT_LAYERS :] gram_style_features = [ self.gram_matrix(style_output) for style_output in style_outputs ] return gram_style_features elif type == "content": content_outputs = outputs[: self.NUM_CONTENT_LAYERS] return content_outputs else: style_outputs = outputs[self.NUM_CONTENT_LAYERS :] content_outputs = outputs[: self.NUM_CONTENT_LAYERS] gram_style_features = [ self.gram_matrix(style_output) for style_output in style_outputs ] return content_outputs + gram_style_features def _loss(self, features, targets, type="style"): """Returns the loss of fearure and target. This is just the mean square error. features : list list of features of the content and style images target : list list of features of the content and style images type : str type of loss to calculate. Either "style" or "content" """ loss = tf.reduce_mean(tf.square(features - targets)) if type == "content": loss = loss * 0.5 return loss def get_loss(self, features, target, alpha, beta): """Returns the total loss of the style and content images Parameters ---------- features : list list of features of the content and style images target : list list of features of the content and style images alpha : float weight of the content loss beta : float weight of the style loss Returns ------- loss : float total loss of the style and content images """ style_features = features[self.NUM_CONTENT_LAYERS :] content_features = features[: self.NUM_CONTENT_LAYERS] style_targets = target[self.NUM_CONTENT_LAYERS :] content_targets = target[: self.NUM_CONTENT_LAYERS] style_loss = 0 content_loss = 0 for i in range(self.NUM_STYLE_LAYERS): style_loss += self._loss(style_features[i], style_targets[i], type="style") for i in range(self.NUM_CONTENT_LAYERS): content_loss += self._loss( content_features[i], content_targets[i], type="content" ) style_loss = beta * style_loss / self.NUM_STYLE_LAYERS content_loss = alpha * content_loss / self.NUM_CONTENT_LAYERS loss = content_loss + style_loss return loss def calculate_gradients(self, image, target, alpha, beta): """Calculates the gradients of the loss with respect to the image""" with tf.GradientTape() as tape: features = self.get_features(image, "all") loss = self.get_loss(features, target, alpha, beta) gradients = tape.gradient(loss, image) return gradients, loss def update_image(self, image, target, alpha, beta, optimizer): """Updates the image by calculating the gradients and applying them to the image""" gradients, loss = self.calculate_gradients(image, target, alpha, beta) optimizer.apply_gradients([(gradients, image)]) image.assign(tf.clip_by_value(image, clip_value_min=0.0, clip_value_max=255.0)) return loss def load_model(self): """Creates a inception model that returns a list of intermediate output values""" K.clear_session() inception = tf.keras.applications.InceptionV3( include_top=False, weights="imagenet" ) inception.trainable = False output_layers = self.content_and_style_layers model = tf.keras.models.Model( [inception.input], [inception.get_layer(name).output for name in output_layers], ) self.model = model return model def stylize_image( self, alpha=1, beta=0.1, epochs=10, steps_per_epoch=10, show_images=True, image_frequency=2, notebook=False, lr=None, ): """Stylizes the image using the style and content images Parameters ---------- alpha : float, optional Content weight, by default 1 beta : float, optional Style weight, by default 0.1 epochs : int, optional Number of epochs, by default 10 steps_per_epoch : int, optional Number of steps per epoch, by default 10 show_images : bool, optional Show images, by default True image_frequency : int, optional Frequency of images to show, by default 2 notebook : bool, optional If the code is running on a notebook, by default False lr : float, optional Learning rate, by default None Returns ------- [PIL.Image] List of images """ if self.model is None: K.clear_session() _ = self.load_model() style_image = self.load_image(self.style_image_path) content_image = self.load_image(self.content_image_path) style_target = self.get_features(style_image, "style") content_target = self.get_features(content_image, "content") target = content_target + style_target image = tf.cast(content_image, dtype=tf.float32) image = tf.Variable(image) # images = [] if lr is None: lr = 40.0 optimizer = tf.optimizers.Adam( tf.keras.optimizers.schedules.ExponentialDecay( initial_learning_rate=lr, decay_steps=100, decay_rate=0.80 ) ) img = None for epoch in range(epochs): for step in range(steps_per_epoch): loss = self.update_image(image, target, alpha, beta, optimizer) display_image = self.tensor_to_image(image) # images.append(display_image) if show_images: if (step) % image_frequency == 0: # save the display_image display_image.save(f".img.jpg") # if notebook: # display_image = self.tensor_to_image(image) # display_fn( # display_image, # clear=True, # ) # else: # im = np.array(display_image) # if img is None: # img = plt.imshow(im) # else: # img.set_data(im) # plt.pause(0.1) # plt.draw() # yield np.array(display_image) # print(f"Epoch: {epoch+1} | Step {step+1} | Loss {loss}", end="\r") # return images