StyleExplorer / style_transfer.py
Adisri99's picture
Upload 12 files
c4a0359 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
import torchvision.transforms as transforms
import torchvision.models as models
import copy
import time
import os
import io
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Image loading and preprocessing
def image_loader(image_path, imsize=512):
loader = transforms.Compose([
transforms.Resize(imsize), # Scale imported image
transforms.CenterCrop(imsize), # Ensure square size
transforms.ToTensor(), # Transform into torch tensor
transforms.Lambda(lambda x: x.repeat(1, 1, 1) if x.size(0) == 1 else x) # Convert grayscale to RGB if needed
])
image = Image.open(image_path).convert('RGB') # Ensure image is RGB
# Add batch dimension (1, 3, h, w)
image = loader(image).unsqueeze(0)
return image.to(device, torch.float)
def load_image_from_bytes(image_bytes, imsize=512):
loader = transforms.Compose([
transforms.Resize(imsize),
transforms.CenterCrop(imsize),
transforms.ToTensor(),
transforms.Lambda(lambda x: x.repeat(1, 1, 1) if x.size(0) == 1 else x)
])
image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
image = loader(image).unsqueeze(0)
return image.to(device, torch.float)
# Content Loss: Measures content similarity
class ContentLoss(nn.Module):
def __init__(self, target):
super(ContentLoss, self).__init__()
# Detach the target content from the tree used to dynamically compute gradients
self.target = target.detach()
def forward(self, input):
self.loss = F.mse_loss(input, self.target)
return input
# Gram matrix calculation for style representation
def gram_matrix(input):
batch_size, n_channels, height, width = input.size()
features = input.view(batch_size * n_channels, height * width)
G = torch.mm(features, features.t())
# Normalize by total number of elements
return G.div(batch_size * n_channels * height * width)
# Style Loss: Measures style similarity using Gram matrices
class StyleLoss(nn.Module):
def __init__(self, target_feature):
super(StyleLoss, self).__init__()
self.target = gram_matrix(target_feature).detach()
self.weight = 1.0 # Default weight for this layer
def forward(self, input):
G = gram_matrix(input)
self.loss = F.mse_loss(G, self.target)
return input
# Normalization layer for VGG compatibility
class Normalization(nn.Module):
def __init__(self, mean, std):
super(Normalization, self).__init__()
# View the mean and std as 1x3x1x1 tensors
self.mean = mean.clone().detach().view(-1, 1, 1).to(device)
self.std = std.clone().detach().view(-1, 1, 1).to(device)
def forward(self, img):
# Normalize img
return (img - self.mean) / self.std
# Build model with content and style losses
def get_style_model_and_losses(cnn, normalization_mean, normalization_std,
style_img, content_img,
content_layers=['conv_4'],
style_layers=['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5'],
layer_weights=None):
normalization = Normalization(normalization_mean, normalization_std)
# Set default layer weights if not provided
if layer_weights is None:
layer_weights = {layer: 1.0 for layer in style_layers}
# Lists to keep track of losses
content_losses = []
style_losses = []
# Create a "sequential" module with added content/style loss layers
model = nn.Sequential(normalization)
i = 0 # Increment for each conv layer
for layer in cnn.children():
if isinstance(layer, nn.Conv2d):
i += 1
name = f'conv_{i}'
elif isinstance(layer, nn.ReLU):
name = f'relu_{i}'
# Replace in-place version with out-of-place
layer = nn.ReLU(inplace=False)
elif isinstance(layer, nn.MaxPool2d):
name = f'pool_{i}'
elif isinstance(layer, nn.BatchNorm2d):
name = f'bn_{i}'
else:
raise RuntimeError(f'Unrecognized layer: {layer.__class__.__name__}')
model.add_module(name, layer)
# Add content loss
if name in content_layers:
# Add content loss:
target = model(content_img).detach()
content_loss = ContentLoss(target)
model.add_module(f"content_loss_{i}", content_loss)
content_losses.append(content_loss)
# Add style loss
if name in style_layers:
# Add style loss:
target_feature = model(style_img).detach()
style_loss = StyleLoss(target_feature)
# Apply customized layer weight
style_loss.weight = layer_weights.get(name, 1.0)
model.add_module(f"style_loss_{i}", style_loss)
style_losses.append(style_loss)
# Trim off the layers after the last content and style losses
for i in range(len(model) - 1, -1, -1):
if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss):
break
model = model[:(i + 1)]
return model, style_losses, content_losses
# Optimization loop for style transfer
def run_style_transfer(cnn, normalization_mean, normalization_std,
content_img, style_img, input_img, num_steps=300,
style_weight=1000000, content_weight=1,
layer_weights=None, progress_callback=None):
"""Run the style transfer."""
num_steps = min(num_steps, 400)
print('Building the style transfer model...')
model, style_losses, content_losses = get_style_model_and_losses(
cnn, normalization_mean, normalization_std,
style_img, content_img,
layer_weights=layer_weights
)
# We want to optimize the input image only
input_img.requires_grad_(True)
model.eval() # We don't need gradients for the model parameters
model.requires_grad_(False)
optimizer = optim.LBFGS([input_img])
best_img = None
best_loss = float('inf')
prev_loss = float('inf')
current_step = 0
start_time = time.time()
# Function to be used with optimizer
def closure():
nonlocal current_step
# Correct the values of updated input image
with torch.no_grad():
input_img.clamp_(0, 1)
optimizer.zero_grad()
model(input_img)
style_score = 0
content_score = 0
for sl in style_losses:
# Apply per-layer weight
style_score += sl.loss * sl.weight
for cl in content_losses:
content_score += cl.loss
style_score *= style_weight
content_score *= content_weight
loss = style_score + content_score
loss.backward()
current_step += 1
if current_step % 50 == 0:
elapsed = time.time() - start_time
print(f"Iteration: {current_step}, Style Loss: {style_score.item():.2f}, Content Loss: {content_score.item():.2f}, Total Loss: {loss.item():.2f}, Time: {elapsed:.1f}s")
if progress_callback:
progress = {
'iteration': current_step,
'style_loss': style_score.item(),
'content_loss': content_score.item(),
'elapsed_time': elapsed
}
progress_callback(progress)
# Save best result so far
nonlocal best_loss, best_img, prev_loss
current_loss = loss.item()
if current_loss < best_loss:
best_loss = current_loss
best_img = input_img.clone()
# Update previous loss for next iteration
prev_loss = current_loss
return loss
# Run optimization with early stopping
while current_step < num_steps:
optimizer.step(closure)
# Check stopping conditions after minimum iterations
if current_step >= 50 and prev_loss > 1000:
print(f"Stopping early at iteration {current_step} due to high loss: {prev_loss:.2f}")
break
# A final correction
with torch.no_grad():
input_img.clamp_(0, 1)
print(f"Total time: {time.time() - start_time:.1f}s")
print(f"Best loss achieved: {best_loss:.2f}")
# Return both the final and best image (often the same)
return input_img, best_img, best_loss
# Save tensor as image
def save_image(tensor, path):
image = tensor.cpu().clone()
image = image.squeeze(0) # Remove batch dimension
image = transforms.ToPILImage()(image)
image.save(path)
return image
# Main style transfer function
def transfer_style(content_path, style_path, output_path, style_weight=1000000,
content_weight=1, num_steps=300, layer_weights=None,
progress_callback=None):
"""
Perform style transfer and save the result
Args:
content_path: Path to content image
style_path: Path to style image
output_path: Where to save the output image
style_weight: Weight for style loss
content_weight: Weight for content loss
num_steps: Number of optimization steps
layer_weights: Dictionary of weights for each style layer
progress_callback: Function to call for progress updates
Returns:
Tuple of (output_path, best_loss)
"""
# Load images
content_img = image_loader(content_path)
style_img = image_loader(style_path)
# Start with content image for faster convergence
input_img = content_img.clone()
# Load VGG19 for feature extraction
cnn = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features.to(device).eval()
# Mean and std for normalization (from ImageNet)
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)
# Run style transfer
output, best_output, best_loss = run_style_transfer(
cnn,
cnn_normalization_mean,
cnn_normalization_std,
content_img,
style_img,
input_img,
num_steps=num_steps,
style_weight=style_weight,
content_weight=content_weight,
layer_weights=layer_weights,
progress_callback=progress_callback
)
# Save result and return path
save_image(best_output, output_path)
return output_path, best_loss