Kush26 commited on
Commit
a6aaf54
·
verified ·
1 Parent(s): 5a97d3b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +204 -0
app.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torchvision import models, transforms as T
5
+ from PIL import Image
6
+ import numpy as np
7
+ import gradio as gr
8
+ import os
9
+
10
+ # --- Configuration ---
11
+ # Check for CUDA availability
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ print(f"Using device: {device}")
14
+
15
+ imsize = 512
16
+ beta = 1e5 # Style weight multiplier
17
+
18
+ # Define the style layers and their weights
19
+ style_layers_names = ['conv1_1', 'conv2_1', 'conv3_1', 'conv4_1', 'conv5_1']
20
+ style_weights = {'conv1_1': 1.0, 'conv2_1': 0.75, 'conv3_1': 0.2, 'conv4_1': 0.2, 'conv5_1': 0.2}
21
+
22
+ # Mapping layer names to VGG19 feature module indices
23
+ layer_name_to_index = {
24
+ 'conv1_1': '0', 'conv2_1': '5', 'conv3_1': '10', 'conv4_1': '19', 'conv4_2': '21', 'conv5_1': '28'
25
+ }
26
+ # Indices for the style layers
27
+ style_layers_indices = {layer_name_to_index[name] for name in style_layers_names}
28
+ # Layers to extract features during inference (only style layers needed)
29
+ layers_for_inference = {idx: name for name, idx in layer_name_to_index.items() if idx in style_layers_indices}
30
+
31
+
32
+ # --- Load Model and Targets (Load once when app starts) ---
33
+ # Load the VGG model
34
+ # Use VGG19_Weights.DEFAULT for recommended weights
35
+ model = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features.to(device).eval()
36
+ for param in model.parameters():
37
+ param.requires_grad_(False) # Freeze model parameters
38
+
39
+ # Load the saved target Gram matrices
40
+ try:
41
+ loaded_target_grams = torch.load('style_target_grams.pt', map_location=device)
42
+ print("Style target grams loaded successfully.")
43
+ except FileNotFoundError:
44
+ print("Error: style_target_grams.pt not found. Please ensure it's in the same directory.")
45
+ # You might want to add logic here to train/generate the grams if missing,
46
+ # but for a simple inference space, ensure the file is pre-uploaded.
47
+ raise SystemExit("Required file style_target_grams.pt not found.")
48
+ except Exception as e:
49
+ print(f"Error loading style target grams: {e}")
50
+ raise SystemExit(f"Error loading style target grams: {e}")
51
+
52
+
53
+ # --- Helper Functions ---
54
+
55
+ def image_loader(image: Image.Image, size=512, device=torch.device("cpu")):
56
+ """Loads a PIL Image, resizes, converts to tensor, and normalizes."""
57
+ # VGG19 mean and std
58
+ normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
59
+ std=[0.229, 0.224, 0.225])
60
+ loader = T.Compose([
61
+ T.Resize(size),
62
+ T.CenterCrop(size), # Ensure square shape
63
+ T.ToTensor(),
64
+ normalize,
65
+ ])
66
+
67
+ # image is already a PIL Image from Gradio
68
+ image = image.convert('RGB') # Ensure RGB
69
+ image = loader(image).unsqueeze(0) # Add batch dimension
70
+ return image.to(device, torch.float)
71
+
72
+ def im_convert(tensor):
73
+ """Converts a PyTorch tensor to a NumPy image for display."""
74
+ image = tensor.to("cpu").clone().detach()
75
+ image = image.numpy().squeeze(0) # Remove batch dimension
76
+ image = image.transpose(1, 2, 0) # Transpose C, H, W -> H, W, C
77
+
78
+ # De-normalize
79
+ # Ensure values are within 0-1 range before de-normalization
80
+ image = np.clip(image, -2.5, 2.5) # Approximate clip based on typical VGG output range after norm
81
+ image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
82
+
83
+ image = image.clip(0, 1) # Clip values to be between 0 and 1
84
+ return image
85
+
86
+ def gram_matrix(tensor):
87
+ """Calculates the Gram matrix of a batch of feature maps."""
88
+ b, c, h, w = tensor.size()
89
+ features = tensor.view(c, h * w) # Reshape features: (c, h*w)
90
+ gram = features.mm(features.t()) # Calculate gram matrix: features * features^T
91
+ return gram.div(c * h * w) # Normalize
92
+
93
+ def get_features(image, model, layers):
94
+ """Extracts features from specified layers of the model."""
95
+ features = {}
96
+ x = image
97
+ # Use state_dict keys to iterate through layers as named_children might skip some
98
+ # Or, since we only need specific indices, just iterate through modules
99
+ i = 0
100
+ for module in model.children():
101
+ name = str(i)
102
+ x = module(x)
103
+ if name in layers:
104
+ features[layers[name]] = x
105
+ i += 1
106
+ return features
107
+
108
+
109
+ # --- Main Inference Function for Gradio ---
110
+
111
+ def stylize_image(content_image: Image.Image):
112
+ """
113
+ Performs style transfer inference on a new content image.
114
+
115
+ Args:
116
+ content_image: A PIL Image object of the content image.
117
+
118
+ Returns:
119
+ A NumPy array representing the stylized image (suitable for Gradio display).
120
+ Returns None if an error occurs.
121
+ """
122
+ print("Starting style transfer inference...")
123
+
124
+ try:
125
+ # 1. Load and preprocess the new content image
126
+ new_content_img = image_loader(content_image, size=imsize, device=device)
127
+
128
+ # 2. Initialize the generated image (clone of content)
129
+ # It's important to clone and require_grad for the optimization
130
+ generated_img = new_content_img.clone().requires_grad_(True).to(device)
131
+
132
+ # 3. Setup optimizer for the generated image
133
+ lr = 0.02
134
+ optimizer = optim.Adam([generated_img], lr=lr)
135
+
136
+ # 4. Run optimization loop
137
+ inference_steps = 500 # Number of optimization steps for inference
138
+
139
+ for step in range(1, inference_steps + 1):
140
+ # Get features for the generated image
141
+ generated_features = get_features(generated_img, model, layers=layers_for_inference)
142
+
143
+ # Calculate style loss
144
+ current_style_loss = torch.tensor(0.0, device=device) # Initialize loss tensor
145
+ for layer_name in style_layers_names:
146
+ # Ensure target_gram is on the correct device
147
+ target_gram = loaded_target_grams[layer_name].to(device)
148
+ input_feature = generated_features[layer_name]
149
+ input_gram = gram_matrix(input_feature)
150
+ loss = nn.functional.mse_loss(input_gram, target_gram)
151
+ current_style_loss = current_style_loss + style_weights[layer_name] * loss
152
+
153
+ # Total loss (only style loss in inference mode)
154
+ total_loss = beta * current_style_loss
155
+
156
+ # Optimization step
157
+ optimizer.zero_grad()
158
+ total_loss.backward()
159
+ optimizer.step()
160
+
161
+ # Optional: Print progress (useful for debugging, might clutter logs in HF Spaces)
162
+ # if step % 100 == 0:
163
+ # print(f"Step {step}/{inference_steps}, Loss: {total_loss.item():.4f}")
164
+
165
+ print("Inference finished.")
166
+
167
+ # 5. Convert the final tensor to a displayable image format
168
+ stylized_np_img = im_convert(generated_img)
169
+
170
+ return stylized_np_img
171
+
172
+ except Exception as e:
173
+ print(f"An error occurred during style transfer: {e}")
174
+ # Return a placeholder or error message if possible, or just let Gradio handle the None return
175
+ return None
176
+
177
+
178
+ # --- Gradio Interface ---
179
+
180
+ # Define the interface inputs and outputs
181
+ # Input: An image component for uploading the content image
182
+ image_input = gr.Image(type="pil", label="Upload Content Image")
183
+
184
+ # Output: An image component to display the stylized result
185
+ image_output = gr.Image(type="numpy", label="Stylized Image")
186
+
187
+ # Create the Gradio Interface
188
+ iface = gr.Interface(
189
+ fn=stylize_image, # The function to run
190
+ inputs=image_input, # The input component
191
+ outputs=image_output, # The output component
192
+ title="Neural Style Transfer (Fixed Style)",
193
+ description="Upload a content image to apply a pre-trained style.",
194
+ # Add example images if you have them in an 'examples' directory
195
+ # examples=["examples/my_content_example.jpg"],
196
+ allow_flagging="never" # Disable flagging unless you want to collect feedback
197
+ )
198
+
199
+ # Launch the app
200
+ if __name__ == "__main__":
201
+ # This part is for local testing. Hugging Face Spaces runs the app directly
202
+ # using `iface.launch()`.
203
+ print("Gradio app starting...")
204
+ iface.launch()