Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.optim as optim | |
| import torchvision.models as models | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| import numpy as np | |
| import requests | |
| from io import BytesIO | |
| def load_image(img_input, max_size=400, shape=None): | |
| if isinstance(img_input, np.ndarray): | |
| # Convert numpy array to PIL Image | |
| image = Image.fromarray(img_input.astype('uint8'), 'RGB') | |
| elif isinstance(img_input, str): | |
| if "http" in img_input: | |
| response = requests.get(img_input) | |
| image = Image.open(BytesIO(response.content)).convert('RGB') | |
| else: | |
| image = Image.open(img_input).convert('RGB') | |
| else: | |
| raise ValueError("Unsupported input type. Expected numpy array or string.") | |
| # large images will slow down processing | |
| if max(image.size) > max_size: | |
| size = max_size | |
| else: | |
| size = max(image.size) | |
| if shape is not None: | |
| size = shape | |
| in_transform = transforms.Compose([ | |
| transforms.Resize(size), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.485, 0.456, 0.406), | |
| (0.229, 0.224, 0.225))]) | |
| # discard the transparent, alpha channel (that's the :3) and add the batch dimension | |
| image = in_transform(image)[:3,:,:].unsqueeze(0) | |
| return image | |
| # helper function for un-normalizing an image | |
| # and converting it from a Tensor image to a NumPy image for display | |
| def im_convert(tensor): | |
| """ Display a tensor as an image. """ | |
| image = tensor.to("cpu").clone().detach() | |
| image = image.numpy().squeeze() | |
| image = image.transpose(1,2,0) | |
| image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406)) | |
| image = image.clip(0, 1) | |
| return image | |
| def get_features(image, model, layers=None): | |
| """ Run an image forward through a model and get the features for | |
| a set of layers. Default layers are for VGGNet matching Gatys et al (2016) | |
| """ | |
| ## TODO: Complete mapping layer names of PyTorch's VGGNet to names from the paper | |
| ## Need the layers for the content and style representations of an image | |
| if layers is None: | |
| layers = {'0': 'conv1_1', | |
| '5': 'conv2_1', | |
| '10': 'conv3_1', | |
| '19': 'conv4_1', | |
| '21': 'conv4_2', ## content representation | |
| '28': 'conv5_1'} | |
| ## -- do not need to change the code below this line -- ## | |
| features = {} | |
| x = image | |
| # model._modules is a dictionary holding each module in the model | |
| for name, layer in model._modules.items(): | |
| x = layer(x) | |
| if name in layers: | |
| features[layers[name]] = x | |
| return features | |
| def gram_matrix(tensor): | |
| """ Calculate the Gram Matrix of a given tensor | |
| Gram Matrix: https://en.wikipedia.org/wiki/Gramian_matrix | |
| """ | |
| ## get the batch_size, depth, height, and width of the Tensor | |
| ## reshape it, so we're multiplying the features for each channel | |
| ## calculate the gram matrix | |
| # get the batch_size, depth, height, and width of the Tensor | |
| b, d, h, w = tensor.size() | |
| # reshape so we're multiplying the features for each channel | |
| tensor = tensor.view(b * d, h * w) | |
| # calculate the gram matrix | |
| gram = torch.mm(tensor, tensor.t()) | |
| return gram | |
| # Function to resize image while maintaining aspect ratio | |
| def resize_image(image_path, max_size=400): | |
| img = Image.open(image_path).convert('RGB') | |
| ratio = max_size / max(img.size) | |
| new_size = tuple([int(x*ratio) for x in img.size]) | |
| img = img.resize(new_size, Image.Resampling.LANCZOS) | |
| return np.array(img) | |
| def create_grid(images, rows, cols): | |
| assert len(images) == rows * cols, "Number of images doesn't match the grid size" | |
| w, h = images[0].shape[1], images[0].shape[0] | |
| grid = np.zeros((h*rows, w*cols, 3), dtype=np.uint8) | |
| for i, img in enumerate(images): | |
| r, c = divmod(i, cols) | |
| grid[r*h:(r+1)*h, c*w:(c+1)*w] = img | |
| return grid | |
| def style_transfer(content_image, style_image, alpha, beta, conv1_1, conv2_1, conv3_1, conv4_1, conv5_1, steps): | |
| content = load_image(content_image).to(device) | |
| style = load_image(style_image, shape=content.shape[-2:]).to(device) | |
| content_features = get_features(content, vgg) | |
| style_features = get_features(style, vgg) | |
| style_grams = {layer: gram_matrix(style_features[layer]) for layer in style_features} | |
| target = content.clone().requires_grad_(True).to(device) | |
| style_weights = { | |
| 'conv1_1': conv1_1, | |
| 'conv2_1': conv2_1, | |
| 'conv3_1': conv3_1, | |
| 'conv4_1': conv4_1, | |
| 'conv5_1': conv5_1 | |
| } | |
| content_weight = alpha | |
| style_weight = beta * 1e6 | |
| optimizer = optim.Adam([target], lr=0.003) | |
| intermediate_images = [] | |
| show_every = steps // 9 # Show 9 intermediate images | |
| for ii in range(1, steps+1): | |
| target_features = get_features(target, vgg) | |
| content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2'])**2) | |
| style_loss = 0 | |
| for layer in style_weights: | |
| target_feature = target_features[layer] | |
| target_gram = gram_matrix(target_feature) | |
| _, d, h, w = target_feature.shape | |
| style_gram = style_grams[layer] | |
| layer_style_loss = style_weights[layer] * torch.mean((target_gram - style_gram)**2) | |
| style_loss += layer_style_loss / (d * h * w) | |
| total_loss = content_weight * content_loss + style_weight * style_loss | |
| optimizer.zero_grad() | |
| total_loss.backward() | |
| optimizer.step() | |
| if ii % show_every == 0 or ii == steps: | |
| intermediate_images.append(im_convert(target)) | |
| final_image = intermediate_images[-1] | |
| intermediate_grid = create_grid(intermediate_images, 3, 3) | |
| return final_image, intermediate_grid | |
| def load_example(content, style, output): | |
| return content, style, output | |
| # Example images | |
| examples = [ | |
| ["assets/content_1.jpg", "assets/style_1.jpg", "assets/result_1.png"], | |
| ["assets/content_2.jpg", "assets/style_2.jpg", "assets/result_2.png"], | |
| ["assets/content_3.png", "assets/style_3.jpg", "assets/result_3.png"], | |
| ] | |
| #Load VGG19 model | |
| vgg = models.vgg19(pretrained=True).features | |
| for param in vgg.parameters(): | |
| param.requires_grad_(False) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| vgg.to(device) | |
| # Resize example images | |
| resized_examples = [[resize_image(content), resize_image(style), resize_image(output)] for content, style, output in examples] | |
| # Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Neural Style Transfer") | |
| with gr.Row(): | |
| with gr.Column(): | |
| content_input = gr.Image(label="Content Image", type="numpy", image_mode="RGB", height=400, width=400) | |
| style_input = gr.Image(label="Style Image", type="numpy", image_mode="RGB", height=400, width=400) | |
| with gr.Column(): | |
| output_image = gr.Image(label="Output Image") | |
| intermediate_output = gr.Image(label="Intermediate Results") | |
| run_button = gr.Button("Run Style Transfer") | |
| with gr.Row(): | |
| alpha_slider = gr.Slider(minimum=0, maximum=1, value=1, step=0.1, label="Content Weight (α)") | |
| beta_slider = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.1, label="Style Weight (β)") | |
| with gr.Row(): | |
| conv1_1_slider = gr.Slider(minimum=0, maximum=1, value=1, step=0.1, label="Conv1_1 Weight") | |
| conv2_1_slider = gr.Slider(minimum=0, maximum=1, value=0.8, step=0.1, label="Conv2_1 Weight") | |
| conv3_1_slider = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="Conv3_1 Weight") | |
| conv4_1_slider = gr.Slider(minimum=0, maximum=1, value=0.3, step=0.1, label="Conv4_1 Weight") | |
| conv5_1_slider = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.1, label="Conv5_1 Weight") | |
| steps_slider = gr.Slider(minimum=1, maximum=2000, value=1000, step=100, label="Number of Steps") | |
| run_button.click( | |
| style_transfer, | |
| inputs=[ | |
| content_input, | |
| style_input, | |
| alpha_slider, | |
| beta_slider, | |
| conv1_1_slider, | |
| conv2_1_slider, | |
| conv3_1_slider, | |
| conv4_1_slider, | |
| conv5_1_slider, | |
| steps_slider | |
| ], | |
| outputs=[output_image, intermediate_output] | |
| ) | |
| gr.Examples( | |
| resized_examples, | |
| inputs=[content_input, style_input, output_image], | |
| outputs=[content_input, style_input, output_image], | |
| fn=load_example, | |
| cache_examples=True | |
| ) | |
| demo.launch() |