Kush26's picture
Update app.py
ab248ac verified
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms as T
from PIL import Image
import numpy as np
import gradio as gr
import os
# --- Configuration ---
# Check for CUDA availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
imsize = 256
beta = 1e5 # Style weight multiplier
# Define the style layers and their weights
style_layers_names = ['conv1_1', 'conv2_1', 'conv3_1', 'conv4_1', 'conv5_1']
style_weights = {'conv1_1': 1.0, 'conv2_1': 0.75, 'conv3_1': 0.2, 'conv4_1': 0.2, 'conv5_1': 0.2}
# Mapping layer names to VGG19 feature module indices
layer_name_to_index = {
'conv1_1': '0', 'conv2_1': '5', 'conv3_1': '10', 'conv4_1': '19', 'conv4_2': '21', 'conv5_1': '28'
}
# Indices for the style layers
style_layers_indices = {layer_name_to_index[name] for name in style_layers_names}
# Layers to extract features during inference (only style layers needed)
layers_for_inference = {idx: name for name, idx in layer_name_to_index.items() if idx in style_layers_indices}
# --- Load Model and Targets (Load once when app starts) ---
# Load the VGG model
# Use VGG19_Weights.DEFAULT for recommended weights
model = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features.to(device).eval()
for param in model.parameters():
param.requires_grad_(False) # Freeze model parameters
# Load the saved target Gram matrices
try:
loaded_target_grams = torch.load('style_target_grams.pt', map_location=device)
print("Style target grams loaded successfully.")
except FileNotFoundError:
print("Error: style_target_grams.pt not found. Please ensure it's in the same directory.")
# You might want to add logic here to train/generate the grams if missing,
# but for a simple inference space, ensure the file is pre-uploaded.
raise SystemExit("Required file style_target_grams.pt not found.")
except Exception as e:
print(f"Error loading style target grams: {e}")
raise SystemExit(f"Error loading style target grams: {e}")
# --- Helper Functions ---
def image_loader(image: Image.Image, size=256, device=torch.device("cpu")):
"""Loads a PIL Image, resizes, converts to tensor, and normalizes."""
# VGG19 mean and std
normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
loader = T.Compose([
T.Resize(size),
T.CenterCrop(size), # Ensure square shape
T.ToTensor(),
normalize,
])
# image is already a PIL Image from Gradio
image = image.convert('RGB') # Ensure RGB
image = loader(image).unsqueeze(0) # Add batch dimension
return image.to(device, torch.float)
def im_convert(tensor):
"""Converts a PyTorch tensor to a NumPy image for display."""
image = tensor.to("cpu").clone().detach()
image = image.numpy().squeeze(0) # Remove batch dimension
image = image.transpose(1, 2, 0) # Transpose C, H, W -> H, W, C
# De-normalize
# Ensure values are within 0-1 range before de-normalization
image = np.clip(image, -2.5, 2.5) # Approximate clip based on typical VGG output range after norm
image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
image = image.clip(0, 1) # Clip values to be between 0 and 1
return image
def gram_matrix(tensor):
"""Calculates the Gram matrix of a batch of feature maps."""
b, c, h, w = tensor.size()
features = tensor.view(c, h * w) # Reshape features: (c, h*w)
gram = features.mm(features.t()) # Calculate gram matrix: features * features^T
return gram.div(c * h * w) # Normalize
def get_features(image, model, layers):
"""Extracts features from specified layers of the model."""
features = {}
x = image
# Use state_dict keys to iterate through layers as named_children might skip some
# Or, since we only need specific indices, just iterate through modules
i = 0
for module in model.children():
name = str(i)
x = module(x)
if name in layers:
features[layers[name]] = x
i += 1
return features
# --- Main Inference Function for Gradio ---
def stylize_image(content_image: Image.Image):
"""
Performs style transfer inference on a new content image.
Args:
content_image: A PIL Image object of the content image.
Returns:
A NumPy array representing the stylized image (suitable for Gradio display).
Returns None if an error occurs.
"""
print("Starting style transfer inference...")
try:
# 1. Load and preprocess the new content image
new_content_img = image_loader(content_image, size=imsize, device=device)
# 2. Initialize the generated image (clone of content)
# It's important to clone and require_grad for the optimization
generated_img = new_content_img.clone().requires_grad_(True).to(device)
# 3. Setup optimizer for the generated image
lr = 0.002
optimizer = optim.Adam([generated_img], lr=lr)
# 4. Run optimization loop
inference_steps = 100 # Number of optimization steps for inference
for step in range(1, inference_steps + 1):
# Get features for the generated image
generated_features = get_features(generated_img, model, layers=layers_for_inference)
# Calculate style loss
current_style_loss = torch.tensor(0.0, device=device) # Initialize loss tensor
for layer_name in style_layers_names:
# Ensure target_gram is on the correct device
target_gram = loaded_target_grams[layer_name].to(device)
input_feature = generated_features[layer_name]
input_gram = gram_matrix(input_feature)
loss = nn.functional.mse_loss(input_gram, target_gram)
current_style_loss = current_style_loss + style_weights[layer_name] * loss
# Total loss (only style loss in inference mode)
total_loss = beta * current_style_loss
# Optimization step
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
# Optional: Print progress (useful for debugging, might clutter logs in HF Spaces)
# if step % 100 == 0:
# print(f"Step {step}/{inference_steps}, Loss: {total_loss.item():.4f}")
print("Inference finished.")
# 5. Convert the final tensor to a displayable image format
stylized_np_img = im_convert(generated_img)
return stylized_np_img
except Exception as e:
print(f"An error occurred during style transfer: {e}")
# Return a placeholder or error message if possible, or just let Gradio handle the None return
return None
# --- Gradio Interface ---
# Define the interface inputs and outputs
# Input: An image component for uploading the content image
image_input = gr.Image(type="pil", label="Upload Content Image")
# Output: An image component to display the stylized result
image_output = gr.Image(type="numpy", label="Stylized Image")
# Create the Gradio Interface
iface = gr.Interface(
fn=stylize_image, # The function to run
inputs=image_input, # The input component
outputs=image_output, # The output component
title="Neural Style Transfer (Fixed Style)",
description="Upload a content image to apply a pre-trained style.",
# Add example images if you have them in an 'examples' directory
# examples=["examples/my_content_example.jpg"],
allow_flagging="never" # Disable flagging unless you want to collect feedback
)
# Launch the app
if __name__ == "__main__":
# This part is for local testing. Hugging Face Spaces runs the app directly
# using `iface.launch()`.
print("Gradio app starting...")
iface.launch()