Spaces:
Build error
Build error
| import gradio as gr | |
| import spaces | |
| import tensorflow as tf | |
| import numpy as np | |
| from PIL import Image | |
| import logging | |
| import time | |
| from tqdm import tqdm | |
| # Initialize logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("style_transfer_app") | |
| # Set TensorFlow threading options | |
| tf.config.threading.set_inter_op_parallelism_threads(8) | |
| tf.config.threading.set_intra_op_parallelism_threads(8) | |
| def load_img(image): | |
| """Load and preprocess image for style transfer""" | |
| max_dim = 512 | |
| # Convert PIL Image to tensor | |
| img = tf.convert_to_tensor(np.array(image)) | |
| img = tf.image.convert_image_dtype(img, tf.float32) | |
| shape = tf.cast(tf.shape(img)[:-1], tf.float32) | |
| long_dim = max(shape) | |
| scale = max_dim / long_dim | |
| new_shape = tf.cast(shape * scale, tf.int32) | |
| img = tf.image.resize(img, new_shape) | |
| img = img[tf.newaxis, :] | |
| return img | |
| def vgg_layers(layer_names): | |
| """Create VGG model with specified layers""" | |
| vgg = tf.keras.applications.VGG19(include_top=False, weights='imagenet') | |
| vgg.trainable = False | |
| outputs = [vgg.get_layer(name).output for name in layer_names] | |
| model = tf.keras.Model([vgg.input], outputs) | |
| return model | |
| def gram_matrix(input_tensor): | |
| """Calculate Gram matrix""" | |
| result = tf.linalg.einsum('bijc,bijd->bcd', input_tensor, input_tensor) | |
| input_shape = tf.shape(input_tensor) | |
| num_locations = tf.cast(input_shape[1]*input_shape[2], tf.float32) | |
| return result / num_locations | |
| class StyleContentModel(tf.keras.models.Model): | |
| def __init__(self, style_layers, content_layers): | |
| super(StyleContentModel, self).__init__() | |
| self.vgg = vgg_layers(style_layers + content_layers) | |
| self.style_layers = style_layers | |
| self.content_layers = content_layers | |
| self.num_style_layers = len(style_layers) | |
| self.vgg.trainable = False | |
| def call(self, inputs): | |
| inputs = inputs * 255.0 | |
| preprocessed_input = tf.keras.applications.vgg19.preprocess_input(inputs) | |
| outputs = self.vgg(preprocessed_input) | |
| style_outputs, content_outputs = (outputs[:self.num_style_layers], | |
| outputs[self.num_style_layers:]) | |
| style_outputs = [gram_matrix(style_output) | |
| for style_output in style_outputs] | |
| content_dict = {content_name: value | |
| for content_name, value | |
| in zip(self.content_layers, content_outputs)} | |
| style_dict = {style_name: value | |
| for style_name, value | |
| in zip(self.style_layers, style_outputs)} | |
| return {'content': content_dict, 'style': style_dict} | |
| def clip_0_1(image): | |
| """Clip tensor values between 0 and 1""" | |
| return tf.clip_by_value(image, clip_value_min=0.0, clip_value_max=1.0) | |
| def style_content_loss(outputs, style_targets, content_targets, style_weight, content_weight): | |
| """Calculate style and content loss""" | |
| style_outputs = outputs['style'] | |
| content_outputs = outputs['content'] | |
| style_loss = tf.add_n([tf.reduce_mean((style_outputs[name]-style_targets[name])**2) | |
| for name in style_outputs.keys()]) | |
| style_loss *= style_weight / len(style_outputs) | |
| content_loss = tf.add_n([tf.reduce_mean((content_outputs[name]-content_targets[name])**2) | |
| for name in content_outputs.keys()]) | |
| content_loss *= content_weight / len(content_outputs) | |
| loss = style_loss + content_loss | |
| return loss | |
| def train_step(image, extractor, style_targets, content_targets, opt, style_weight, content_weight, total_variation_weight): | |
| """Perform one training step""" | |
| with tf.GradientTape() as tape: | |
| outputs = extractor(image) | |
| loss = style_content_loss(outputs, style_targets, content_targets, style_weight, content_weight) | |
| loss += total_variation_weight * tf.image.total_variation(image) | |
| grad = tape.gradient(loss, image) | |
| opt.apply_gradients([(grad, image)]) | |
| image.assign(clip_0_1(image)) | |
| return loss | |
| def tensor_to_image(tensor): | |
| """Convert tensor to PIL Image""" | |
| tensor = tensor * 255 | |
| tensor = np.array(tensor, dtype=np.uint8) | |
| if np.ndim(tensor) > 3: | |
| tensor = tensor[0] | |
| return Image.fromarray(tensor) | |
| # Initialize the style-content model | |
| style_layers = ['block1_conv1', 'block2_conv1', 'block3_conv1', 'block4_conv1', 'block5_conv1'] | |
| content_layers = ['block5_conv2'] | |
| extractor = StyleContentModel(style_layers, content_layers) | |
| # Style transfer typically needs more than 60s | |
| def style_transfer_fn(content_image, style_image, progress=gr.Progress(track_tqdm=True)): | |
| """Main style transfer function for Gradio interface""" | |
| try: | |
| # Preprocess images | |
| content_img = load_img(content_image) | |
| style_img = load_img(style_image) | |
| # Extract style and content features | |
| style_targets = extractor(style_img)['style'] | |
| content_targets = extractor(content_img)['content'] | |
| image = tf.Variable(content_img) | |
| # Set optimization parameters | |
| opt = tf.keras.optimizers.Adam(learning_rate=0.02, beta_1=0.99, epsilon=1e-1) | |
| style_weight = 1e-2 | |
| content_weight = 1e4 | |
| total_variation_weight = 30 | |
| epochs = 10 | |
| steps_per_epoch = 100 | |
| start_time = time.time() | |
| # Training loop | |
| for n in tqdm(range(epochs), desc="Epochs"): | |
| for m in tqdm(range(steps_per_epoch), desc="Steps", leave=False): | |
| loss = train_step(image, extractor, style_targets, content_targets, | |
| opt, style_weight, content_weight, total_variation_weight) | |
| # Convert result to image | |
| result_image = tensor_to_image(image) | |
| return result_image | |
| except Exception as e: | |
| logger.error(f"Error during style transfer: {e}") | |
| raise gr.Error("An error occurred during style transfer.") | |
| # Create Gradio interface | |
| iface = gr.Interface( | |
| fn=style_transfer_fn, | |
| inputs=[ | |
| gr.Image(label="Content Image", type="pil"), | |
| gr.Image(label="Style Image", type="pil") | |
| ], | |
| outputs=gr.Image(label="Stylized Image"), | |
| title="Neural Style Transfer - Ty Chermsirivatana", | |
| description="Upload a content image and a style image to create a stylized image in context.", | |
| ) | |
| # Launch the interface | |
| if __name__ == "__main__": | |
| iface.launch() |