Spaces:
Runtime error
Runtime error
| import tensorflow as tf | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from keras import backend as K | |
| class NeuralStyleTransfer: | |
| def __init__(self, style_image, content_image, extractor, n_style_layers=5, n_content_layers=5): | |
| # load the model | |
| if extractor == "inception_v3": | |
| self.feature_extractor = tf.keras.applications.InceptionV3( | |
| include_top=False, weights="imagenet" | |
| ) | |
| elif extractor == "vgg19": | |
| self.feature_extractor = tf.keras.applications.VGG19( | |
| include_top=False, weights="imagenet" | |
| ) | |
| elif extractor == "resnet50": | |
| self.feature_extractor = tf.keras.applications.ResNet50( | |
| include_top=False, weights="imagenet" | |
| ) | |
| elif extractor == "mobilenet_v2": | |
| self.feature_extractor = tf.keras.applications.MobileNetV2( | |
| include_top=False, weights="imagenet" | |
| ) | |
| elif isinstance(extractor, tf.keras.Model): | |
| self.feature_extractor = extractor | |
| else: | |
| raise Exception("Features Extractor not found") | |
| # freeze the model | |
| self.feature_extractor.trainable = False | |
| # define the style and content depth | |
| self.n_style_layers = n_style_layers | |
| self.n_content_layers = n_content_layers | |
| self.style_image = self._load_img(style_image) | |
| self.content_image = self._load_img(content_image) | |
| def tensor_to_image(self, tensor): | |
| """converts a tensor to an image""" | |
| tensor_shape = tf.shape(tensor) | |
| number_elem_shape = tf.shape(tensor_shape) | |
| if number_elem_shape > 3: | |
| assert tensor_shape[0] == 1 | |
| tensor = tensor[0] | |
| return tf.keras.preprocessing.image.array_to_img(tensor) | |
| def _load_img(self, image): | |
| max_dim = 512 | |
| image = tf.io.read_file(image) | |
| image = tf.image.decode_image(image) | |
| image = tf.image.convert_image_dtype(image, tf.float32) | |
| image = tf.image.convert_image_dtype(image, tf.float32) | |
| shape = tf.shape(image)[:-1] | |
| shape = tf.cast(tf.shape(image)[:-1], tf.float32) | |
| long_dim = max(shape) | |
| scale = max_dim / long_dim | |
| new_shape = tf.cast(shape * scale, tf.int32) | |
| image = tf.image.resize(image, new_shape) | |
| image = image[tf.newaxis, :] | |
| image = tf.image.convert_image_dtype(image, tf.uint8) | |
| return image | |
| def imshow(self, image, title=None): | |
| """displays an image with a corresponding title""" | |
| if len(image.shape) > 3: | |
| image = tf.squeeze(image, axis=0) | |
| plt.imshow(image) | |
| if title: | |
| plt.title(title) | |
| def show_images_with_objects(self, images, titles=[]): | |
| """displays a row of images with corresponding titles""" | |
| if len(images) != len(titles): | |
| return | |
| plt.figure(figsize=(20, 12)) | |
| for idx, (image, title) in enumerate(zip(images, titles)): | |
| plt.subplot(1, len(images), idx + 1) | |
| plt.xticks([]) | |
| plt.yticks([]) | |
| self.imshow(image, title) | |
| def _preprocess_image(self, image): | |
| image = tf.cast(image, dtype=tf.float32) | |
| image = (image / 127.5) - 1.0 | |
| return image | |
| def get_output_layers(self): | |
| # get all the layers which contain conv in their name | |
| all_layers = [ | |
| layer.name | |
| for layer in self.feature_extractor.layers | |
| if "conv" in layer.name | |
| ] | |
| # define the style layers | |
| style_layers = all_layers[: self.n_style_layers] | |
| # define the content layers from second last layer | |
| content_layers = all_layers[-2: -self.n_content_layers - 2 : -1] | |
| content_and_style_layers = content_layers + style_layers | |
| return content_and_style_layers | |
| def build(self, layers_name): | |
| output_layers = [ | |
| self.feature_extractor.get_layer(name).output for name in layers_name | |
| ] | |
| model = tf.keras.Model(self.feature_extractor.input, output_layers) | |
| self.feature_extractor = model | |
| return | |
| def _loss(self, target_img, features_img, type): | |
| """ | |
| Calculates the loss of the style transfer | |
| target_img: | |
| the target image (style or content) features | |
| features_img: | |
| the generated image features (style or content) | |
| """ | |
| loss = tf.reduce_mean(tf.square(features_img - target_img)) | |
| if type == "content": | |
| return 0.5 * loss | |
| return loss | |
| def _gram_matrix(self, input_tensor): | |
| """ | |
| Calculates the gram matrix and divides by the number of locations | |
| input_tensor: | |
| the output of the conv layer of the style image, shape = (batch_size, height, width, channels) | |
| """ | |
| result = tf.linalg.einsum("bijc,bijd->bcd", input_tensor, input_tensor) | |
| input_shape = tf.shape(input_tensor) | |
| num_locations = tf.cast(input_shape[1] * input_shape[2], tf.float32) | |
| return result / (num_locations) | |
| def get_features(self, image, type): | |
| preprocess_image = self._preprocess_image(image) | |
| outputs = self.feature_extractor(preprocess_image) | |
| if type == "style": | |
| outputs = outputs[self.n_content_layers : ] | |
| features = [self._gram_matrix(style_output) for style_output in outputs] | |
| elif type == "content": | |
| features = outputs[ : self.n_content_layers] | |
| return features | |
| def _style_content_loss( | |
| self, | |
| style_targets, | |
| style_outputs, | |
| content_targets, | |
| content_outputs, | |
| style_weight, | |
| content_weight, | |
| ): | |
| """ | |
| Calculates the total loss of the style transfer | |
| style_targets: | |
| the style features of the style image | |
| style_outputs: | |
| the style features of the generated image | |
| content_targets: | |
| the content features of the content image | |
| content_outputs: | |
| the content features of the generated image | |
| style_weight: | |
| the weight of the style loss | |
| content_weight: | |
| the weight of the content loss | |
| """ | |
| # adding the loss of each layer | |
| style_loss = style_weight * tf.add_n( | |
| [ | |
| self._loss(style_target, style_output, type="style") | |
| for style_target, style_output in zip(style_targets, style_outputs) | |
| ] | |
| ) | |
| content_loss = content_weight * tf.add_n( | |
| [ | |
| self._loss(content_target, content_output, type="content") | |
| for content_target, content_output in zip( | |
| content_targets, content_outputs | |
| ) | |
| ] | |
| ) | |
| total_loss = style_loss + content_loss | |
| return total_loss, style_loss, content_loss | |
| def _grad_loss( | |
| self, | |
| generated_image, | |
| style_target, | |
| content_target, | |
| style_weight, | |
| content_weight, | |
| var_weight, | |
| ): | |
| """ | |
| Calculates the gradients of the loss function with respect to the generated image | |
| generated_image: | |
| the generated image | |
| """ | |
| with tf.GradientTape() as tape: | |
| style_features = self.get_features(generated_image, type="style") | |
| content_features = self.get_features(generated_image, type="content") | |
| loss, style_loss, content_loss = self._style_content_loss( | |
| style_target, | |
| style_features, | |
| content_target, | |
| content_features, | |
| style_weight, | |
| content_weight, | |
| ) | |
| variational_loss= var_weight*tf.image.total_variation(generated_image) | |
| loss += variational_loss | |
| grads = tape.gradient(loss, generated_image) | |
| return grads, loss, [style_loss, content_loss, variational_loss] | |
| def _update_image_with_style( | |
| self, | |
| generated_image, | |
| style_target, | |
| content_target, | |
| style_weight, | |
| content_weight, | |
| optimizer, | |
| var_weight, | |
| ): | |
| grads, loss, loss_list = self._grad_loss( | |
| generated_image, style_target, content_target, style_weight, content_weight, var_weight | |
| ) | |
| optimizer.apply_gradients([(grads, generated_image)]) | |
| generated_image.assign( | |
| tf.clip_by_value(generated_image, clip_value_min=0.0, clip_value_max=255.0) | |
| ) | |
| return loss_list |