Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torchvision import models, transforms as T | |
| from PIL import Image | |
| import numpy as np | |
| import gradio as gr | |
| import os | |
| # --- Configuration --- | |
| # Check for CUDA availability | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| imsize = 256 | |
| beta = 1e5 # Style weight multiplier | |
| # Define the style layers and their weights | |
| style_layers_names = ['conv1_1', 'conv2_1', 'conv3_1', 'conv4_1', 'conv5_1'] | |
| style_weights = {'conv1_1': 1.0, 'conv2_1': 0.75, 'conv3_1': 0.2, 'conv4_1': 0.2, 'conv5_1': 0.2} | |
| # Mapping layer names to VGG19 feature module indices | |
| layer_name_to_index = { | |
| 'conv1_1': '0', 'conv2_1': '5', 'conv3_1': '10', 'conv4_1': '19', 'conv4_2': '21', 'conv5_1': '28' | |
| } | |
| # Indices for the style layers | |
| style_layers_indices = {layer_name_to_index[name] for name in style_layers_names} | |
| # Layers to extract features during inference (only style layers needed) | |
| layers_for_inference = {idx: name for name, idx in layer_name_to_index.items() if idx in style_layers_indices} | |
| # --- Load Model and Targets (Load once when app starts) --- | |
| # Load the VGG model | |
| # Use VGG19_Weights.DEFAULT for recommended weights | |
| model = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features.to(device).eval() | |
| for param in model.parameters(): | |
| param.requires_grad_(False) # Freeze model parameters | |
| # Load the saved target Gram matrices | |
| try: | |
| loaded_target_grams = torch.load('style_target_grams.pt', map_location=device) | |
| print("Style target grams loaded successfully.") | |
| except FileNotFoundError: | |
| print("Error: style_target_grams.pt not found. Please ensure it's in the same directory.") | |
| # You might want to add logic here to train/generate the grams if missing, | |
| # but for a simple inference space, ensure the file is pre-uploaded. | |
| raise SystemExit("Required file style_target_grams.pt not found.") | |
| except Exception as e: | |
| print(f"Error loading style target grams: {e}") | |
| raise SystemExit(f"Error loading style target grams: {e}") | |
| # --- Helper Functions --- | |
| def image_loader(image: Image.Image, size=256, device=torch.device("cpu")): | |
| """Loads a PIL Image, resizes, converts to tensor, and normalizes.""" | |
| # VGG19 mean and std | |
| normalize = T.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| loader = T.Compose([ | |
| T.Resize(size), | |
| T.CenterCrop(size), # Ensure square shape | |
| T.ToTensor(), | |
| normalize, | |
| ]) | |
| # image is already a PIL Image from Gradio | |
| image = image.convert('RGB') # Ensure RGB | |
| image = loader(image).unsqueeze(0) # Add batch dimension | |
| return image.to(device, torch.float) | |
| def im_convert(tensor): | |
| """Converts a PyTorch tensor to a NumPy image for display.""" | |
| image = tensor.to("cpu").clone().detach() | |
| image = image.numpy().squeeze(0) # Remove batch dimension | |
| image = image.transpose(1, 2, 0) # Transpose C, H, W -> H, W, C | |
| # De-normalize | |
| # Ensure values are within 0-1 range before de-normalization | |
| image = np.clip(image, -2.5, 2.5) # Approximate clip based on typical VGG output range after norm | |
| image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406)) | |
| image = image.clip(0, 1) # Clip values to be between 0 and 1 | |
| return image | |
| def gram_matrix(tensor): | |
| """Calculates the Gram matrix of a batch of feature maps.""" | |
| b, c, h, w = tensor.size() | |
| features = tensor.view(c, h * w) # Reshape features: (c, h*w) | |
| gram = features.mm(features.t()) # Calculate gram matrix: features * features^T | |
| return gram.div(c * h * w) # Normalize | |
| def get_features(image, model, layers): | |
| """Extracts features from specified layers of the model.""" | |
| features = {} | |
| x = image | |
| # Use state_dict keys to iterate through layers as named_children might skip some | |
| # Or, since we only need specific indices, just iterate through modules | |
| i = 0 | |
| for module in model.children(): | |
| name = str(i) | |
| x = module(x) | |
| if name in layers: | |
| features[layers[name]] = x | |
| i += 1 | |
| return features | |
| # --- Main Inference Function for Gradio --- | |
| def stylize_image(content_image: Image.Image): | |
| """ | |
| Performs style transfer inference on a new content image. | |
| Args: | |
| content_image: A PIL Image object of the content image. | |
| Returns: | |
| A NumPy array representing the stylized image (suitable for Gradio display). | |
| Returns None if an error occurs. | |
| """ | |
| print("Starting style transfer inference...") | |
| try: | |
| # 1. Load and preprocess the new content image | |
| new_content_img = image_loader(content_image, size=imsize, device=device) | |
| # 2. Initialize the generated image (clone of content) | |
| # It's important to clone and require_grad for the optimization | |
| generated_img = new_content_img.clone().requires_grad_(True).to(device) | |
| # 3. Setup optimizer for the generated image | |
| lr = 0.002 | |
| optimizer = optim.Adam([generated_img], lr=lr) | |
| # 4. Run optimization loop | |
| inference_steps = 100 # Number of optimization steps for inference | |
| for step in range(1, inference_steps + 1): | |
| # Get features for the generated image | |
| generated_features = get_features(generated_img, model, layers=layers_for_inference) | |
| # Calculate style loss | |
| current_style_loss = torch.tensor(0.0, device=device) # Initialize loss tensor | |
| for layer_name in style_layers_names: | |
| # Ensure target_gram is on the correct device | |
| target_gram = loaded_target_grams[layer_name].to(device) | |
| input_feature = generated_features[layer_name] | |
| input_gram = gram_matrix(input_feature) | |
| loss = nn.functional.mse_loss(input_gram, target_gram) | |
| current_style_loss = current_style_loss + style_weights[layer_name] * loss | |
| # Total loss (only style loss in inference mode) | |
| total_loss = beta * current_style_loss | |
| # Optimization step | |
| optimizer.zero_grad() | |
| total_loss.backward() | |
| optimizer.step() | |
| # Optional: Print progress (useful for debugging, might clutter logs in HF Spaces) | |
| # if step % 100 == 0: | |
| # print(f"Step {step}/{inference_steps}, Loss: {total_loss.item():.4f}") | |
| print("Inference finished.") | |
| # 5. Convert the final tensor to a displayable image format | |
| stylized_np_img = im_convert(generated_img) | |
| return stylized_np_img | |
| except Exception as e: | |
| print(f"An error occurred during style transfer: {e}") | |
| # Return a placeholder or error message if possible, or just let Gradio handle the None return | |
| return None | |
| # --- Gradio Interface --- | |
| # Define the interface inputs and outputs | |
| # Input: An image component for uploading the content image | |
| image_input = gr.Image(type="pil", label="Upload Content Image") | |
| # Output: An image component to display the stylized result | |
| image_output = gr.Image(type="numpy", label="Stylized Image") | |
| # Create the Gradio Interface | |
| iface = gr.Interface( | |
| fn=stylize_image, # The function to run | |
| inputs=image_input, # The input component | |
| outputs=image_output, # The output component | |
| title="Neural Style Transfer (Fixed Style)", | |
| description="Upload a content image to apply a pre-trained style.", | |
| # Add example images if you have them in an 'examples' directory | |
| # examples=["examples/my_content_example.jpg"], | |
| allow_flagging="never" # Disable flagging unless you want to collect feedback | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| # This part is for local testing. Hugging Face Spaces runs the app directly | |
| # using `iface.launch()`. | |
| print("Gradio app starting...") | |
| iface.launch() |