import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from PIL import Image import torchvision.transforms as transforms import torchvision.models as models import copy import time import os import io # Check if GPU is available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Image loading and preprocessing def image_loader(image_path, imsize=512): loader = transforms.Compose([ transforms.Resize(imsize), # Scale imported image transforms.CenterCrop(imsize), # Ensure square size transforms.ToTensor(), # Transform into torch tensor transforms.Lambda(lambda x: x.repeat(1, 1, 1) if x.size(0) == 1 else x) # Convert grayscale to RGB if needed ]) image = Image.open(image_path).convert('RGB') # Ensure image is RGB # Add batch dimension (1, 3, h, w) image = loader(image).unsqueeze(0) return image.to(device, torch.float) def load_image_from_bytes(image_bytes, imsize=512): loader = transforms.Compose([ transforms.Resize(imsize), transforms.CenterCrop(imsize), transforms.ToTensor(), transforms.Lambda(lambda x: x.repeat(1, 1, 1) if x.size(0) == 1 else x) ]) image = Image.open(io.BytesIO(image_bytes)).convert('RGB') image = loader(image).unsqueeze(0) return image.to(device, torch.float) # Content Loss: Measures content similarity class ContentLoss(nn.Module): def __init__(self, target): super(ContentLoss, self).__init__() # Detach the target content from the tree used to dynamically compute gradients self.target = target.detach() def forward(self, input): self.loss = F.mse_loss(input, self.target) return input # Gram matrix calculation for style representation def gram_matrix(input): batch_size, n_channels, height, width = input.size() features = input.view(batch_size * n_channels, height * width) G = torch.mm(features, features.t()) # Normalize by total number of elements return G.div(batch_size * n_channels * height * width) # Style Loss: Measures style similarity using Gram matrices class StyleLoss(nn.Module): def __init__(self, target_feature): super(StyleLoss, self).__init__() self.target = gram_matrix(target_feature).detach() self.weight = 1.0 # Default weight for this layer def forward(self, input): G = gram_matrix(input) self.loss = F.mse_loss(G, self.target) return input # Normalization layer for VGG compatibility class Normalization(nn.Module): def __init__(self, mean, std): super(Normalization, self).__init__() # View the mean and std as 1x3x1x1 tensors self.mean = mean.clone().detach().view(-1, 1, 1).to(device) self.std = std.clone().detach().view(-1, 1, 1).to(device) def forward(self, img): # Normalize img return (img - self.mean) / self.std # Build model with content and style losses def get_style_model_and_losses(cnn, normalization_mean, normalization_std, style_img, content_img, content_layers=['conv_4'], style_layers=['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5'], layer_weights=None): normalization = Normalization(normalization_mean, normalization_std) # Set default layer weights if not provided if layer_weights is None: layer_weights = {layer: 1.0 for layer in style_layers} # Lists to keep track of losses content_losses = [] style_losses = [] # Create a "sequential" module with added content/style loss layers model = nn.Sequential(normalization) i = 0 # Increment for each conv layer for layer in cnn.children(): if isinstance(layer, nn.Conv2d): i += 1 name = f'conv_{i}' elif isinstance(layer, nn.ReLU): name = f'relu_{i}' # Replace in-place version with out-of-place layer = nn.ReLU(inplace=False) elif isinstance(layer, nn.MaxPool2d): name = f'pool_{i}' elif isinstance(layer, nn.BatchNorm2d): name = f'bn_{i}' else: raise RuntimeError(f'Unrecognized layer: {layer.__class__.__name__}') model.add_module(name, layer) # Add content loss if name in content_layers: # Add content loss: target = model(content_img).detach() content_loss = ContentLoss(target) model.add_module(f"content_loss_{i}", content_loss) content_losses.append(content_loss) # Add style loss if name in style_layers: # Add style loss: target_feature = model(style_img).detach() style_loss = StyleLoss(target_feature) # Apply customized layer weight style_loss.weight = layer_weights.get(name, 1.0) model.add_module(f"style_loss_{i}", style_loss) style_losses.append(style_loss) # Trim off the layers after the last content and style losses for i in range(len(model) - 1, -1, -1): if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss): break model = model[:(i + 1)] return model, style_losses, content_losses # Optimization loop for style transfer def run_style_transfer(cnn, normalization_mean, normalization_std, content_img, style_img, input_img, num_steps=300, style_weight=1000000, content_weight=1, layer_weights=None, progress_callback=None): """Run the style transfer.""" num_steps = min(num_steps, 400) print('Building the style transfer model...') model, style_losses, content_losses = get_style_model_and_losses( cnn, normalization_mean, normalization_std, style_img, content_img, layer_weights=layer_weights ) # We want to optimize the input image only input_img.requires_grad_(True) model.eval() # We don't need gradients for the model parameters model.requires_grad_(False) optimizer = optim.LBFGS([input_img]) best_img = None best_loss = float('inf') prev_loss = float('inf') current_step = 0 start_time = time.time() # Function to be used with optimizer def closure(): nonlocal current_step # Correct the values of updated input image with torch.no_grad(): input_img.clamp_(0, 1) optimizer.zero_grad() model(input_img) style_score = 0 content_score = 0 for sl in style_losses: # Apply per-layer weight style_score += sl.loss * sl.weight for cl in content_losses: content_score += cl.loss style_score *= style_weight content_score *= content_weight loss = style_score + content_score loss.backward() current_step += 1 if current_step % 50 == 0: elapsed = time.time() - start_time print(f"Iteration: {current_step}, Style Loss: {style_score.item():.2f}, Content Loss: {content_score.item():.2f}, Total Loss: {loss.item():.2f}, Time: {elapsed:.1f}s") if progress_callback: progress = { 'iteration': current_step, 'style_loss': style_score.item(), 'content_loss': content_score.item(), 'elapsed_time': elapsed } progress_callback(progress) # Save best result so far nonlocal best_loss, best_img, prev_loss current_loss = loss.item() if current_loss < best_loss: best_loss = current_loss best_img = input_img.clone() # Update previous loss for next iteration prev_loss = current_loss return loss # Run optimization with early stopping while current_step < num_steps: optimizer.step(closure) # Check stopping conditions after minimum iterations if current_step >= 50 and prev_loss > 1000: print(f"Stopping early at iteration {current_step} due to high loss: {prev_loss:.2f}") break # A final correction with torch.no_grad(): input_img.clamp_(0, 1) print(f"Total time: {time.time() - start_time:.1f}s") print(f"Best loss achieved: {best_loss:.2f}") # Return both the final and best image (often the same) return input_img, best_img, best_loss # Save tensor as image def save_image(tensor, path): image = tensor.cpu().clone() image = image.squeeze(0) # Remove batch dimension image = transforms.ToPILImage()(image) image.save(path) return image # Main style transfer function def transfer_style(content_path, style_path, output_path, style_weight=1000000, content_weight=1, num_steps=300, layer_weights=None, progress_callback=None): """ Perform style transfer and save the result Args: content_path: Path to content image style_path: Path to style image output_path: Where to save the output image style_weight: Weight for style loss content_weight: Weight for content loss num_steps: Number of optimization steps layer_weights: Dictionary of weights for each style layer progress_callback: Function to call for progress updates Returns: Tuple of (output_path, best_loss) """ # Load images content_img = image_loader(content_path) style_img = image_loader(style_path) # Start with content image for faster convergence input_img = content_img.clone() # Load VGG19 for feature extraction cnn = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features.to(device).eval() # Mean and std for normalization (from ImageNet) cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device) cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device) # Run style transfer output, best_output, best_loss = run_style_transfer( cnn, cnn_normalization_mean, cnn_normalization_std, content_img, style_img, input_img, num_steps=num_steps, style_weight=style_weight, content_weight=content_weight, layer_weights=layer_weights, progress_callback=progress_callback ) # Save result and return path save_image(best_output, output_path) return output_path, best_loss