maomao88 commited on
Commit
acd1f17
·
1 Parent(s): 0223cf5

first commit

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.jpg filter=lfs diff=lfs merge=lfs -text
37
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
.idea/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/misc.xml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="Black">
4
+ <option name="sdkName" value="Python 3.13" />
5
+ </component>
6
+ <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.13 (style_transfer_app)" project-jdk-type="Python SDK" />
7
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/style_transfer_app.iml" filepath="$PROJECT_DIR$/.idea/style_transfer_app.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/style_transfer_app.iml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$">
5
+ <excludeFolder url="file://$MODULE_DIR$/.venv" />
6
+ </content>
7
+ <orderEntry type="jdk" jdkName="Python 3.13 (style_transfer_app)" jdkType="Python SDK" />
8
+ <orderEntry type="sourceFolder" forTests="false" />
9
+ </component>
10
+ </module>
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="$PROJECT_DIR$" vcs="Git" />
5
+ </component>
6
+ </project>
__pycache__/model.cpython-313.pyc ADDED
Binary file (2.15 kB). View file
 
__pycache__/utils.cpython-313.pyc ADDED
Binary file (3.33 kB). View file
 
app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+
4
+ from model import generate_image
5
+ from utils import load_model, load_image, im_convert, device
6
+
7
+
8
+ vgg = load_model(d=device)
9
+
10
+ max_image_size = 400
11
+
12
+
13
+ def generate(content: torch.Tensor, style: torch.Tensor, alpha_slider: float):
14
+
15
+ content_img = load_image(image=content, max_size=max_image_size).to(device)
16
+ style_img = load_image(image=style, max_size=max_image_size, shape=content_img.shape[-2:]).to(device)
17
+
18
+ target_img = content_img.clone().requires_grad_(True).to(
19
+ device) # Initialize the target image as a clone of the original content image
20
+ target = generate_image(model=vgg, content=content_img, style=style_img, target=target_img, steps = 2700, content_wt=alpha_slider)
21
+ return im_convert(target)
22
+
23
+
24
+ def check_inputs(img1, img2):
25
+ """Enable the submit button only if both images are uploaded."""
26
+ if img1 is not None and img2 is not None:
27
+ return gr.update(interactive=True) # Enable button
28
+ return gr.update(interactive=False) # Keep button disabled
29
+
30
+ with gr.Blocks() as demo:
31
+ gr.Markdown("Transfer Image Style with the VGG19 model.")
32
+
33
+ with gr.Row():
34
+ content_image = gr.Image(type="pil", label="Original Image")
35
+ style_image = gr.Image(type="pil", label="Style Reference Image")
36
+
37
+ # Examples Section
38
+ gr.Examples(
39
+ examples=[
40
+ ["./images/input-image-1.jpg", "./images/style-image-1.jpg"],
41
+ ["./images/input-image-2.jpg", "./images/style-image-2.jpg"]
42
+ ],
43
+ inputs=[content_image, style_image]
44
+ )
45
+
46
+ alpha_slider = gr.Slider(0, 1, value=1, step=0.1, label="Blending Ratio")
47
+ submit_button = gr.Button("Blend Images", "Generate", variant="primary", interactive=False)
48
+
49
+ output = gr.Image(label="Blended Image")
50
+
51
+ submit_button.click(generate, inputs=[content_image, style_image, alpha_slider], outputs=output)
52
+
53
+ # When images change, check if both are uploaded to enable the button
54
+ content_image.change(fn=check_inputs, inputs=[content_image, style_image], outputs=submit_button)
55
+ style_image.change(fn=check_inputs, inputs=[content_image, style_image], outputs=submit_button)
56
+
57
+
58
+ # Launch the demo!
59
+ demo.launch()
60
+
61
+ # if __name__ == "__main__":
62
+ # demo.launch()
images/.DS_Store ADDED
Binary file (6.15 kB). View file
 
images/input-image-1.jpg ADDED

Git LFS Details

  • SHA256: 228a4edc34917fc6f70985c6b672eff6801d2b7f25efb7676b1d1cae01c0f42a
  • Pointer size: 132 Bytes
  • Size of remote file: 2.12 MB
images/input-image-2.jpeg ADDED

Git LFS Details

  • SHA256: 3518767cb2b9832a3698687a87a61cfd209970b6807449b78d653764a96c7626
  • Pointer size: 131 Bytes
  • Size of remote file: 161 kB
images/style-image-1.jpg ADDED

Git LFS Details

  • SHA256: 363189897b489a6377ace527ac4e8c6bf50c83e65686f2d89d1a8550ec290ac1
  • Pointer size: 131 Bytes
  • Size of remote file: 468 kB
images/style-image-2.jpg ADDED

Git LFS Details

  • SHA256: 63ce122df5b71f6c9353426f96856e6452e2ebb99edc74a856575405590ab57e
  • Pointer size: 131 Bytes
  • Size of remote file: 252 kB
model.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.optim as optim
3
+ from torch import nn
4
+ from utils import get_features, gram_matrix
5
+
6
+
7
+ # weight early layers more heavily
8
+ style_weights = {'conv1_1': 1.,
9
+ 'conv2_1': 0.75,
10
+ 'conv3_1': 0.2,
11
+ 'conv4_1': 0.2,
12
+ 'conv5_1': 0.2}
13
+
14
+ # the balance between style and content
15
+ content_weight = 1 # alpha
16
+ style_weight = 1e9 # beta
17
+
18
+
19
+
20
+ def generate_image(model: nn.Module, content: torch.Tensor, style: torch.Tensor, target: torch.Tensor, steps = 2700, content_wt=content_weight):
21
+ content_features = get_features(content, model)
22
+ style_features = get_features(style, model)
23
+
24
+ # apply gram_matrix to each of the style features for that same layer
25
+ style_grams = {layer: gram_matrix(style_features[layer]) for layer in style_features}
26
+
27
+ optimizer = optim.Adam([target], lr=0.003)
28
+
29
+ for ii in range(1, steps + 1):
30
+ target_features = get_features(target, model)
31
+ content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2']) ** 2)
32
+ style_loss = 0
33
+
34
+ # calculate the style loss
35
+ for layer in style_weights:
36
+ target_feature = target_features[layer]
37
+ target_gram = gram_matrix(target_feature)
38
+ style_gram = style_grams[layer]
39
+ layer_style_loss = style_weights[layer] * torch.mean((target_gram - style_gram) ** 2)
40
+ _, d, h, w = target_feature.shape
41
+ style_loss += layer_style_loss / (d * h * w)
42
+
43
+ total_loss = content_wt * content_loss + style_weight * style_loss
44
+
45
+ optimizer.zero_grad()
46
+ total_loss.backward()
47
+ optimizer.step()
48
+
49
+ return target
50
+
51
+
52
+
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio
2
+ transformers
3
+ torch
4
+ torchvision
utils.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.optim as optim
3
+ from torchvision import transforms, models
4
+ from PIL import Image
5
+ import numpy as np
6
+
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu")
8
+
9
+
10
+
11
+ def load_model(d=device):
12
+ weights = models.VGG19_Weights.DEFAULT
13
+ vgg = models.vgg19(weights=weights).features # only uses the feature layers of the model
14
+
15
+ # https://pytorch.org/docs/stable/generated/torch.Tensor.requires_grad_.html
16
+ for param in vgg.parameters():
17
+ param.requires_grad_(False)
18
+
19
+ vgg.to(device=d)
20
+ return vgg
21
+
22
+
23
+ # max_size limits the image size to 400 pixel
24
+ def load_image(image, max_size=400, shape=None):
25
+ # image = Image.open(img_path).convert('RGB')
26
+
27
+ # either the horizontal or vertical image size exceeds max_size, set the size to max_size
28
+ if max(image.size) > max_size:
29
+ size = max_size
30
+ else:
31
+ size = max(image.size)
32
+
33
+ if shape is not None:
34
+ size = shape
35
+
36
+ in_transform = transforms.Compose([
37
+ transforms.Resize(size), # Resize will scale the smaller edge of the image to 'size'
38
+ transforms.ToTensor(),
39
+ transforms.Normalize((0.5, 0.5, 0.5),
40
+ (0.5, 0.5, 0.5))])
41
+
42
+ image = in_transform(image).unsqueeze(0)
43
+
44
+ return image
45
+
46
+
47
+ def im_convert(tensor):
48
+ image = tensor.to("cpu").clone().detach()
49
+ image = image.numpy().squeeze()
50
+ image = image.transpose(1,2,0)
51
+ image = image * np.array((0.5, 0.5, 0.5)) + np.array((0.5, 0.5, 0.5))
52
+ image = image.clip(0, 1)
53
+
54
+ return image
55
+
56
+
57
+ def get_features(image, model):
58
+ layers = {'0': 'conv1_1', # Style Extraction
59
+ '5': 'conv2_1', # Style Extraction
60
+ '10': 'conv3_1', # Style Extraction
61
+ '19': 'conv4_1', # Style Extraction
62
+ '21': 'conv4_2', # Content Extraction
63
+ '28': 'conv5_1'} # Style Extraction
64
+
65
+ features = {}
66
+
67
+ for name, layer in model._modules.items():
68
+ # feed the image through the network
69
+ image = layer(image) # run the image through this layer and store it as the output for the layer
70
+ if name in layers:
71
+ features[layers[name]] = image
72
+
73
+ return features
74
+
75
+
76
+ # Eliminate content feature and only maintain style features
77
+ def gram_matrix(tensor):
78
+ _, d, h, w = tensor.size() # d is depth, h is height, w is width
79
+ tensor = tensor.view(d, h * w) # reshape the data into a 2 dimensional tensor
80
+ gram = torch.mm(tensor, tensor.t())
81
+ return gram