| import torch |
| import torch.nn as nn |
| from torchvision import models, transforms |
| import numpy as np |
| from PIL import Image |
| |
| from torchvision.models import vgg19 as vggm, VGG19_Weights |
| import os |
| from pathlib import Path |
| import gradio as gr |
| from tqdm.auto import tqdm |
|
|
| SIZE = (640, 480) |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| device |
|
|
| |
| vgg19 = vggm(weights=VGG19_Weights.IMAGENET1K_V1).to(device) |
| vgg19.eval() |
|
|
| mean = (0.485, 0.456, 0.406) |
| std = (0.229, 0.224, 0.225) |
|
|
| |
| transforms_preprocess = transforms.Compose([ |
| transforms.Resize(SIZE), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=mean, std=std), |
| ]) |
|
|
| |
| def image_load(img, transform = transforms_preprocess, verbose= True): |
| if isinstance(img, str): |
| img = Image.open(img).convert('RGB') |
| elif isinstance(img, Image.Image): |
| img = img.convert('RGB') |
| elif isinstance(img, torch.Tensor): |
| |
| if img.ndim == 3: |
| img = img.unsqueeze(0) |
| elif img.ndim == 4: |
| pass |
| else: |
| raise ValueError(f"Unexpected tensor shape: {img.shape}") |
| if verbose: |
| print(img.shape) |
| return img.to(device) |
| else: |
| raise TypeError(f"Unsupported input type: {type(img)}") |
|
|
| |
|
|
| image = transform(img) |
| image = torch.unsqueeze(image, dim= 0) |
| if verbose: |
| print(image.shape) |
| return image.to(device) |
|
|
|
|
| |
| STYLE_IMAGE_PATH = './assets/style/' |
| CONTENT_IMAGE_PATH = './assets/target/' |
|
|
| style_files= os.listdir(STYLE_IMAGE_PATH) |
| content_files= os.listdir(CONTENT_IMAGE_PATH) |
|
|
| style_files = [Path(STYLE_IMAGE_PATH, f) for f in style_files] |
| content_files = [Path(CONTENT_IMAGE_PATH, f) for f in content_files] |
|
|
|
|
| |
| |
|
|
|
|
| |
| LOSS_LAYERS = { '0': 'conv1_1', |
| '5': 'conv2_1', |
| '10': 'conv3_1', |
| '19': 'conv4_1', |
| '21': 'conv4_2', |
| '28': 'conv5_1'} |
|
|
|
|
| |
| def feature_extractor(x, model_features): |
| extracted_features = {} |
| for name, layer in model_features._modules.items(): |
| x = layer(x) |
| if name in LOSS_LAYERS: |
| extracted_features[LOSS_LAYERS[name]] = x |
| return extracted_features |
|
|
|
|
|
|
| |
| def gram_matrix_calculator(feature_tensor): |
| |
| n, c, h, w = feature_tensor.size() |
| feat = feature_tensor.squeeze(0) |
| feat = feat.view(c, h * w) |
| gram = torch.mm(feat, feat.t()) |
| gram = gram.div(c * h * w) |
| return gram |
|
|
| def training_loop(style_image, content_image, num_steps=300, style_weight=1e8, content_weight=1.0): |
| style_tr = image_load(style_image) |
| content_tr = image_load(content_image) |
|
|
| |
| for p in vgg19.parameters(): |
| p.requires_grad = False |
|
|
| |
| style_features = feature_extractor(style_tr, vgg19.features) |
| content_features = feature_extractor(content_tr, vgg19.features) |
| style_gram = {layer: gram_matrix_calculator(style_features[layer]) for layer in style_features} |
|
|
| |
| target = content_tr.clone().detach().requires_grad_(True) |
|
|
| |
| optimizer = torch.optim.LBFGS([target]) |
|
|
| |
| weights = {'conv1_1': 1.0, 'conv2_1': 0.8, 'conv3_1': 0.6, 'conv4_1': 0.4, 'conv5_1': 0.2} |
| loss_fn = nn.functional.mse_loss |
|
|
| run = [0] |
|
|
| def closure(): |
| optimizer.zero_grad() |
|
|
| target_features = feature_extractor(target, vgg19.features) |
|
|
| |
| c_loss = loss_fn(target_features['conv4_2'], content_features['conv4_2']) |
|
|
| |
| s_loss = 0.0 |
| for layer in weights: |
| target_gram = gram_matrix_calculator(target_features[layer]) |
| style_gram_matrix = style_gram[layer] |
| s_loss += loss_fn(target_gram, style_gram_matrix) * weights[layer] |
|
|
| total_loss = style_weight * s_loss + content_weight * c_loss |
| total_loss.backward() |
|
|
| if run[0] % 50 == 0: |
| print(f"Step {run[0]} | Total: {total_loss.item():.4f} | " |
| f"Style: {s_loss.item():.4f} | Content: {c_loss.item():.4f}") |
| run[0] += 1 |
|
|
| return total_loss |
|
|
| |
| with tqdm(range(num_steps), desc="Optimizing") as pbar: |
| for _ in pbar: |
| optimizer.step(closure) |
|
|
| return target.detach().cpu() |
|
|
| |
| mean = (0.485, 0.456, 0.406) |
| std = (0.229, 0.224, 0.225) |
|
|
| def tensor_to_image(image_tensor): |
| |
| image = image_tensor.clone().detach().cpu().squeeze(0) |
| image = image.permute(1, 2, 0).numpy() |
| image = image * np.array(std)[None, None, :] + np.array(mean)[None, None, :] |
| image = np.clip(image, 0.0, 1.0) |
| image_uint8 = (image * 255).astype(np.uint8) |
| return Image.fromarray(image_uint8) |
|
|
|
|
|
|
| def style_transfer(style_image, content_image): |
| optimized = training_loop(style_image, content_image, num_steps=2) |
| return tensor_to_image(optimized) |
|
|
| |
| |
| if __name__ == "__main__": |
| |
| gr.close_all() |
| with gr.Blocks(theme=gr.themes.Glass()) as interface: |
| with gr.Row(): |
| gr.Markdown("<h2 style='color: blue;'>Vanilla CNN Style Transfer</h2>") |
| with gr.Row(): |
| with gr.Column(): |
| style_image=gr.Image(type="pil", label="Style Image",height=300,width=300) |
| with gr.Column(): |
| content_image=gr.Image(type="pil", label="Content Image",height=300,width=300) |
| with gr.Column(): |
| image_output=gr.Image(type="pil", label="Generated Image",height=300,width=300) |
| |
| with gr.Row(): |
| style_transfer_button=gr.Button("Style Transfer",variant= "primary") |
| reset_button=gr.Button("Reset",variant= "secondary") |
|
|
| with gr.Row(): |
| style_examples=gr.Examples(examples=[f for f in style_files], inputs=[style_image], label="Style Images") |
| content_examples=gr.Examples(examples=[f for f in content_files], inputs=[content_image], label="Content Images") |
|
|
| style_transfer_button.click( |
| fn=style_transfer, |
| inputs=[style_image, content_image], |
| outputs=[image_output] |
| ) |
| reset_button.click( |
| fn=lambda: gr.update(value=None), |
| outputs=[image_output] |
| ) |
| interface.launch(share=False, server_port=7860) |
| |
|
|