Style-Transfer / style_transfer.py
hari31416's picture
Added files
44f9699
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from keras import backend as K
class StyleTransfer:
"""A class for neural style transfer. Uses the Inception model to extract features from the content and style images."""
content_layers = ["conv2d_88", "conv2d_91", "conv2d_92", "conv2d_85", "conv2d_93"]
style_layers = ["conv2d", "conv2d_1", "conv2d_2", "conv2d_3", "conv2d_4"]
content_and_style_layers = content_layers + style_layers
NUM_CONTENT_LAYERS = len(content_layers)
NUM_STYLE_LAYERS = len(style_layers)
def __init__(self, content_image_path, style_image_path) -> None:
"""Initializes the class
Parameters
----------
content_image_path : str
path to the content image
style_image_path : str
path to the style image
Returns
-------
None
"""
self.content_image_path = content_image_path
self.style_image_path = style_image_path
self.model = None
def tensor_to_image(self, tensor):
"""converts a tensor to an image"""
tensor_shape = tf.shape(tensor)
number_elem_shape = tf.shape(tensor_shape)
if number_elem_shape > 3:
assert tensor_shape[0] == 1, "There are more than one image"
tensor = tensor[0]
return tf.keras.preprocessing.image.array_to_img(tensor)
def load_image(self, path_to_img):
"""loads an image as a tensor and scales it to 512 pixels"""
max_dim = 512
image = tf.io.read_file(path_to_img)
image = tf.image.decode_jpeg(image)
image = tf.image.convert_image_dtype(image, tf.float32)
shape = tf.shape(image)[:-1]
shape = tf.cast(tf.shape(image)[:-1], tf.float32)
long_dim = max(shape)
scale = max_dim / long_dim
new_shape = tf.cast(shape * scale, tf.int32)
image = tf.image.resize(image, new_shape)
image = image[tf.newaxis, :]
image = tf.image.convert_image_dtype(image, tf.uint8)
return image
def imshow(self, image, title=""):
"""displays an image"""
if len(image.shape) > 3:
image = tf.squeeze(image, axis=0)
plt.imshow(image)
plt.title(title)
def show_images_with_style(self, images, titles=[]):
"""displays a row of images with corresponding titles"""
if len(images) != len(titles):
return
plt.figure(figsize=(20, 12))
for idx, (image, title) in enumerate(zip(images, titles)):
plt.subplot(1, len(images), idx + 1)
plt.xticks([])
plt.yticks([])
self.imshow(image, title)
plt.show()
def preprocess_image(self, image):
"""preprocesses a given image to use with Inception model"""
image = tf.cast(image, dtype=tf.float32)
image = (image / 127.5) - 1.0
return image
def display_images(self):
"""displays the content and style images"""
content_image = self.load_image(self.content_image_path)
style_image = self.load_image(self.style_image_path)
self.show_images_with_style(
[content_image, style_image],
titles=[f"Content image", f"Style image"],
)
def gram_matrix(self, input_tensor):
"""Calculates the gram matrix and divides by the number of locations
Parameters
----------
input_tensor : tensor
tensor to calculate the gram matrix from
Returns
-------
tensor
gram matrix of the input tensor
"""
gram = tf.linalg.einsum("bijc,bijd->bcd", input_tensor, input_tensor)
input_shape = tf.shape(input_tensor)
num_locations = tf.cast(input_shape[1] * input_shape[2], tf.float32)
scaled_gram = gram / num_locations
return scaled_gram
def get_features(self, image, type=None):
"""Returns the features of the image
Parameters
----------
image : tensor
image to extract features from
type : str
type of features to extract. Either "style" or "content". If `None` is provided, both
content and style features are returned
Returns
-------
list
list of features of the content and style images
"""
preprocessed_image = self.preprocess_image(image)
outputs = self.model(preprocessed_image)
if type == "style":
style_outputs = outputs[self.NUM_CONTENT_LAYERS :]
gram_style_features = [
self.gram_matrix(style_output) for style_output in style_outputs
]
return gram_style_features
elif type == "content":
content_outputs = outputs[: self.NUM_CONTENT_LAYERS]
return content_outputs
else:
style_outputs = outputs[self.NUM_CONTENT_LAYERS :]
content_outputs = outputs[: self.NUM_CONTENT_LAYERS]
gram_style_features = [
self.gram_matrix(style_output) for style_output in style_outputs
]
return content_outputs + gram_style_features
def _loss(self, features, targets, type="style"):
"""Returns the loss of fearure and target. This is just the mean square error.
features : list
list of features of the content and style images
target : list
list of features of the content and style images
type : str
type of loss to calculate. Either "style" or "content"
"""
loss = tf.reduce_mean(tf.square(features - targets))
if type == "content":
loss = loss * 0.5
return loss
def get_loss(self, features, target, alpha, beta):
"""Returns the total loss of the style and content images
Parameters
----------
features : list
list of features of the content and style images
target : list
list of features of the content and style images
alpha : float
weight of the content loss
beta : float
weight of the style loss
Returns
-------
loss : float
total loss of the style and content images
"""
style_features = features[self.NUM_CONTENT_LAYERS :]
content_features = features[: self.NUM_CONTENT_LAYERS]
style_targets = target[self.NUM_CONTENT_LAYERS :]
content_targets = target[: self.NUM_CONTENT_LAYERS]
style_loss = 0
content_loss = 0
for i in range(self.NUM_STYLE_LAYERS):
style_loss += self._loss(style_features[i], style_targets[i], type="style")
for i in range(self.NUM_CONTENT_LAYERS):
content_loss += self._loss(
content_features[i], content_targets[i], type="content"
)
style_loss = beta * style_loss / self.NUM_STYLE_LAYERS
content_loss = alpha * content_loss / self.NUM_CONTENT_LAYERS
loss = content_loss + style_loss
return loss
def calculate_gradients(self, image, target, alpha, beta):
"""Calculates the gradients of the loss with respect to the image"""
with tf.GradientTape() as tape:
features = self.get_features(image, "all")
loss = self.get_loss(features, target, alpha, beta)
gradients = tape.gradient(loss, image)
return gradients, loss
def update_image(self, image, target, alpha, beta, optimizer):
"""Updates the image by calculating the gradients and applying them to the image"""
gradients, loss = self.calculate_gradients(image, target, alpha, beta)
optimizer.apply_gradients([(gradients, image)])
image.assign(tf.clip_by_value(image, clip_value_min=0.0, clip_value_max=255.0))
return loss
def load_model(self):
"""Creates a inception model that returns a list of intermediate output values"""
K.clear_session()
inception = tf.keras.applications.InceptionV3(
include_top=False, weights="imagenet"
)
inception.trainable = False
output_layers = self.content_and_style_layers
model = tf.keras.models.Model(
[inception.input],
[inception.get_layer(name).output for name in output_layers],
)
self.model = model
return model
def stylize_image(
self,
alpha=1,
beta=0.1,
epochs=10,
steps_per_epoch=10,
show_images=True,
image_frequency=2,
notebook=False,
lr=None,
):
"""Stylizes the image using the style and content images
Parameters
----------
alpha : float, optional
Content weight, by default 1
beta : float, optional
Style weight, by default 0.1
epochs : int, optional
Number of epochs, by default 10
steps_per_epoch : int, optional
Number of steps per epoch, by default 10
show_images : bool, optional
Show images, by default True
image_frequency : int, optional
Frequency of images to show, by default 2
notebook : bool, optional
If the code is running on a notebook, by default False
lr : float, optional
Learning rate, by default None
Returns
-------
[PIL.Image]
List of images
"""
if self.model is None:
K.clear_session()
_ = self.load_model()
style_image = self.load_image(self.style_image_path)
content_image = self.load_image(self.content_image_path)
style_target = self.get_features(style_image, "style")
content_target = self.get_features(content_image, "content")
target = content_target + style_target
image = tf.cast(content_image, dtype=tf.float32)
image = tf.Variable(image)
# images = []
if lr is None:
lr = 40.0
optimizer = tf.optimizers.Adam(
tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate=lr, decay_steps=100, decay_rate=0.80
)
)
img = None
for epoch in range(epochs):
for step in range(steps_per_epoch):
loss = self.update_image(image, target, alpha, beta, optimizer)
display_image = self.tensor_to_image(image)
# images.append(display_image)
if show_images:
if (step) % image_frequency == 0:
# save the display_image
display_image.save(f".img.jpg")
# if notebook:
# display_image = self.tensor_to_image(image)
# display_fn(
# display_image,
# clear=True,
# )
# else:
# im = np.array(display_image)
# if img is None:
# img = plt.imshow(im)
# else:
# img.set_data(im)
# plt.pause(0.1)
# plt.draw()
# yield np.array(display_image)
# print(f"Epoch: {epoch+1} | Step {step+1} | Loss {loss}", end="\r")
# return images