import torch import torch.nn as nn import torch.optim as optim from torchvision import models, transforms as T from PIL import Image import numpy as np import gradio as gr import os # --- Configuration --- # Check for CUDA availability device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") imsize = 256 beta = 1e5 # Style weight multiplier # Define the style layers and their weights style_layers_names = ['conv1_1', 'conv2_1', 'conv3_1', 'conv4_1', 'conv5_1'] style_weights = {'conv1_1': 1.0, 'conv2_1': 0.75, 'conv3_1': 0.2, 'conv4_1': 0.2, 'conv5_1': 0.2} # Mapping layer names to VGG19 feature module indices layer_name_to_index = { 'conv1_1': '0', 'conv2_1': '5', 'conv3_1': '10', 'conv4_1': '19', 'conv4_2': '21', 'conv5_1': '28' } # Indices for the style layers style_layers_indices = {layer_name_to_index[name] for name in style_layers_names} # Layers to extract features during inference (only style layers needed) layers_for_inference = {idx: name for name, idx in layer_name_to_index.items() if idx in style_layers_indices} # --- Load Model and Targets (Load once when app starts) --- # Load the VGG model # Use VGG19_Weights.DEFAULT for recommended weights model = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features.to(device).eval() for param in model.parameters(): param.requires_grad_(False) # Freeze model parameters # Load the saved target Gram matrices try: loaded_target_grams = torch.load('style_target_grams.pt', map_location=device) print("Style target grams loaded successfully.") except FileNotFoundError: print("Error: style_target_grams.pt not found. Please ensure it's in the same directory.") # You might want to add logic here to train/generate the grams if missing, # but for a simple inference space, ensure the file is pre-uploaded. raise SystemExit("Required file style_target_grams.pt not found.") except Exception as e: print(f"Error loading style target grams: {e}") raise SystemExit(f"Error loading style target grams: {e}") # --- Helper Functions --- def image_loader(image: Image.Image, size=256, device=torch.device("cpu")): """Loads a PIL Image, resizes, converts to tensor, and normalizes.""" # VGG19 mean and std normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) loader = T.Compose([ T.Resize(size), T.CenterCrop(size), # Ensure square shape T.ToTensor(), normalize, ]) # image is already a PIL Image from Gradio image = image.convert('RGB') # Ensure RGB image = loader(image).unsqueeze(0) # Add batch dimension return image.to(device, torch.float) def im_convert(tensor): """Converts a PyTorch tensor to a NumPy image for display.""" image = tensor.to("cpu").clone().detach() image = image.numpy().squeeze(0) # Remove batch dimension image = image.transpose(1, 2, 0) # Transpose C, H, W -> H, W, C # De-normalize # Ensure values are within 0-1 range before de-normalization image = np.clip(image, -2.5, 2.5) # Approximate clip based on typical VGG output range after norm image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406)) image = image.clip(0, 1) # Clip values to be between 0 and 1 return image def gram_matrix(tensor): """Calculates the Gram matrix of a batch of feature maps.""" b, c, h, w = tensor.size() features = tensor.view(c, h * w) # Reshape features: (c, h*w) gram = features.mm(features.t()) # Calculate gram matrix: features * features^T return gram.div(c * h * w) # Normalize def get_features(image, model, layers): """Extracts features from specified layers of the model.""" features = {} x = image # Use state_dict keys to iterate through layers as named_children might skip some # Or, since we only need specific indices, just iterate through modules i = 0 for module in model.children(): name = str(i) x = module(x) if name in layers: features[layers[name]] = x i += 1 return features # --- Main Inference Function for Gradio --- def stylize_image(content_image: Image.Image): """ Performs style transfer inference on a new content image. Args: content_image: A PIL Image object of the content image. Returns: A NumPy array representing the stylized image (suitable for Gradio display). Returns None if an error occurs. """ print("Starting style transfer inference...") try: # 1. Load and preprocess the new content image new_content_img = image_loader(content_image, size=imsize, device=device) # 2. Initialize the generated image (clone of content) # It's important to clone and require_grad for the optimization generated_img = new_content_img.clone().requires_grad_(True).to(device) # 3. Setup optimizer for the generated image lr = 0.002 optimizer = optim.Adam([generated_img], lr=lr) # 4. Run optimization loop inference_steps = 100 # Number of optimization steps for inference for step in range(1, inference_steps + 1): # Get features for the generated image generated_features = get_features(generated_img, model, layers=layers_for_inference) # Calculate style loss current_style_loss = torch.tensor(0.0, device=device) # Initialize loss tensor for layer_name in style_layers_names: # Ensure target_gram is on the correct device target_gram = loaded_target_grams[layer_name].to(device) input_feature = generated_features[layer_name] input_gram = gram_matrix(input_feature) loss = nn.functional.mse_loss(input_gram, target_gram) current_style_loss = current_style_loss + style_weights[layer_name] * loss # Total loss (only style loss in inference mode) total_loss = beta * current_style_loss # Optimization step optimizer.zero_grad() total_loss.backward() optimizer.step() # Optional: Print progress (useful for debugging, might clutter logs in HF Spaces) # if step % 100 == 0: # print(f"Step {step}/{inference_steps}, Loss: {total_loss.item():.4f}") print("Inference finished.") # 5. Convert the final tensor to a displayable image format stylized_np_img = im_convert(generated_img) return stylized_np_img except Exception as e: print(f"An error occurred during style transfer: {e}") # Return a placeholder or error message if possible, or just let Gradio handle the None return return None # --- Gradio Interface --- # Define the interface inputs and outputs # Input: An image component for uploading the content image image_input = gr.Image(type="pil", label="Upload Content Image") # Output: An image component to display the stylized result image_output = gr.Image(type="numpy", label="Stylized Image") # Create the Gradio Interface iface = gr.Interface( fn=stylize_image, # The function to run inputs=image_input, # The input component outputs=image_output, # The output component title="Neural Style Transfer (Fixed Style)", description="Upload a content image to apply a pre-trained style.", # Add example images if you have them in an 'examples' directory # examples=["examples/my_content_example.jpg"], allow_flagging="never" # Disable flagging unless you want to collect feedback ) # Launch the app if __name__ == "__main__": # This part is for local testing. Hugging Face Spaces runs the app directly # using `iface.launch()`. print("Gradio app starting...") iface.launch()