Spaces:
Sleeping
Sleeping
first commit
Browse files- .DS_Store +0 -0
- .gitattributes +2 -0
- .idea/.gitignore +3 -0
- .idea/inspectionProfiles/profiles_settings.xml +6 -0
- .idea/misc.xml +7 -0
- .idea/modules.xml +8 -0
- .idea/style_transfer_app.iml +10 -0
- .idea/vcs.xml +6 -0
- __pycache__/model.cpython-313.pyc +0 -0
- __pycache__/utils.cpython-313.pyc +0 -0
- app.py +62 -0
- images/.DS_Store +0 -0
- images/input-image-1.jpg +3 -0
- images/input-image-2.jpeg +3 -0
- images/style-image-1.jpg +3 -0
- images/style-image-2.jpg +3 -0
- model.py +52 -0
- requirements.txt +4 -0
- utils.py +81 -0
.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
.idea/.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Default ignored files
|
| 2 |
+
/shelf/
|
| 3 |
+
/workspace.xml
|
.idea/inspectionProfiles/profiles_settings.xml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<component name="InspectionProjectProfileManager">
|
| 2 |
+
<settings>
|
| 3 |
+
<option name="USE_PROJECT_PROFILE" value="false" />
|
| 4 |
+
<version value="1.0" />
|
| 5 |
+
</settings>
|
| 6 |
+
</component>
|
.idea/misc.xml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="Black">
|
| 4 |
+
<option name="sdkName" value="Python 3.13" />
|
| 5 |
+
</component>
|
| 6 |
+
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.13 (style_transfer_app)" project-jdk-type="Python SDK" />
|
| 7 |
+
</project>
|
.idea/modules.xml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="ProjectModuleManager">
|
| 4 |
+
<modules>
|
| 5 |
+
<module fileurl="file://$PROJECT_DIR$/.idea/style_transfer_app.iml" filepath="$PROJECT_DIR$/.idea/style_transfer_app.iml" />
|
| 6 |
+
</modules>
|
| 7 |
+
</component>
|
| 8 |
+
</project>
|
.idea/style_transfer_app.iml
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<module type="PYTHON_MODULE" version="4">
|
| 3 |
+
<component name="NewModuleRootManager">
|
| 4 |
+
<content url="file://$MODULE_DIR$">
|
| 5 |
+
<excludeFolder url="file://$MODULE_DIR$/.venv" />
|
| 6 |
+
</content>
|
| 7 |
+
<orderEntry type="jdk" jdkName="Python 3.13 (style_transfer_app)" jdkType="Python SDK" />
|
| 8 |
+
<orderEntry type="sourceFolder" forTests="false" />
|
| 9 |
+
</component>
|
| 10 |
+
</module>
|
.idea/vcs.xml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="VcsDirectoryMappings">
|
| 4 |
+
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
| 5 |
+
</component>
|
| 6 |
+
</project>
|
__pycache__/model.cpython-313.pyc
ADDED
|
Binary file (2.15 kB). View file
|
|
|
__pycache__/utils.cpython-313.pyc
ADDED
|
Binary file (3.33 kB). View file
|
|
|
app.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from model import generate_image
|
| 5 |
+
from utils import load_model, load_image, im_convert, device
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
vgg = load_model(d=device)
|
| 9 |
+
|
| 10 |
+
max_image_size = 400
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def generate(content: torch.Tensor, style: torch.Tensor, alpha_slider: float):
|
| 14 |
+
|
| 15 |
+
content_img = load_image(image=content, max_size=max_image_size).to(device)
|
| 16 |
+
style_img = load_image(image=style, max_size=max_image_size, shape=content_img.shape[-2:]).to(device)
|
| 17 |
+
|
| 18 |
+
target_img = content_img.clone().requires_grad_(True).to(
|
| 19 |
+
device) # Initialize the target image as a clone of the original content image
|
| 20 |
+
target = generate_image(model=vgg, content=content_img, style=style_img, target=target_img, steps = 2700, content_wt=alpha_slider)
|
| 21 |
+
return im_convert(target)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def check_inputs(img1, img2):
|
| 25 |
+
"""Enable the submit button only if both images are uploaded."""
|
| 26 |
+
if img1 is not None and img2 is not None:
|
| 27 |
+
return gr.update(interactive=True) # Enable button
|
| 28 |
+
return gr.update(interactive=False) # Keep button disabled
|
| 29 |
+
|
| 30 |
+
with gr.Blocks() as demo:
|
| 31 |
+
gr.Markdown("Transfer Image Style with the VGG19 model.")
|
| 32 |
+
|
| 33 |
+
with gr.Row():
|
| 34 |
+
content_image = gr.Image(type="pil", label="Original Image")
|
| 35 |
+
style_image = gr.Image(type="pil", label="Style Reference Image")
|
| 36 |
+
|
| 37 |
+
# Examples Section
|
| 38 |
+
gr.Examples(
|
| 39 |
+
examples=[
|
| 40 |
+
["./images/input-image-1.jpg", "./images/style-image-1.jpg"],
|
| 41 |
+
["./images/input-image-2.jpg", "./images/style-image-2.jpg"]
|
| 42 |
+
],
|
| 43 |
+
inputs=[content_image, style_image]
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
alpha_slider = gr.Slider(0, 1, value=1, step=0.1, label="Blending Ratio")
|
| 47 |
+
submit_button = gr.Button("Blend Images", "Generate", variant="primary", interactive=False)
|
| 48 |
+
|
| 49 |
+
output = gr.Image(label="Blended Image")
|
| 50 |
+
|
| 51 |
+
submit_button.click(generate, inputs=[content_image, style_image, alpha_slider], outputs=output)
|
| 52 |
+
|
| 53 |
+
# When images change, check if both are uploaded to enable the button
|
| 54 |
+
content_image.change(fn=check_inputs, inputs=[content_image, style_image], outputs=submit_button)
|
| 55 |
+
style_image.change(fn=check_inputs, inputs=[content_image, style_image], outputs=submit_button)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# Launch the demo!
|
| 59 |
+
demo.launch()
|
| 60 |
+
|
| 61 |
+
# if __name__ == "__main__":
|
| 62 |
+
# demo.launch()
|
images/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
images/input-image-1.jpg
ADDED
|
Git LFS Details
|
images/input-image-2.jpeg
ADDED
|
Git LFS Details
|
images/style-image-1.jpg
ADDED
|
Git LFS Details
|
images/style-image-2.jpg
ADDED
|
Git LFS Details
|
model.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.optim as optim
|
| 3 |
+
from torch import nn
|
| 4 |
+
from utils import get_features, gram_matrix
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# weight early layers more heavily
|
| 8 |
+
style_weights = {'conv1_1': 1.,
|
| 9 |
+
'conv2_1': 0.75,
|
| 10 |
+
'conv3_1': 0.2,
|
| 11 |
+
'conv4_1': 0.2,
|
| 12 |
+
'conv5_1': 0.2}
|
| 13 |
+
|
| 14 |
+
# the balance between style and content
|
| 15 |
+
content_weight = 1 # alpha
|
| 16 |
+
style_weight = 1e9 # beta
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def generate_image(model: nn.Module, content: torch.Tensor, style: torch.Tensor, target: torch.Tensor, steps = 2700, content_wt=content_weight):
|
| 21 |
+
content_features = get_features(content, model)
|
| 22 |
+
style_features = get_features(style, model)
|
| 23 |
+
|
| 24 |
+
# apply gram_matrix to each of the style features for that same layer
|
| 25 |
+
style_grams = {layer: gram_matrix(style_features[layer]) for layer in style_features}
|
| 26 |
+
|
| 27 |
+
optimizer = optim.Adam([target], lr=0.003)
|
| 28 |
+
|
| 29 |
+
for ii in range(1, steps + 1):
|
| 30 |
+
target_features = get_features(target, model)
|
| 31 |
+
content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2']) ** 2)
|
| 32 |
+
style_loss = 0
|
| 33 |
+
|
| 34 |
+
# calculate the style loss
|
| 35 |
+
for layer in style_weights:
|
| 36 |
+
target_feature = target_features[layer]
|
| 37 |
+
target_gram = gram_matrix(target_feature)
|
| 38 |
+
style_gram = style_grams[layer]
|
| 39 |
+
layer_style_loss = style_weights[layer] * torch.mean((target_gram - style_gram) ** 2)
|
| 40 |
+
_, d, h, w = target_feature.shape
|
| 41 |
+
style_loss += layer_style_loss / (d * h * w)
|
| 42 |
+
|
| 43 |
+
total_loss = content_wt * content_loss + style_weight * style_loss
|
| 44 |
+
|
| 45 |
+
optimizer.zero_grad()
|
| 46 |
+
total_loss.backward()
|
| 47 |
+
optimizer.step()
|
| 48 |
+
|
| 49 |
+
return target
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
transformers
|
| 3 |
+
torch
|
| 4 |
+
torchvision
|
utils.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.optim as optim
|
| 3 |
+
from torchvision import transforms, models
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu")
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def load_model(d=device):
|
| 12 |
+
weights = models.VGG19_Weights.DEFAULT
|
| 13 |
+
vgg = models.vgg19(weights=weights).features # only uses the feature layers of the model
|
| 14 |
+
|
| 15 |
+
# https://pytorch.org/docs/stable/generated/torch.Tensor.requires_grad_.html
|
| 16 |
+
for param in vgg.parameters():
|
| 17 |
+
param.requires_grad_(False)
|
| 18 |
+
|
| 19 |
+
vgg.to(device=d)
|
| 20 |
+
return vgg
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# max_size limits the image size to 400 pixel
|
| 24 |
+
def load_image(image, max_size=400, shape=None):
|
| 25 |
+
# image = Image.open(img_path).convert('RGB')
|
| 26 |
+
|
| 27 |
+
# either the horizontal or vertical image size exceeds max_size, set the size to max_size
|
| 28 |
+
if max(image.size) > max_size:
|
| 29 |
+
size = max_size
|
| 30 |
+
else:
|
| 31 |
+
size = max(image.size)
|
| 32 |
+
|
| 33 |
+
if shape is not None:
|
| 34 |
+
size = shape
|
| 35 |
+
|
| 36 |
+
in_transform = transforms.Compose([
|
| 37 |
+
transforms.Resize(size), # Resize will scale the smaller edge of the image to 'size'
|
| 38 |
+
transforms.ToTensor(),
|
| 39 |
+
transforms.Normalize((0.5, 0.5, 0.5),
|
| 40 |
+
(0.5, 0.5, 0.5))])
|
| 41 |
+
|
| 42 |
+
image = in_transform(image).unsqueeze(0)
|
| 43 |
+
|
| 44 |
+
return image
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def im_convert(tensor):
|
| 48 |
+
image = tensor.to("cpu").clone().detach()
|
| 49 |
+
image = image.numpy().squeeze()
|
| 50 |
+
image = image.transpose(1,2,0)
|
| 51 |
+
image = image * np.array((0.5, 0.5, 0.5)) + np.array((0.5, 0.5, 0.5))
|
| 52 |
+
image = image.clip(0, 1)
|
| 53 |
+
|
| 54 |
+
return image
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def get_features(image, model):
|
| 58 |
+
layers = {'0': 'conv1_1', # Style Extraction
|
| 59 |
+
'5': 'conv2_1', # Style Extraction
|
| 60 |
+
'10': 'conv3_1', # Style Extraction
|
| 61 |
+
'19': 'conv4_1', # Style Extraction
|
| 62 |
+
'21': 'conv4_2', # Content Extraction
|
| 63 |
+
'28': 'conv5_1'} # Style Extraction
|
| 64 |
+
|
| 65 |
+
features = {}
|
| 66 |
+
|
| 67 |
+
for name, layer in model._modules.items():
|
| 68 |
+
# feed the image through the network
|
| 69 |
+
image = layer(image) # run the image through this layer and store it as the output for the layer
|
| 70 |
+
if name in layers:
|
| 71 |
+
features[layers[name]] = image
|
| 72 |
+
|
| 73 |
+
return features
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# Eliminate content feature and only maintain style features
|
| 77 |
+
def gram_matrix(tensor):
|
| 78 |
+
_, d, h, w = tensor.size() # d is depth, h is height, w is width
|
| 79 |
+
tensor = tensor.view(d, h * w) # reshape the data into a 2 dimensional tensor
|
| 80 |
+
gram = torch.mm(tensor, tensor.t())
|
| 81 |
+
return gram
|