Spaces:
Build error
Build error
Commit
·
4158574
1
Parent(s):
4ac8bc1
adding the gradio app code
Browse files- app.py +152 -0
- models/__init__.py +0 -0
- models/components/__init__.py +0 -0
- models/components/photo_wct.pth +3 -0
- models/models.py +297 -0
- requirements.txt +18 -0
- utils/__init__.py +12 -0
- utils/photo_smooth.py +101 -0
- utils/photo_wct.py +171 -0
- utils/shared_utils.py +136 -0
- utils/smooth_filter.py +405 -0
app.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import numpy as np
|
| 3 |
+
import gradio as gr
|
| 4 |
+
import utils.shared_utils as st
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import autocast
|
| 9 |
+
import torchvision.transforms as T
|
| 10 |
+
from contextlib import nullcontext
|
| 11 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 12 |
+
context = autocast if device == "cuda" else nullcontext
|
| 13 |
+
# Apply the transformations needed
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def select_input(input_img,webcm_img):
|
| 18 |
+
if input_img is None:
|
| 19 |
+
img= webcm_img
|
| 20 |
+
else:
|
| 21 |
+
img=input_img
|
| 22 |
+
return img
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def infer(prompt,samples):
|
| 26 |
+
images= []
|
| 27 |
+
selections = ["Img_{}".format(str(i+1).zfill(2)) for i in range(samples)]
|
| 28 |
+
with context(device):
|
| 29 |
+
for _ in range(samples):
|
| 30 |
+
back_img = st.stableDiffusionAPICall(prompt)
|
| 31 |
+
images.append(back_img)
|
| 32 |
+
return images
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def change_bg_option(choice):
|
| 39 |
+
if choice == "I have an Image":
|
| 40 |
+
return gr.Image(shape=(800, 800))
|
| 41 |
+
|
| 42 |
+
elif choice == "Generate one for me":
|
| 43 |
+
return gr.update(lines=8, visible=True, value="Please enter a text prompt")
|
| 44 |
+
else:
|
| 45 |
+
return gr.update(visible=False)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# TEXT
|
| 49 |
+
title = "FSDL- One-Shot, Green-Screen, Composition-Transfer"
|
| 50 |
+
DEFAULT_TEXT = "Photorealistic scenery of bookshelf in a room"
|
| 51 |
+
description = """
|
| 52 |
+
<center><a href="https://docs.google.com/document/d/1fde8XKIMT1nNU72859ytd2c58LFBxepS3od9KFBrJbM/edit?usp=sharing">[PAPER]</a> <a href="https://github.com/snknitin/FSDL-Project/blob/main/src/utils/shared_utils.py">[CODE]</a></center>
|
| 53 |
+
<details>
|
| 54 |
+
<summary><b>Instructions</b></summary>
|
| 55 |
+
<p style="margin-top: -3px;">With this app, you can generate a suitable background image to overlay your portrait!<br />You have several ways to set how your final auto-edited image will look like:<br /></p>
|
| 56 |
+
<ul style="margin-top: -20px;margin-bottom: -15px;">
|
| 57 |
+
<li style="margin-bottom: -10px;margin-left: 20px;">Use the "<i>Inputs</i>" tab to either upload an image from your device or allow the use of your webcam to capture</li>
|
| 58 |
+
<li style="margin-left: 20px;">Use the "<i>Background Image Inputs</i>" to upload your own background</li>
|
| 59 |
+
<li style="margin-left: 20px;">Use the "<i>Text prompt</i>" tab to generate a satisfactory bacground image.</li>
|
| 60 |
+
</ul>
|
| 61 |
+
<p>After customization, just hit "<i>Edit</i>" and wait a few seconds.<br />The final image will be available for download <br /> <b>Enjoy!<b><p>
|
| 62 |
+
</details>
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
running = """
|
| 66 |
+
|
| 67 |
+
### Instructions for running the 3 S's in sequence
|
| 68 |
+
|
| 69 |
+
* **Superimpose** - This button allows you to isolate the foreground from your image and overlay it on the background. Remove background using alpha matting
|
| 70 |
+
* **Style-Transfer** - This button transfer the style from your original image to re-map your new background realistically. Uses Nvidia FastPhotoStyle
|
| 71 |
+
* **Smoothing** - Given than image resolutions and clarity can be an issue, this smoothing button makes your final image crisp after the stylization transfer. Fair warning - this last process can take 5-10 mins
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
demo = gr.Blocks()
|
| 76 |
+
|
| 77 |
+
with demo:
|
| 78 |
+
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>" + title + "</h1>")
|
| 79 |
+
with gr.Box():
|
| 80 |
+
gr.Markdown(description)
|
| 81 |
+
# First row - Inputs
|
| 82 |
+
with gr.Row(scale=1):
|
| 83 |
+
with gr.Column():
|
| 84 |
+
with gr.Tabs():
|
| 85 |
+
with gr.TabItem("Upload "):
|
| 86 |
+
input_img = gr.Image(shape=(800, 800), interactive=True, label="You")
|
| 87 |
+
with gr.TabItem("Webcam Capture"):
|
| 88 |
+
webcm_img = gr.Image(source="webcam", streaming=True, shape=(800, 800), interactive=True)
|
| 89 |
+
inp_select_btn = gr.Button("Select")
|
| 90 |
+
|
| 91 |
+
with gr.Column():
|
| 92 |
+
with gr.Tabs():
|
| 93 |
+
with gr.TabItem("Upload"):
|
| 94 |
+
bgm_img = gr.Image(shape=(800, 800), type="pil", interactive=True, label="The Background")
|
| 95 |
+
bgm_select_btn = gr.Button("Select")
|
| 96 |
+
|
| 97 |
+
with gr.TabItem("Generate via Text Prompt"):
|
| 98 |
+
with gr.Box():
|
| 99 |
+
with gr.Row().style(mobile_collapse=False, equal_height=True):
|
| 100 |
+
text = gr.Textbox(lines=7,
|
| 101 |
+
placeholder="Enter your prompt to generate a background image... something like - Photorealistic scenery of bookshelf in a room")
|
| 102 |
+
|
| 103 |
+
samples = gr.Slider(label="Number of Images", minimum=1, maximum=5, value=2, step=1)
|
| 104 |
+
btn = gr.Button("Generate images",variant="primary").style(
|
| 105 |
+
margin=False,
|
| 106 |
+
rounded=(False, True, True, False),
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
gallery = gr.Gallery(label="Generated images", show_label=True).style(grid=(1, 3), height="auto")
|
| 110 |
+
# image_options = gr.Radio(label="Pick", interactive=True, choices=None, type="value")
|
| 111 |
+
text.submit(infer, inputs=[text, samples], outputs=gallery)
|
| 112 |
+
btn.click(infer, inputs=[text, samples], outputs=gallery, show_progress=True, status_tracker=None)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
# Second Row - Backgrounds
|
| 116 |
+
with gr.Row(scale=1):
|
| 117 |
+
with gr.Column():
|
| 118 |
+
final_input_img = gr.Image(shape=(800, 800), type="pil", label="Foreground")
|
| 119 |
+
|
| 120 |
+
with gr.Column():
|
| 121 |
+
final_back_img = gr.Image(shape=(800, 800), type="pil", label="Background", interactive=True)
|
| 122 |
+
|
| 123 |
+
bgm_select_btn.click(fn=lambda x: x, inputs=bgm_img, outputs=final_back_img)
|
| 124 |
+
|
| 125 |
+
inp_select_btn.click(select_input, [input_img, webcm_img], final_input_img)
|
| 126 |
+
|
| 127 |
+
with gr.Row(scale=1):
|
| 128 |
+
with gr.Box():
|
| 129 |
+
gr.Markdown(running)
|
| 130 |
+
|
| 131 |
+
with gr.Row(scale=1):
|
| 132 |
+
|
| 133 |
+
with gr.Column(scale=1):
|
| 134 |
+
supimp_btn = gr.Button("SuperImpose")
|
| 135 |
+
overlay_img = gr.Image(shape=(800, 800), label="Overlay", type="pil")
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
with gr.Column(scale=1):
|
| 139 |
+
style_btn = gr.Button("Composition-Transfer",variant="primary")
|
| 140 |
+
style_img = gr.Image(shape=(800, 800),label="Style-Transfer Image",type="pil")
|
| 141 |
+
|
| 142 |
+
with gr.Column(scale=1):
|
| 143 |
+
submit_btn = gr.Button("Smoothen",variant="primary")
|
| 144 |
+
output_img = gr.Image(shape=(800, 800),label="FinalSmoothened Image",type="pil")
|
| 145 |
+
|
| 146 |
+
supimp_btn.click(fn=st.superimpose, inputs=[final_input_img, final_back_img], outputs=[overlay_img])
|
| 147 |
+
style_btn.click(fn=st.style_transfer, inputs=[overlay_img,final_input_img], outputs=[style_img])
|
| 148 |
+
submit_btn.click(fn=st.smoother, inputs=[style_img,overlay_img], outputs=[output_img])
|
| 149 |
+
|
| 150 |
+
demo.queue()
|
| 151 |
+
demo.launch()
|
| 152 |
+
|
models/__init__.py
ADDED
|
File without changes
|
models/components/__init__.py
ADDED
|
File without changes
|
models/components/photo_wct.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bedc114a83833de79e92b7166b37bc522db71a30bbfa13d0c4f36387789c8af5
|
| 3 |
+
size 33410469
|
models/models.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright (C) 2018 NVIDIA Corporation. All rights reserved.
|
| 3 |
+
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
| 4 |
+
"""
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class VGGEncoder(nn.Module):
|
| 9 |
+
def __init__(self, level):
|
| 10 |
+
super(VGGEncoder, self).__init__()
|
| 11 |
+
self.level = level
|
| 12 |
+
|
| 13 |
+
# 224 x 224
|
| 14 |
+
self.conv0 = nn.Conv2d(3, 3, 1, 1, 0)
|
| 15 |
+
|
| 16 |
+
self.pad1_1 = nn.ReflectionPad2d((1, 1, 1, 1))
|
| 17 |
+
# 226 x 226
|
| 18 |
+
self.conv1_1 = nn.Conv2d(3, 64, 3, 1, 0)
|
| 19 |
+
self.relu1_1 = nn.ReLU(inplace=True)
|
| 20 |
+
# 224 x 224
|
| 21 |
+
|
| 22 |
+
if level < 2: return
|
| 23 |
+
|
| 24 |
+
self.pad1_2 = nn.ReflectionPad2d((1, 1, 1, 1))
|
| 25 |
+
self.conv1_2 = nn.Conv2d(64, 64, 3, 1, 0)
|
| 26 |
+
self.relu1_2 = nn.ReLU(inplace=True)
|
| 27 |
+
# 224 x 224
|
| 28 |
+
self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
|
| 29 |
+
# 112 x 112
|
| 30 |
+
|
| 31 |
+
self.pad2_1 = nn.ReflectionPad2d((1, 1, 1, 1))
|
| 32 |
+
self.conv2_1 = nn.Conv2d(64, 128, 3, 1, 0)
|
| 33 |
+
self.relu2_1 = nn.ReLU(inplace=True)
|
| 34 |
+
# 112 x 112
|
| 35 |
+
|
| 36 |
+
if level < 3: return
|
| 37 |
+
|
| 38 |
+
self.pad2_2 = nn.ReflectionPad2d((1, 1, 1, 1))
|
| 39 |
+
self.conv2_2 = nn.Conv2d(128, 128, 3, 1, 0)
|
| 40 |
+
self.relu2_2 = nn.ReLU(inplace=True)
|
| 41 |
+
# 112 x 112
|
| 42 |
+
|
| 43 |
+
self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
|
| 44 |
+
# 56 x 56
|
| 45 |
+
|
| 46 |
+
self.pad3_1 = nn.ReflectionPad2d((1, 1, 1, 1))
|
| 47 |
+
self.conv3_1 = nn.Conv2d(128, 256, 3, 1, 0)
|
| 48 |
+
self.relu3_1 = nn.ReLU(inplace=True)
|
| 49 |
+
# 56 x 56
|
| 50 |
+
|
| 51 |
+
if level < 4: return
|
| 52 |
+
|
| 53 |
+
self.pad3_2 = nn.ReflectionPad2d((1, 1, 1, 1))
|
| 54 |
+
self.conv3_2 = nn.Conv2d(256, 256, 3, 1, 0)
|
| 55 |
+
self.relu3_2 = nn.ReLU(inplace=True)
|
| 56 |
+
# 56 x 56
|
| 57 |
+
|
| 58 |
+
self.pad3_3 = nn.ReflectionPad2d((1, 1, 1, 1))
|
| 59 |
+
self.conv3_3 = nn.Conv2d(256, 256, 3, 1, 0)
|
| 60 |
+
self.relu3_3 = nn.ReLU(inplace=True)
|
| 61 |
+
# 56 x 56
|
| 62 |
+
|
| 63 |
+
self.pad3_4 = nn.ReflectionPad2d((1, 1, 1, 1))
|
| 64 |
+
self.conv3_4 = nn.Conv2d(256, 256, 3, 1, 0)
|
| 65 |
+
self.relu3_4 = nn.ReLU(inplace=True)
|
| 66 |
+
# 56 x 56
|
| 67 |
+
|
| 68 |
+
self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
|
| 69 |
+
# 28 x 28
|
| 70 |
+
|
| 71 |
+
self.pad4_1 = nn.ReflectionPad2d((1, 1, 1, 1))
|
| 72 |
+
self.conv4_1 = nn.Conv2d(256, 512, 3, 1, 0)
|
| 73 |
+
self.relu4_1 = nn.ReLU(inplace=True)
|
| 74 |
+
# 28 x 28
|
| 75 |
+
|
| 76 |
+
def forward(self, x):
|
| 77 |
+
out = self.conv0(x)
|
| 78 |
+
|
| 79 |
+
out = self.pad1_1(out)
|
| 80 |
+
out = self.conv1_1(out)
|
| 81 |
+
out = self.relu1_1(out)
|
| 82 |
+
|
| 83 |
+
if self.level < 2:
|
| 84 |
+
return out
|
| 85 |
+
|
| 86 |
+
out = self.pad1_2(out)
|
| 87 |
+
out = self.conv1_2(out)
|
| 88 |
+
pool1 = self.relu1_2(out)
|
| 89 |
+
|
| 90 |
+
out, pool1_idx = self.maxpool1(pool1)
|
| 91 |
+
|
| 92 |
+
out = self.pad2_1(out)
|
| 93 |
+
out = self.conv2_1(out)
|
| 94 |
+
out = self.relu2_1(out)
|
| 95 |
+
|
| 96 |
+
if self.level < 3:
|
| 97 |
+
return out, pool1_idx, pool1.size()
|
| 98 |
+
|
| 99 |
+
out = self.pad2_2(out)
|
| 100 |
+
out = self.conv2_2(out)
|
| 101 |
+
pool2 = self.relu2_2(out)
|
| 102 |
+
|
| 103 |
+
out, pool2_idx = self.maxpool2(pool2)
|
| 104 |
+
|
| 105 |
+
out = self.pad3_1(out)
|
| 106 |
+
out = self.conv3_1(out)
|
| 107 |
+
out = self.relu3_1(out)
|
| 108 |
+
|
| 109 |
+
if self.level < 4:
|
| 110 |
+
return out, pool1_idx, pool1.size(), pool2_idx, pool2.size()
|
| 111 |
+
|
| 112 |
+
out = self.pad3_2(out)
|
| 113 |
+
out = self.conv3_2(out)
|
| 114 |
+
out = self.relu3_2(out)
|
| 115 |
+
|
| 116 |
+
out = self.pad3_3(out)
|
| 117 |
+
out = self.conv3_3(out)
|
| 118 |
+
out = self.relu3_3(out)
|
| 119 |
+
|
| 120 |
+
out = self.pad3_4(out)
|
| 121 |
+
out = self.conv3_4(out)
|
| 122 |
+
pool3 = self.relu3_4(out)
|
| 123 |
+
out, pool3_idx = self.maxpool3(pool3)
|
| 124 |
+
|
| 125 |
+
out = self.pad4_1(out)
|
| 126 |
+
out = self.conv4_1(out)
|
| 127 |
+
out = self.relu4_1(out)
|
| 128 |
+
|
| 129 |
+
return out, pool1_idx, pool1.size(), pool2_idx, pool2.size(), pool3_idx, pool3.size()
|
| 130 |
+
|
| 131 |
+
def forward_multiple(self, x):
|
| 132 |
+
out = self.conv0(x)
|
| 133 |
+
|
| 134 |
+
out = self.pad1_1(out)
|
| 135 |
+
out = self.conv1_1(out)
|
| 136 |
+
out = self.relu1_1(out)
|
| 137 |
+
|
| 138 |
+
if self.level < 2: return out
|
| 139 |
+
|
| 140 |
+
out1 = out
|
| 141 |
+
|
| 142 |
+
out = self.pad1_2(out)
|
| 143 |
+
out = self.conv1_2(out)
|
| 144 |
+
pool1 = self.relu1_2(out)
|
| 145 |
+
|
| 146 |
+
out, pool1_idx = self.maxpool1(pool1)
|
| 147 |
+
|
| 148 |
+
out = self.pad2_1(out)
|
| 149 |
+
out = self.conv2_1(out)
|
| 150 |
+
out = self.relu2_1(out)
|
| 151 |
+
|
| 152 |
+
if self.level < 3: return out, out1
|
| 153 |
+
|
| 154 |
+
out2 = out
|
| 155 |
+
|
| 156 |
+
out = self.pad2_2(out)
|
| 157 |
+
out = self.conv2_2(out)
|
| 158 |
+
pool2 = self.relu2_2(out)
|
| 159 |
+
|
| 160 |
+
out, pool2_idx = self.maxpool2(pool2)
|
| 161 |
+
|
| 162 |
+
out = self.pad3_1(out)
|
| 163 |
+
out = self.conv3_1(out)
|
| 164 |
+
out = self.relu3_1(out)
|
| 165 |
+
|
| 166 |
+
if self.level < 4: return out, out2, out1
|
| 167 |
+
|
| 168 |
+
out3 = out
|
| 169 |
+
|
| 170 |
+
out = self.pad3_2(out)
|
| 171 |
+
out = self.conv3_2(out)
|
| 172 |
+
out = self.relu3_2(out)
|
| 173 |
+
|
| 174 |
+
out = self.pad3_3(out)
|
| 175 |
+
out = self.conv3_3(out)
|
| 176 |
+
out = self.relu3_3(out)
|
| 177 |
+
|
| 178 |
+
out = self.pad3_4(out)
|
| 179 |
+
out = self.conv3_4(out)
|
| 180 |
+
pool3 = self.relu3_4(out)
|
| 181 |
+
out, pool3_idx = self.maxpool3(pool3)
|
| 182 |
+
|
| 183 |
+
out = self.pad4_1(out)
|
| 184 |
+
out = self.conv4_1(out)
|
| 185 |
+
out = self.relu4_1(out)
|
| 186 |
+
|
| 187 |
+
return out, out3, out2, out1
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class VGGDecoder(nn.Module):
|
| 191 |
+
def __init__(self, level):
|
| 192 |
+
super(VGGDecoder, self).__init__()
|
| 193 |
+
self.level = level
|
| 194 |
+
|
| 195 |
+
if level > 3:
|
| 196 |
+
self.pad4_1 = nn.ReflectionPad2d((1, 1, 1, 1))
|
| 197 |
+
self.conv4_1 = nn.Conv2d(512, 256, 3, 1, 0)
|
| 198 |
+
self.relu4_1 = nn.ReLU(inplace=True)
|
| 199 |
+
# 28 x 28
|
| 200 |
+
|
| 201 |
+
self.unpool3 = nn.MaxUnpool2d(kernel_size=2, stride=2)
|
| 202 |
+
# 56 x 56
|
| 203 |
+
|
| 204 |
+
self.pad3_4 = nn.ReflectionPad2d((1, 1, 1, 1))
|
| 205 |
+
self.conv3_4 = nn.Conv2d(256, 256, 3, 1, 0)
|
| 206 |
+
self.relu3_4 = nn.ReLU(inplace=True)
|
| 207 |
+
# 56 x 56
|
| 208 |
+
|
| 209 |
+
self.pad3_3 = nn.ReflectionPad2d((1, 1, 1, 1))
|
| 210 |
+
self.conv3_3 = nn.Conv2d(256, 256, 3, 1, 0)
|
| 211 |
+
self.relu3_3 = nn.ReLU(inplace=True)
|
| 212 |
+
# 56 x 56
|
| 213 |
+
|
| 214 |
+
self.pad3_2 = nn.ReflectionPad2d((1, 1, 1, 1))
|
| 215 |
+
self.conv3_2 = nn.Conv2d(256, 256, 3, 1, 0)
|
| 216 |
+
self.relu3_2 = nn.ReLU(inplace=True)
|
| 217 |
+
# 56 x 56
|
| 218 |
+
|
| 219 |
+
if level > 2:
|
| 220 |
+
self.pad3_1 = nn.ReflectionPad2d((1, 1, 1, 1))
|
| 221 |
+
self.conv3_1 = nn.Conv2d(256, 128, 3, 1, 0)
|
| 222 |
+
self.relu3_1 = nn.ReLU(inplace=True)
|
| 223 |
+
# 56 x 56
|
| 224 |
+
|
| 225 |
+
self.unpool2 = nn.MaxUnpool2d(kernel_size=2, stride=2)
|
| 226 |
+
# 112 x 112
|
| 227 |
+
|
| 228 |
+
self.pad2_2 = nn.ReflectionPad2d((1, 1, 1, 1))
|
| 229 |
+
self.conv2_2 = nn.Conv2d(128, 128, 3, 1, 0)
|
| 230 |
+
self.relu2_2 = nn.ReLU(inplace=True)
|
| 231 |
+
# 112 x 112
|
| 232 |
+
|
| 233 |
+
if level > 1:
|
| 234 |
+
self.pad2_1 = nn.ReflectionPad2d((1, 1, 1, 1))
|
| 235 |
+
self.conv2_1 = nn.Conv2d(128, 64, 3, 1, 0)
|
| 236 |
+
self.relu2_1 = nn.ReLU(inplace=True)
|
| 237 |
+
# 112 x 112
|
| 238 |
+
|
| 239 |
+
self.unpool1 = nn.MaxUnpool2d(kernel_size=2, stride=2)
|
| 240 |
+
# 224 x 224
|
| 241 |
+
|
| 242 |
+
self.pad1_2 = nn.ReflectionPad2d((1, 1, 1, 1))
|
| 243 |
+
self.conv1_2 = nn.Conv2d(64, 64, 3, 1, 0)
|
| 244 |
+
self.relu1_2 = nn.ReLU(inplace=True)
|
| 245 |
+
# 224 x 224
|
| 246 |
+
|
| 247 |
+
if level > 0:
|
| 248 |
+
self.pad1_1 = nn.ReflectionPad2d((1, 1, 1, 1))
|
| 249 |
+
self.conv1_1 = nn.Conv2d(64, 3, 3, 1, 0)
|
| 250 |
+
|
| 251 |
+
def forward(self, x, pool1_idx=None, pool1_size=None, pool2_idx=None, pool2_size=None, pool3_idx=None,
|
| 252 |
+
pool3_size=None):
|
| 253 |
+
out = x
|
| 254 |
+
|
| 255 |
+
if self.level > 3:
|
| 256 |
+
out = self.pad4_1(out)
|
| 257 |
+
out = self.conv4_1(out)
|
| 258 |
+
out = self.relu4_1(out)
|
| 259 |
+
out = self.unpool3(out, pool3_idx, output_size=pool3_size)
|
| 260 |
+
|
| 261 |
+
out = self.pad3_4(out)
|
| 262 |
+
out = self.conv3_4(out)
|
| 263 |
+
out = self.relu3_4(out)
|
| 264 |
+
|
| 265 |
+
out = self.pad3_3(out)
|
| 266 |
+
out = self.conv3_3(out)
|
| 267 |
+
out = self.relu3_3(out)
|
| 268 |
+
|
| 269 |
+
out = self.pad3_2(out)
|
| 270 |
+
out = self.conv3_2(out)
|
| 271 |
+
out = self.relu3_2(out)
|
| 272 |
+
|
| 273 |
+
if self.level > 2:
|
| 274 |
+
out = self.pad3_1(out)
|
| 275 |
+
out = self.conv3_1(out)
|
| 276 |
+
out = self.relu3_1(out)
|
| 277 |
+
out = self.unpool2(out, pool2_idx, output_size=pool2_size)
|
| 278 |
+
|
| 279 |
+
out = self.pad2_2(out)
|
| 280 |
+
out = self.conv2_2(out)
|
| 281 |
+
out = self.relu2_2(out)
|
| 282 |
+
|
| 283 |
+
if self.level > 1:
|
| 284 |
+
out = self.pad2_1(out)
|
| 285 |
+
out = self.conv2_1(out)
|
| 286 |
+
out = self.relu2_1(out)
|
| 287 |
+
out = self.unpool1(out, pool1_idx, output_size=pool1_size)
|
| 288 |
+
|
| 289 |
+
out = self.pad1_2(out)
|
| 290 |
+
out = self.conv1_2(out)
|
| 291 |
+
out = self.relu1_2(out)
|
| 292 |
+
|
| 293 |
+
if self.level > 0:
|
| 294 |
+
out = self.pad1_1(out)
|
| 295 |
+
out = self.conv1_1(out)
|
| 296 |
+
|
| 297 |
+
return out
|
requirements.txt
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
--extra-index-url https://download.pytorch.org/whl/cu116
|
| 2 |
+
torch
|
| 3 |
+
diffusers
|
| 4 |
+
transformers
|
| 5 |
+
scipy
|
| 6 |
+
ftfy
|
| 7 |
+
gradio
|
| 8 |
+
torchvision
|
| 9 |
+
scikit-image
|
| 10 |
+
rembg
|
| 11 |
+
replicate
|
| 12 |
+
requests
|
| 13 |
+
Pillow
|
| 14 |
+
numpy
|
| 15 |
+
scipy
|
| 16 |
+
pyrootutils
|
| 17 |
+
pynvrtc
|
| 18 |
+
cupy
|
utils/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.utils.pylogger import get_pylogger
|
| 2 |
+
from src.utils.rich_utils import enforce_tags, print_config_tree
|
| 3 |
+
from src.utils.utils import (
|
| 4 |
+
close_loggers,
|
| 5 |
+
extras,
|
| 6 |
+
get_metric_value,
|
| 7 |
+
instantiate_callbacks,
|
| 8 |
+
instantiate_loggers,
|
| 9 |
+
log_hyperparameters,
|
| 10 |
+
save_file,
|
| 11 |
+
task_wrapper,
|
| 12 |
+
)
|
utils/photo_smooth.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright (C) 2018 NVIDIA Corporation. All rights reserved.
|
| 3 |
+
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
| 4 |
+
"""
|
| 5 |
+
from __future__ import division
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import scipy.misc
|
| 8 |
+
import scipy._lib
|
| 9 |
+
import numpy as np
|
| 10 |
+
import scipy.sparse
|
| 11 |
+
import scipy.sparse.linalg as linalg
|
| 12 |
+
from numpy.lib.stride_tricks import as_strided
|
| 13 |
+
from PIL import Image
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Propagator(nn.Module):
|
| 17 |
+
def __init__(self, beta=0.9999):
|
| 18 |
+
super(Propagator, self).__init__()
|
| 19 |
+
self.beta = beta
|
| 20 |
+
|
| 21 |
+
def process(self, initImg, contentImg):
|
| 22 |
+
|
| 23 |
+
if type(contentImg) == str:
|
| 24 |
+
content = scipy.misc.imread(contentImg, mode='RGB')
|
| 25 |
+
else:
|
| 26 |
+
content = contentImg.copy()
|
| 27 |
+
# content = scipy.misc.imread(contentImg, mode='RGB')
|
| 28 |
+
|
| 29 |
+
if type(initImg) == str:
|
| 30 |
+
B = scipy.misc.imread(initImg, mode='RGB').astype(np.float64) / 255
|
| 31 |
+
else:
|
| 32 |
+
B = scipy.asarray(initImg).astype(np.float64) / 255
|
| 33 |
+
# B = self.
|
| 34 |
+
# B = scipy.misc.imread(initImg, mode='RGB').astype(np.float64)/255
|
| 35 |
+
h1,w1,k = B.shape
|
| 36 |
+
h = h1 - 4
|
| 37 |
+
w = w1 - 4
|
| 38 |
+
B = B[int((h1-h)/2):int((h1-h)/2+h),int((w1-w)/2):int((w1-w)/2+w),:]
|
| 39 |
+
#content = scipy.misc.imresize(content,(h,w))
|
| 40 |
+
content = np.asarray(Image.fromarray(np.array(content)).resize((h,w),Image.BICUBIC))
|
| 41 |
+
B = self.__replication_padding(B,2)
|
| 42 |
+
content = self.__replication_padding(content,2)
|
| 43 |
+
content = content.astype(np.float64)/255
|
| 44 |
+
B = np.reshape(B,(h1*w1,k))
|
| 45 |
+
W = self.__compute_laplacian(content)
|
| 46 |
+
W = W.tocsc()
|
| 47 |
+
dd = W.sum(0)
|
| 48 |
+
dd = np.sqrt(np.power(dd,-1))
|
| 49 |
+
dd = dd.A.squeeze()
|
| 50 |
+
D = scipy.sparse.csc_matrix((dd, (np.arange(0,w1*h1), np.arange(0,w1*h1)))) # 0.026
|
| 51 |
+
S = D.dot(W).dot(D)
|
| 52 |
+
A = scipy.sparse.identity(w1*h1) - self.beta*S
|
| 53 |
+
A = A.tocsc()
|
| 54 |
+
solver = linalg.factorized(A)
|
| 55 |
+
V = np.zeros((h1*w1,k))
|
| 56 |
+
V[:,0] = solver(B[:,0])
|
| 57 |
+
V[:,1] = solver(B[:,1])
|
| 58 |
+
V[:,2] = solver(B[:,2])
|
| 59 |
+
V = V*(1-self.beta)
|
| 60 |
+
V = V.reshape(h1,w1,k)
|
| 61 |
+
V = V[2:2+h,2:2+w,:]
|
| 62 |
+
|
| 63 |
+
img = Image.fromarray(np.uint8(np.clip(V * 255., 0, 255.)))
|
| 64 |
+
return img
|
| 65 |
+
|
| 66 |
+
# Returns sparse matting laplacian
|
| 67 |
+
# The implementation of the function is heavily borrowed from
|
| 68 |
+
# https://github.com/MarcoForte/closed-form-matting/blob/master/closed_form_matting.py
|
| 69 |
+
# We thank Marco Forte for sharing his code.
|
| 70 |
+
def __compute_laplacian(self, img, eps=10**(-7), win_rad=1):
|
| 71 |
+
win_size = (win_rad*2+1)**2
|
| 72 |
+
h, w, d = img.shape
|
| 73 |
+
c_h, c_w = h - 2*win_rad, w - 2*win_rad
|
| 74 |
+
win_diam = win_rad*2+1
|
| 75 |
+
indsM = np.arange(h*w).reshape((h, w))
|
| 76 |
+
ravelImg = img.reshape(h*w, d)
|
| 77 |
+
win_inds = self.__rolling_block(indsM, block=(win_diam, win_diam))
|
| 78 |
+
win_inds = win_inds.reshape(c_h, c_w, win_size)
|
| 79 |
+
winI = ravelImg[win_inds]
|
| 80 |
+
win_mu = np.mean(winI, axis=2, keepdims=True)
|
| 81 |
+
win_var = np.einsum('...ji,...jk ->...ik', winI, winI)/win_size - np.einsum('...ji,...jk ->...ik', win_mu, win_mu)
|
| 82 |
+
inv = np.linalg.inv(win_var + (eps/win_size)*np.eye(3))
|
| 83 |
+
X = np.einsum('...ij,...jk->...ik', winI - win_mu, inv)
|
| 84 |
+
vals = (1/win_size)*(1 + np.einsum('...ij,...kj->...ik', X, winI - win_mu))
|
| 85 |
+
nz_indsCol = np.tile(win_inds, win_size).ravel()
|
| 86 |
+
nz_indsRow = np.repeat(win_inds, win_size).ravel()
|
| 87 |
+
nz_indsVal = vals.ravel()
|
| 88 |
+
L = scipy.sparse.coo_matrix((nz_indsVal, (nz_indsRow, nz_indsCol)), shape=(h*w, h*w))
|
| 89 |
+
return L
|
| 90 |
+
|
| 91 |
+
def __replication_padding(self, arr,pad):
|
| 92 |
+
h,w,c = arr.shape
|
| 93 |
+
ans = np.zeros((h+pad*2,w+pad*2,c))
|
| 94 |
+
for i in range(c):
|
| 95 |
+
ans[:,:,i] = np.pad(arr[:,:,i],pad_width=(pad,pad),mode='edge')
|
| 96 |
+
return ans
|
| 97 |
+
|
| 98 |
+
def __rolling_block(self, A, block=(3, 3)):
|
| 99 |
+
shape = (A.shape[0] - block[0] + 1, A.shape[1] - block[1] + 1) + block
|
| 100 |
+
strides = (A.strides[0], A.strides[1]) + A.strides
|
| 101 |
+
return as_strided(A, shape=shape, strides=strides)
|
utils/photo_wct.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright (C) 2018 NVIDIA Corporation. All rights reserved.
|
| 3 |
+
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
from models.models import VGGEncoder, VGGDecoder
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class PhotoWCT(nn.Module):
|
| 14 |
+
def __init__(self):
|
| 15 |
+
super(PhotoWCT, self).__init__()
|
| 16 |
+
self.e1 = VGGEncoder(1)
|
| 17 |
+
self.d1 = VGGDecoder(1)
|
| 18 |
+
self.e2 = VGGEncoder(2)
|
| 19 |
+
self.d2 = VGGDecoder(2)
|
| 20 |
+
self.e3 = VGGEncoder(3)
|
| 21 |
+
self.d3 = VGGDecoder(3)
|
| 22 |
+
self.e4 = VGGEncoder(4)
|
| 23 |
+
self.d4 = VGGDecoder(4)
|
| 24 |
+
|
| 25 |
+
def transform(self, cont_img, styl_img, cont_seg, styl_seg):
|
| 26 |
+
self.__compute_label_info(cont_seg, styl_seg)
|
| 27 |
+
|
| 28 |
+
sF4, sF3, sF2, sF1 = self.e4.forward_multiple(styl_img)
|
| 29 |
+
|
| 30 |
+
cF4, cpool_idx, cpool1, cpool_idx2, cpool2, cpool_idx3, cpool3 = self.e4(cont_img)
|
| 31 |
+
sF4 = sF4.data.squeeze(0)
|
| 32 |
+
cF4 = cF4.data.squeeze(0)
|
| 33 |
+
# print(cont_seg)
|
| 34 |
+
csF4 = self.__feature_wct(cF4, sF4, cont_seg, styl_seg)
|
| 35 |
+
Im4 = self.d4(csF4, cpool_idx, cpool1, cpool_idx2, cpool2, cpool_idx3, cpool3)
|
| 36 |
+
|
| 37 |
+
cF3, cpool_idx, cpool1, cpool_idx2, cpool2 = self.e3(Im4)
|
| 38 |
+
sF3 = sF3.data.squeeze(0)
|
| 39 |
+
cF3 = cF3.data.squeeze(0)
|
| 40 |
+
csF3 = self.__feature_wct(cF3, sF3, cont_seg, styl_seg)
|
| 41 |
+
Im3 = self.d3(csF3, cpool_idx, cpool1, cpool_idx2, cpool2)
|
| 42 |
+
|
| 43 |
+
cF2, cpool_idx, cpool = self.e2(Im3)
|
| 44 |
+
sF2 = sF2.data.squeeze(0)
|
| 45 |
+
cF2 = cF2.data.squeeze(0)
|
| 46 |
+
csF2 = self.__feature_wct(cF2, sF2, cont_seg, styl_seg)
|
| 47 |
+
Im2 = self.d2(csF2, cpool_idx, cpool)
|
| 48 |
+
|
| 49 |
+
cF1 = self.e1(Im2)
|
| 50 |
+
sF1 = sF1.data.squeeze(0)
|
| 51 |
+
cF1 = cF1.data.squeeze(0)
|
| 52 |
+
csF1 = self.__feature_wct(cF1, sF1, cont_seg, styl_seg)
|
| 53 |
+
Im1 = self.d1(csF1)
|
| 54 |
+
return Im1
|
| 55 |
+
|
| 56 |
+
def __compute_label_info(self, cont_seg, styl_seg):
|
| 57 |
+
if cont_seg.size == False or styl_seg.size == False:
|
| 58 |
+
return
|
| 59 |
+
max_label = np.max(cont_seg) + 1
|
| 60 |
+
self.label_set = np.unique(cont_seg)
|
| 61 |
+
self.label_indicator = np.zeros(max_label)
|
| 62 |
+
for l in self.label_set:
|
| 63 |
+
# if l==0:
|
| 64 |
+
# continue
|
| 65 |
+
is_valid = lambda a, b: a > 10 and b > 10 and a / b < 100 and b / a < 100
|
| 66 |
+
o_cont_mask = np.where(cont_seg.reshape(cont_seg.shape[0] * cont_seg.shape[1]) == l)
|
| 67 |
+
o_styl_mask = np.where(styl_seg.reshape(styl_seg.shape[0] * styl_seg.shape[1]) == l)
|
| 68 |
+
self.label_indicator[l] = is_valid(o_cont_mask[0].size, o_styl_mask[0].size)
|
| 69 |
+
|
| 70 |
+
def __feature_wct(self, cont_feat, styl_feat, cont_seg, styl_seg):
|
| 71 |
+
cont_c, cont_h, cont_w = cont_feat.size(0), cont_feat.size(1), cont_feat.size(2)
|
| 72 |
+
styl_c, styl_h, styl_w = styl_feat.size(0), styl_feat.size(1), styl_feat.size(2)
|
| 73 |
+
cont_feat_view = cont_feat.view(cont_c, -1).clone()
|
| 74 |
+
styl_feat_view = styl_feat.view(styl_c, -1).clone()
|
| 75 |
+
|
| 76 |
+
if cont_seg.size == False or styl_seg.size == False:
|
| 77 |
+
target_feature = self.__wct_core(cont_feat_view, styl_feat_view)
|
| 78 |
+
else:
|
| 79 |
+
target_feature = cont_feat.view(cont_c, -1).clone()
|
| 80 |
+
if len(cont_seg.shape) == 2:
|
| 81 |
+
t_cont_seg = np.asarray(Image.fromarray(cont_seg).resize((cont_w, cont_h), Image.NEAREST))
|
| 82 |
+
else:
|
| 83 |
+
t_cont_seg = np.asarray(Image.fromarray(cont_seg, mode='RGB').resize((cont_w, cont_h), Image.NEAREST))
|
| 84 |
+
if len(styl_seg.shape) == 2:
|
| 85 |
+
t_styl_seg = np.asarray(Image.fromarray(styl_seg).resize((styl_w, styl_h), Image.NEAREST))
|
| 86 |
+
else:
|
| 87 |
+
t_styl_seg = np.asarray(Image.fromarray(styl_seg, mode='RGB').resize((styl_w, styl_h), Image.NEAREST))
|
| 88 |
+
|
| 89 |
+
for l in self.label_set:
|
| 90 |
+
if self.label_indicator[l] == 0:
|
| 91 |
+
continue
|
| 92 |
+
cont_mask = np.where(t_cont_seg.reshape(t_cont_seg.shape[0] * t_cont_seg.shape[1]) == l)
|
| 93 |
+
styl_mask = np.where(t_styl_seg.reshape(t_styl_seg.shape[0] * t_styl_seg.shape[1]) == l)
|
| 94 |
+
if cont_mask[0].size <= 0 or styl_mask[0].size <= 0:
|
| 95 |
+
continue
|
| 96 |
+
|
| 97 |
+
cont_indi = torch.LongTensor(cont_mask[0])
|
| 98 |
+
styl_indi = torch.LongTensor(styl_mask[0])
|
| 99 |
+
if self.is_cuda:
|
| 100 |
+
cont_indi = cont_indi.cuda(0)
|
| 101 |
+
styl_indi = styl_indi.cuda(0)
|
| 102 |
+
|
| 103 |
+
cFFG = torch.index_select(cont_feat_view, 1, cont_indi)
|
| 104 |
+
sFFG = torch.index_select(styl_feat_view, 1, styl_indi)
|
| 105 |
+
# print(len(cont_indi))
|
| 106 |
+
# print(len(styl_indi))
|
| 107 |
+
tmp_target_feature = self.__wct_core(cFFG, sFFG)
|
| 108 |
+
# print(tmp_target_feature.size())
|
| 109 |
+
if torch.__version__ >= "0.4.0":
|
| 110 |
+
# This seems to be a bug in PyTorch 0.4.0 to me.
|
| 111 |
+
new_target_feature = torch.transpose(target_feature, 1, 0)
|
| 112 |
+
new_target_feature.index_copy_(0, cont_indi, \
|
| 113 |
+
torch.transpose(tmp_target_feature,1,0))
|
| 114 |
+
target_feature = torch.transpose(new_target_feature, 1, 0)
|
| 115 |
+
else:
|
| 116 |
+
target_feature.index_copy_(1, cont_indi, tmp_target_feature)
|
| 117 |
+
|
| 118 |
+
target_feature = target_feature.view_as(cont_feat)
|
| 119 |
+
ccsF = target_feature.float().unsqueeze(0)
|
| 120 |
+
return ccsF
|
| 121 |
+
|
| 122 |
+
def __wct_core(self, cont_feat, styl_feat):
|
| 123 |
+
cFSize = cont_feat.size()
|
| 124 |
+
c_mean = torch.mean(cont_feat, 1) # c x (h x w)
|
| 125 |
+
c_mean = c_mean.unsqueeze(1).expand_as(cont_feat)
|
| 126 |
+
cont_feat = cont_feat - c_mean
|
| 127 |
+
|
| 128 |
+
iden = torch.eye(cFSize[0]) # .double()
|
| 129 |
+
if self.is_cuda:
|
| 130 |
+
iden = iden.cuda()
|
| 131 |
+
|
| 132 |
+
contentConv = torch.mm(cont_feat, cont_feat.t()).div(cFSize[1] - 1) + iden
|
| 133 |
+
# del iden
|
| 134 |
+
c_u, c_e, c_v = torch.svd(contentConv, some=False)
|
| 135 |
+
# c_e2, c_v = torch.eig(contentConv, True)
|
| 136 |
+
# c_e = c_e2[:,0]
|
| 137 |
+
|
| 138 |
+
k_c = cFSize[0]
|
| 139 |
+
for i in range(cFSize[0] - 1, -1, -1):
|
| 140 |
+
if c_e[i] >= 0.00001:
|
| 141 |
+
k_c = i + 1
|
| 142 |
+
break
|
| 143 |
+
|
| 144 |
+
sFSize = styl_feat.size()
|
| 145 |
+
s_mean = torch.mean(styl_feat, 1)
|
| 146 |
+
styl_feat = styl_feat - s_mean.unsqueeze(1).expand_as(styl_feat)
|
| 147 |
+
styleConv = torch.mm(styl_feat, styl_feat.t()).div(sFSize[1] - 1)
|
| 148 |
+
s_u, s_e, s_v = torch.svd(styleConv, some=False)
|
| 149 |
+
|
| 150 |
+
k_s = sFSize[0]
|
| 151 |
+
for i in range(sFSize[0] - 1, -1, -1):
|
| 152 |
+
if s_e[i] >= 0.00001:
|
| 153 |
+
k_s = i + 1
|
| 154 |
+
break
|
| 155 |
+
|
| 156 |
+
c_d = (c_e[0:k_c]).pow(-0.5)
|
| 157 |
+
step1 = torch.mm(c_v[:, 0:k_c], torch.diag(c_d))
|
| 158 |
+
step2 = torch.mm(step1, (c_v[:, 0:k_c].t()))
|
| 159 |
+
whiten_cF = torch.mm(step2, cont_feat)
|
| 160 |
+
|
| 161 |
+
s_d = (s_e[0:k_s]).pow(0.5)
|
| 162 |
+
targetFeature = torch.mm(torch.mm(torch.mm(s_v[:, 0:k_s], torch.diag(s_d)), (s_v[:, 0:k_s].t())), whiten_cF)
|
| 163 |
+
targetFeature = targetFeature + s_mean.unsqueeze(1).expand_as(targetFeature)
|
| 164 |
+
return targetFeature
|
| 165 |
+
|
| 166 |
+
@property
|
| 167 |
+
def is_cuda(self):
|
| 168 |
+
return next(self.parameters()).is_cuda
|
| 169 |
+
|
| 170 |
+
def forward(self, *input):
|
| 171 |
+
pass
|
utils/shared_utils.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from rembg import remove
|
| 3 |
+
import io
|
| 4 |
+
|
| 5 |
+
# Apply the transformations needed
|
| 6 |
+
from torch import autocast, nn
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch
|
| 10 |
+
import torchvision.transforms as transforms
|
| 11 |
+
import torchvision.utils as utils
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import pyrootutils
|
| 14 |
+
from PIL import Image
|
| 15 |
+
import numpy as np
|
| 16 |
+
from utils.photo_wct import PhotoWCT
|
| 17 |
+
from utils.photo_smooth import Propagator
|
| 18 |
+
|
| 19 |
+
# Load models
|
| 20 |
+
root = pyrootutils.setup_root(Path.cwd(), pythonpath=True)
|
| 21 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 22 |
+
# Load model
|
| 23 |
+
p_wct = PhotoWCT()
|
| 24 |
+
p_wct.load_state_dict(torch.load(root/"models/components/photo_wct.pth"))
|
| 25 |
+
p_pro = Propagator()
|
| 26 |
+
stylization_module=p_wct
|
| 27 |
+
smoothing_module=p_pro
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
#Dependecies - To be installed -
|
| 31 |
+
#!pip install replicate
|
| 32 |
+
#Token - To be authenticated -
|
| 33 |
+
#API TOKEN - 664474670af075461f85420f7b1d23d18484f826
|
| 34 |
+
#To be declared as an environment variable -
|
| 35 |
+
#export REPLICATE_API_TOKEN =
|
| 36 |
+
import replicate
|
| 37 |
+
import os
|
| 38 |
+
import requests
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def stableDiffusionAPICall(text_prompt):
|
| 43 |
+
os.environ['REPLICATE_API_TOKEN'] = 'a9f4c06cb9808f42b29637bb60b7b88f106ad5b8'
|
| 44 |
+
model = replicate.models.get("stability-ai/stable-diffusion")
|
| 45 |
+
#text_prompt = 'photorealistic, elf fighting Sauron'
|
| 46 |
+
gen_bg_img = model.predict(prompt=text_prompt)[0]
|
| 47 |
+
img_data = requests.get(gen_bg_img).content
|
| 48 |
+
# r_data = binascii.unhexlify(img_data)
|
| 49 |
+
stream = io.BytesIO(img_data)
|
| 50 |
+
img = Image.open(stream)
|
| 51 |
+
del img_data
|
| 52 |
+
|
| 53 |
+
return img
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def memory_limit_image_resize(cont_img):
|
| 58 |
+
# prevent too small or too big images
|
| 59 |
+
MINSIZE=400
|
| 60 |
+
MAXSIZE=800
|
| 61 |
+
orig_width = cont_img.width
|
| 62 |
+
orig_height = cont_img.height
|
| 63 |
+
if max(cont_img.width,cont_img.height) < MINSIZE:
|
| 64 |
+
if cont_img.width > cont_img.height:
|
| 65 |
+
cont_img.thumbnail((int(cont_img.width*1.0/cont_img.height*MINSIZE), MINSIZE), Image.BICUBIC)
|
| 66 |
+
else:
|
| 67 |
+
cont_img.thumbnail((MINSIZE, int(cont_img.height*1.0/cont_img.width*MINSIZE)), Image.BICUBIC)
|
| 68 |
+
if min(cont_img.width,cont_img.height) > MAXSIZE:
|
| 69 |
+
if cont_img.width > cont_img.height:
|
| 70 |
+
cont_img.thumbnail((MAXSIZE, int(cont_img.height*1.0/cont_img.width*MAXSIZE)), Image.BICUBIC)
|
| 71 |
+
else:
|
| 72 |
+
cont_img.thumbnail(((int(cont_img.width*1.0/cont_img.height*MAXSIZE), MAXSIZE)), Image.BICUBIC)
|
| 73 |
+
print("Resize image: (%d,%d)->(%d,%d)" % (orig_width, orig_height, cont_img.width, cont_img.height))
|
| 74 |
+
return cont_img.width, cont_img.height
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def superimpose(input_img,back_img):
|
| 81 |
+
matte_img = remove(input_img)
|
| 82 |
+
back_img.paste(matte_img, (0, 0), matte_img)
|
| 83 |
+
return back_img
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def style_transfer(cont_img,styl_img):
|
| 88 |
+
with torch.no_grad():
|
| 89 |
+
new_cw, new_ch = memory_limit_image_resize(cont_img)
|
| 90 |
+
new_sw, new_sh = memory_limit_image_resize(styl_img)
|
| 91 |
+
cont_pilimg = cont_img.copy()
|
| 92 |
+
cw = cont_pilimg.width
|
| 93 |
+
ch = cont_pilimg.height
|
| 94 |
+
cont_img = transforms.ToTensor()(cont_img).unsqueeze(0)
|
| 95 |
+
styl_img = transforms.ToTensor()(styl_img).unsqueeze(0)
|
| 96 |
+
|
| 97 |
+
cont_seg = []
|
| 98 |
+
styl_seg = []
|
| 99 |
+
|
| 100 |
+
if device == 'cuda':
|
| 101 |
+
cont_img = cont_img.to(device)
|
| 102 |
+
styl_img = styl_img.to(device)
|
| 103 |
+
stylization_module.to(device)
|
| 104 |
+
cont_seg = np.asarray(cont_seg)
|
| 105 |
+
styl_seg = np.asarray(styl_seg)
|
| 106 |
+
|
| 107 |
+
stylized_img = stylization_module.transform(cont_img, styl_img, cont_seg, styl_seg)
|
| 108 |
+
if ch != new_ch or cw != new_cw:
|
| 109 |
+
stylized_img = nn.functional.upsample(stylized_img, size=(ch, cw), mode='bilinear')
|
| 110 |
+
grid = utils.make_grid(stylized_img.data, nrow=1, padding=0)
|
| 111 |
+
ndarr = grid.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
|
| 112 |
+
stylized_img = Image.fromarray(ndarr)
|
| 113 |
+
#final_img = smooth_filter(stylized_img, cont_pilimg, f_radius=15, f_edge=1e-1)
|
| 114 |
+
return stylized_img
|
| 115 |
+
|
| 116 |
+
def smoother(stylized_img, over_img):
|
| 117 |
+
final_img = smoothing_module.process(stylized_img, over_img)
|
| 118 |
+
return final_img
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
if __name__ == "__main__":
|
| 122 |
+
root = pyrootutils.setup_root(__file__, pythonpath=True)
|
| 123 |
+
fg_path = root/"notebooks/profile_new.png"
|
| 124 |
+
bg_path = root/"notebooks/back_img.png"
|
| 125 |
+
ckpt_path = root/"src/models/MODNet/pretrained/modnet_photographic_portrait_matting.ckpt"
|
| 126 |
+
|
| 127 |
+
#stableDiffusionAPICall("Photorealistic scenery of a concert")
|
| 128 |
+
fg_img = Image.open(fg_path).resize((800,800))
|
| 129 |
+
bg_img = Image.open(bg_path).resize((800,800))
|
| 130 |
+
#img = combined_display(fg_img, bg_img,ckpt_path)
|
| 131 |
+
img = superimpose(fg_img,bg_img)
|
| 132 |
+
img.save(root/"notebooks/overlay.png")
|
| 133 |
+
# bg_img.paste(img, (0, 0), img)
|
| 134 |
+
# bg_img.save(root/"notebooks/check.png")
|
| 135 |
+
|
| 136 |
+
|
utils/smooth_filter.py
ADDED
|
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright (C) 2018 NVIDIA Corporation. All rights reserved.
|
| 3 |
+
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
| 4 |
+
"""
|
| 5 |
+
src = '''
|
| 6 |
+
#include "/usr/local/cuda/include/math_functions.h"
|
| 7 |
+
#define TB 256
|
| 8 |
+
#define EPS 1e-7
|
| 9 |
+
|
| 10 |
+
__device__ bool InverseMat4x4(double m_in[4][4], double inv_out[4][4]) {
|
| 11 |
+
double m[16], inv[16];
|
| 12 |
+
for (int i = 0; i < 4; i++) {
|
| 13 |
+
for (int j = 0; j < 4; j++) {
|
| 14 |
+
m[i * 4 + j] = m_in[i][j];
|
| 15 |
+
}
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
inv[0] = m[5] * m[10] * m[15] -
|
| 19 |
+
m[5] * m[11] * m[14] -
|
| 20 |
+
m[9] * m[6] * m[15] +
|
| 21 |
+
m[9] * m[7] * m[14] +
|
| 22 |
+
m[13] * m[6] * m[11] -
|
| 23 |
+
m[13] * m[7] * m[10];
|
| 24 |
+
|
| 25 |
+
inv[4] = -m[4] * m[10] * m[15] +
|
| 26 |
+
m[4] * m[11] * m[14] +
|
| 27 |
+
m[8] * m[6] * m[15] -
|
| 28 |
+
m[8] * m[7] * m[14] -
|
| 29 |
+
m[12] * m[6] * m[11] +
|
| 30 |
+
m[12] * m[7] * m[10];
|
| 31 |
+
|
| 32 |
+
inv[8] = m[4] * m[9] * m[15] -
|
| 33 |
+
m[4] * m[11] * m[13] -
|
| 34 |
+
m[8] * m[5] * m[15] +
|
| 35 |
+
m[8] * m[7] * m[13] +
|
| 36 |
+
m[12] * m[5] * m[11] -
|
| 37 |
+
m[12] * m[7] * m[9];
|
| 38 |
+
|
| 39 |
+
inv[12] = -m[4] * m[9] * m[14] +
|
| 40 |
+
m[4] * m[10] * m[13] +
|
| 41 |
+
m[8] * m[5] * m[14] -
|
| 42 |
+
m[8] * m[6] * m[13] -
|
| 43 |
+
m[12] * m[5] * m[10] +
|
| 44 |
+
m[12] * m[6] * m[9];
|
| 45 |
+
|
| 46 |
+
inv[1] = -m[1] * m[10] * m[15] +
|
| 47 |
+
m[1] * m[11] * m[14] +
|
| 48 |
+
m[9] * m[2] * m[15] -
|
| 49 |
+
m[9] * m[3] * m[14] -
|
| 50 |
+
m[13] * m[2] * m[11] +
|
| 51 |
+
m[13] * m[3] * m[10];
|
| 52 |
+
|
| 53 |
+
inv[5] = m[0] * m[10] * m[15] -
|
| 54 |
+
m[0] * m[11] * m[14] -
|
| 55 |
+
m[8] * m[2] * m[15] +
|
| 56 |
+
m[8] * m[3] * m[14] +
|
| 57 |
+
m[12] * m[2] * m[11] -
|
| 58 |
+
m[12] * m[3] * m[10];
|
| 59 |
+
|
| 60 |
+
inv[9] = -m[0] * m[9] * m[15] +
|
| 61 |
+
m[0] * m[11] * m[13] +
|
| 62 |
+
m[8] * m[1] * m[15] -
|
| 63 |
+
m[8] * m[3] * m[13] -
|
| 64 |
+
m[12] * m[1] * m[11] +
|
| 65 |
+
m[12] * m[3] * m[9];
|
| 66 |
+
|
| 67 |
+
inv[13] = m[0] * m[9] * m[14] -
|
| 68 |
+
m[0] * m[10] * m[13] -
|
| 69 |
+
m[8] * m[1] * m[14] +
|
| 70 |
+
m[8] * m[2] * m[13] +
|
| 71 |
+
m[12] * m[1] * m[10] -
|
| 72 |
+
m[12] * m[2] * m[9];
|
| 73 |
+
|
| 74 |
+
inv[2] = m[1] * m[6] * m[15] -
|
| 75 |
+
m[1] * m[7] * m[14] -
|
| 76 |
+
m[5] * m[2] * m[15] +
|
| 77 |
+
m[5] * m[3] * m[14] +
|
| 78 |
+
m[13] * m[2] * m[7] -
|
| 79 |
+
m[13] * m[3] * m[6];
|
| 80 |
+
|
| 81 |
+
inv[6] = -m[0] * m[6] * m[15] +
|
| 82 |
+
m[0] * m[7] * m[14] +
|
| 83 |
+
m[4] * m[2] * m[15] -
|
| 84 |
+
m[4] * m[3] * m[14] -
|
| 85 |
+
m[12] * m[2] * m[7] +
|
| 86 |
+
m[12] * m[3] * m[6];
|
| 87 |
+
|
| 88 |
+
inv[10] = m[0] * m[5] * m[15] -
|
| 89 |
+
m[0] * m[7] * m[13] -
|
| 90 |
+
m[4] * m[1] * m[15] +
|
| 91 |
+
m[4] * m[3] * m[13] +
|
| 92 |
+
m[12] * m[1] * m[7] -
|
| 93 |
+
m[12] * m[3] * m[5];
|
| 94 |
+
|
| 95 |
+
inv[14] = -m[0] * m[5] * m[14] +
|
| 96 |
+
m[0] * m[6] * m[13] +
|
| 97 |
+
m[4] * m[1] * m[14] -
|
| 98 |
+
m[4] * m[2] * m[13] -
|
| 99 |
+
m[12] * m[1] * m[6] +
|
| 100 |
+
m[12] * m[2] * m[5];
|
| 101 |
+
|
| 102 |
+
inv[3] = -m[1] * m[6] * m[11] +
|
| 103 |
+
m[1] * m[7] * m[10] +
|
| 104 |
+
m[5] * m[2] * m[11] -
|
| 105 |
+
m[5] * m[3] * m[10] -
|
| 106 |
+
m[9] * m[2] * m[7] +
|
| 107 |
+
m[9] * m[3] * m[6];
|
| 108 |
+
|
| 109 |
+
inv[7] = m[0] * m[6] * m[11] -
|
| 110 |
+
m[0] * m[7] * m[10] -
|
| 111 |
+
m[4] * m[2] * m[11] +
|
| 112 |
+
m[4] * m[3] * m[10] +
|
| 113 |
+
m[8] * m[2] * m[7] -
|
| 114 |
+
m[8] * m[3] * m[6];
|
| 115 |
+
|
| 116 |
+
inv[11] = -m[0] * m[5] * m[11] +
|
| 117 |
+
m[0] * m[7] * m[9] +
|
| 118 |
+
m[4] * m[1] * m[11] -
|
| 119 |
+
m[4] * m[3] * m[9] -
|
| 120 |
+
m[8] * m[1] * m[7] +
|
| 121 |
+
m[8] * m[3] * m[5];
|
| 122 |
+
|
| 123 |
+
inv[15] = m[0] * m[5] * m[10] -
|
| 124 |
+
m[0] * m[6] * m[9] -
|
| 125 |
+
m[4] * m[1] * m[10] +
|
| 126 |
+
m[4] * m[2] * m[9] +
|
| 127 |
+
m[8] * m[1] * m[6] -
|
| 128 |
+
m[8] * m[2] * m[5];
|
| 129 |
+
|
| 130 |
+
double det = m[0] * inv[0] + m[1] * inv[4] + m[2] * inv[8] + m[3] * inv[12];
|
| 131 |
+
|
| 132 |
+
if (abs(det) < 1e-9) {
|
| 133 |
+
return false;
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
det = 1.0 / det;
|
| 138 |
+
|
| 139 |
+
for (int i = 0; i < 4; i++) {
|
| 140 |
+
for (int j = 0; j < 4; j++) {
|
| 141 |
+
inv_out[i][j] = inv[i * 4 + j] * det;
|
| 142 |
+
}
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
return true;
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
extern "C"
|
| 149 |
+
__global__ void best_local_affine_kernel(
|
| 150 |
+
float *output, float *input, float *affine_model,
|
| 151 |
+
int h, int w, float epsilon, int kernel_radius
|
| 152 |
+
)
|
| 153 |
+
{
|
| 154 |
+
int size = h * w;
|
| 155 |
+
int id = blockIdx.x * blockDim.x + threadIdx.x;
|
| 156 |
+
|
| 157 |
+
if (id < size) {
|
| 158 |
+
int x = id % w, y = id / w;
|
| 159 |
+
|
| 160 |
+
double Mt_M[4][4] = {}; // 4x4
|
| 161 |
+
double invMt_M[4][4] = {};
|
| 162 |
+
double Mt_S[3][4] = {}; // RGB -> 1x4
|
| 163 |
+
double A[3][4] = {};
|
| 164 |
+
for (int i = 0; i < 4; i++)
|
| 165 |
+
for (int j = 0; j < 4; j++) {
|
| 166 |
+
Mt_M[i][j] = 0, invMt_M[i][j] = 0;
|
| 167 |
+
if (i != 3) {
|
| 168 |
+
Mt_S[i][j] = 0, A[i][j] = 0;
|
| 169 |
+
if (i == j)
|
| 170 |
+
Mt_M[i][j] = 1e-3;
|
| 171 |
+
}
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
for (int dy = -kernel_radius; dy <= kernel_radius; dy++) {
|
| 175 |
+
for (int dx = -kernel_radius; dx <= kernel_radius; dx++) {
|
| 176 |
+
|
| 177 |
+
int xx = x + dx, yy = y + dy;
|
| 178 |
+
int id2 = yy * w + xx;
|
| 179 |
+
|
| 180 |
+
if (0 <= xx && xx < w && 0 <= yy && yy < h) {
|
| 181 |
+
|
| 182 |
+
Mt_M[0][0] += input[id2 + 2*size] * input[id2 + 2*size];
|
| 183 |
+
Mt_M[0][1] += input[id2 + 2*size] * input[id2 + size];
|
| 184 |
+
Mt_M[0][2] += input[id2 + 2*size] * input[id2];
|
| 185 |
+
Mt_M[0][3] += input[id2 + 2*size];
|
| 186 |
+
|
| 187 |
+
Mt_M[1][0] += input[id2 + size] * input[id2 + 2*size];
|
| 188 |
+
Mt_M[1][1] += input[id2 + size] * input[id2 + size];
|
| 189 |
+
Mt_M[1][2] += input[id2 + size] * input[id2];
|
| 190 |
+
Mt_M[1][3] += input[id2 + size];
|
| 191 |
+
|
| 192 |
+
Mt_M[2][0] += input[id2] * input[id2 + 2*size];
|
| 193 |
+
Mt_M[2][1] += input[id2] * input[id2 + size];
|
| 194 |
+
Mt_M[2][2] += input[id2] * input[id2];
|
| 195 |
+
Mt_M[2][3] += input[id2];
|
| 196 |
+
|
| 197 |
+
Mt_M[3][0] += input[id2 + 2*size];
|
| 198 |
+
Mt_M[3][1] += input[id2 + size];
|
| 199 |
+
Mt_M[3][2] += input[id2];
|
| 200 |
+
Mt_M[3][3] += 1;
|
| 201 |
+
|
| 202 |
+
Mt_S[0][0] += input[id2 + 2*size] * output[id2 + 2*size];
|
| 203 |
+
Mt_S[0][1] += input[id2 + size] * output[id2 + 2*size];
|
| 204 |
+
Mt_S[0][2] += input[id2] * output[id2 + 2*size];
|
| 205 |
+
Mt_S[0][3] += output[id2 + 2*size];
|
| 206 |
+
|
| 207 |
+
Mt_S[1][0] += input[id2 + 2*size] * output[id2 + size];
|
| 208 |
+
Mt_S[1][1] += input[id2 + size] * output[id2 + size];
|
| 209 |
+
Mt_S[1][2] += input[id2] * output[id2 + size];
|
| 210 |
+
Mt_S[1][3] += output[id2 + size];
|
| 211 |
+
|
| 212 |
+
Mt_S[2][0] += input[id2 + 2*size] * output[id2];
|
| 213 |
+
Mt_S[2][1] += input[id2 + size] * output[id2];
|
| 214 |
+
Mt_S[2][2] += input[id2] * output[id2];
|
| 215 |
+
Mt_S[2][3] += output[id2];
|
| 216 |
+
}
|
| 217 |
+
}
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
bool success = InverseMat4x4(Mt_M, invMt_M);
|
| 221 |
+
|
| 222 |
+
for (int i = 0; i < 3; i++) {
|
| 223 |
+
for (int j = 0; j < 4; j++) {
|
| 224 |
+
for (int k = 0; k < 4; k++) {
|
| 225 |
+
A[i][j] += invMt_M[j][k] * Mt_S[i][k];
|
| 226 |
+
}
|
| 227 |
+
}
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
for (int i = 0; i < 3; i++) {
|
| 231 |
+
for (int j = 0; j < 4; j++) {
|
| 232 |
+
int affine_id = i * 4 + j;
|
| 233 |
+
affine_model[12 * id + affine_id] = A[i][j];
|
| 234 |
+
}
|
| 235 |
+
}
|
| 236 |
+
}
|
| 237 |
+
return ;
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
extern "C"
|
| 241 |
+
__global__ void bilateral_smooth_kernel(
|
| 242 |
+
float *affine_model, float *filtered_affine_model, float *guide,
|
| 243 |
+
int h, int w, int kernel_radius, float sigma1, float sigma2
|
| 244 |
+
)
|
| 245 |
+
{
|
| 246 |
+
int id = blockIdx.x * blockDim.x + threadIdx.x;
|
| 247 |
+
int size = h * w;
|
| 248 |
+
if (id < size) {
|
| 249 |
+
int x = id % w;
|
| 250 |
+
int y = id / w;
|
| 251 |
+
|
| 252 |
+
double sum_affine[12] = {};
|
| 253 |
+
double sum_weight = 0;
|
| 254 |
+
for (int dx = -kernel_radius; dx <= kernel_radius; dx++) {
|
| 255 |
+
for (int dy = -kernel_radius; dy <= kernel_radius; dy++) {
|
| 256 |
+
int yy = y + dy, xx = x + dx;
|
| 257 |
+
int id2 = yy * w + xx;
|
| 258 |
+
if (0 <= xx && xx < w && 0 <= yy && yy < h) {
|
| 259 |
+
float color_diff1 = guide[yy*w + xx] - guide[y*w + x];
|
| 260 |
+
float color_diff2 = guide[yy*w + xx + size] - guide[y*w + x + size];
|
| 261 |
+
float color_diff3 = guide[yy*w + xx + 2*size] - guide[y*w + x + 2*size];
|
| 262 |
+
float color_diff_sqr =
|
| 263 |
+
(color_diff1*color_diff1 + color_diff2*color_diff2 + color_diff3*color_diff3) / 3;
|
| 264 |
+
|
| 265 |
+
float v1 = exp(-(dx * dx + dy * dy) / (2 * sigma1 * sigma1));
|
| 266 |
+
float v2 = exp(-(color_diff_sqr) / (2 * sigma2 * sigma2));
|
| 267 |
+
float weight = v1 * v2;
|
| 268 |
+
|
| 269 |
+
for (int i = 0; i < 3; i++) {
|
| 270 |
+
for (int j = 0; j < 4; j++) {
|
| 271 |
+
int affine_id = i * 4 + j;
|
| 272 |
+
sum_affine[affine_id] += weight * affine_model[id2*12 + affine_id];
|
| 273 |
+
}
|
| 274 |
+
}
|
| 275 |
+
sum_weight += weight;
|
| 276 |
+
}
|
| 277 |
+
}
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
for (int i = 0; i < 3; i++) {
|
| 281 |
+
for (int j = 0; j < 4; j++) {
|
| 282 |
+
int affine_id = i * 4 + j;
|
| 283 |
+
filtered_affine_model[id*12 + affine_id] = sum_affine[affine_id] / sum_weight;
|
| 284 |
+
}
|
| 285 |
+
}
|
| 286 |
+
}
|
| 287 |
+
return ;
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
extern "C"
|
| 292 |
+
__global__ void reconstruction_best_kernel(
|
| 293 |
+
float *input, float *filtered_affine_model, float *filtered_best_output,
|
| 294 |
+
int h, int w
|
| 295 |
+
)
|
| 296 |
+
{
|
| 297 |
+
int id = blockIdx.x * blockDim.x + threadIdx.x;
|
| 298 |
+
int size = h * w;
|
| 299 |
+
if (id < size) {
|
| 300 |
+
double out1 =
|
| 301 |
+
input[id + 2*size] * filtered_affine_model[id*12 + 0] + // A[0][0] +
|
| 302 |
+
input[id + size] * filtered_affine_model[id*12 + 1] + // A[0][1] +
|
| 303 |
+
input[id] * filtered_affine_model[id*12 + 2] + // A[0][2] +
|
| 304 |
+
filtered_affine_model[id*12 + 3]; //A[0][3];
|
| 305 |
+
double out2 =
|
| 306 |
+
input[id + 2*size] * filtered_affine_model[id*12 + 4] + //A[1][0] +
|
| 307 |
+
input[id + size] * filtered_affine_model[id*12 + 5] + //A[1][1] +
|
| 308 |
+
input[id] * filtered_affine_model[id*12 + 6] + //A[1][2] +
|
| 309 |
+
filtered_affine_model[id*12 + 7]; //A[1][3];
|
| 310 |
+
double out3 =
|
| 311 |
+
input[id + 2*size] * filtered_affine_model[id*12 + 8] + //A[2][0] +
|
| 312 |
+
input[id + size] * filtered_affine_model[id*12 + 9] + //A[2][1] +
|
| 313 |
+
input[id] * filtered_affine_model[id*12 + 10] + //A[2][2] +
|
| 314 |
+
filtered_affine_model[id*12 + 11]; // A[2][3];
|
| 315 |
+
|
| 316 |
+
filtered_best_output[id] = out1;
|
| 317 |
+
filtered_best_output[id + size] = out2;
|
| 318 |
+
filtered_best_output[id + 2*size] = out3;
|
| 319 |
+
}
|
| 320 |
+
return ;
|
| 321 |
+
}
|
| 322 |
+
'''
|
| 323 |
+
|
| 324 |
+
import torch
|
| 325 |
+
import numpy as np
|
| 326 |
+
from PIL import Image
|
| 327 |
+
from cupy.cuda import function
|
| 328 |
+
from pynvrtc.compiler import Program
|
| 329 |
+
from collections import namedtuple
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def smooth_local_affine(output_cpu, input_cpu, epsilon, patch, h, w, f_r, f_e):
|
| 333 |
+
# program = Program(src.encode('utf-8'), 'best_local_affine_kernel.cu'.encode('utf-8'))
|
| 334 |
+
# ptx = program.compile(['-I/usr/local/cuda/include'.encode('utf-8')])
|
| 335 |
+
program = Program(src, 'best_local_affine_kernel.cu')
|
| 336 |
+
ptx = program.compile(['-I/usr/local/cuda/include'])
|
| 337 |
+
m = function.Module()
|
| 338 |
+
m.load(bytes(ptx.encode()))
|
| 339 |
+
|
| 340 |
+
_reconstruction_best_kernel = m.get_function('reconstruction_best_kernel')
|
| 341 |
+
_bilateral_smooth_kernel = m.get_function('bilateral_smooth_kernel')
|
| 342 |
+
_best_local_affine_kernel = m.get_function('best_local_affine_kernel')
|
| 343 |
+
Stream = namedtuple('Stream', ['ptr'])
|
| 344 |
+
s = Stream(ptr=torch.cuda.current_stream().cuda_stream)
|
| 345 |
+
|
| 346 |
+
filter_radius = f_r
|
| 347 |
+
sigma1 = filter_radius / 3
|
| 348 |
+
sigma2 = f_e
|
| 349 |
+
radius = (patch - 1) / 2
|
| 350 |
+
|
| 351 |
+
filtered_best_output = torch.zeros(np.shape(input_cpu)).cuda()
|
| 352 |
+
affine_model = torch.zeros((h * w, 12)).cuda()
|
| 353 |
+
filtered_affine_model =torch.zeros((h * w, 12)).cuda()
|
| 354 |
+
|
| 355 |
+
input_ = torch.from_numpy(input_cpu).cuda()
|
| 356 |
+
output_ = torch.from_numpy(output_cpu).cuda()
|
| 357 |
+
_best_local_affine_kernel(
|
| 358 |
+
grid=(int((h * w) / 256 + 1), 1),
|
| 359 |
+
block=(256, 1, 1),
|
| 360 |
+
args=[output_.data_ptr(), input_.data_ptr(), affine_model.data_ptr(),
|
| 361 |
+
np.int32(h), np.int32(w), np.float32(epsilon), np.int32(radius)], stream=s
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
_bilateral_smooth_kernel(
|
| 365 |
+
grid=(int((h * w) / 256 + 1), 1),
|
| 366 |
+
block=(256, 1, 1),
|
| 367 |
+
args=[affine_model.data_ptr(), filtered_affine_model.data_ptr(), input_.data_ptr(), np.int32(h), np.int32(w), np.int32(f_r), np.float32(sigma1), np.float32(sigma2)], stream=s
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
_reconstruction_best_kernel(
|
| 371 |
+
grid=(int((h * w) / 256 + 1), 1),
|
| 372 |
+
block=(256, 1, 1),
|
| 373 |
+
args=[input_.data_ptr(), filtered_affine_model.data_ptr(), filtered_best_output.data_ptr(),
|
| 374 |
+
np.int32(h), np.int32(w)], stream=s
|
| 375 |
+
)
|
| 376 |
+
numpy_filtered_best_output = filtered_best_output.cpu().numpy()
|
| 377 |
+
return numpy_filtered_best_output
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def smooth_filter(initImg, contentImg, f_radius=15,f_edge=1e-1):
|
| 381 |
+
'''
|
| 382 |
+
:param initImg: intermediate output. Either image path or PIL Image
|
| 383 |
+
:param contentImg: content image output. Either path or PIL Image
|
| 384 |
+
:return: stylized output image. PIL Image
|
| 385 |
+
'''
|
| 386 |
+
if type(initImg) == str:
|
| 387 |
+
initImg = Image.open(initImg).convert("RGB")
|
| 388 |
+
best_image_bgr = np.array(initImg, dtype=np.float32)
|
| 389 |
+
bW, bH, bC = best_image_bgr.shape
|
| 390 |
+
best_image_bgr = best_image_bgr[:, :, ::-1]
|
| 391 |
+
best_image_bgr = best_image_bgr.transpose((2, 0, 1))
|
| 392 |
+
|
| 393 |
+
if type(contentImg) == str:
|
| 394 |
+
contentImg = Image.open(contentImg).convert("RGB")
|
| 395 |
+
content_input = contentImg.resize((bH,bW))
|
| 396 |
+
content_input = np.array(content_input, dtype=np.float32)
|
| 397 |
+
content_input = content_input[:, :, ::-1]
|
| 398 |
+
content_input = content_input.transpose((2, 0, 1))
|
| 399 |
+
input_ = np.ascontiguousarray(content_input, dtype=np.float32) / 255.
|
| 400 |
+
_, H, W = np.shape(input_)
|
| 401 |
+
output_ = np.ascontiguousarray(best_image_bgr, dtype=np.float32) / 255.
|
| 402 |
+
best_ = smooth_local_affine(output_, input_, 1e-7, 3, H, W, f_radius, f_edge)
|
| 403 |
+
best_ = best_.transpose(1, 2, 0)
|
| 404 |
+
result = Image.fromarray(np.uint8(np.clip(best_ * 255., 0, 255.)))
|
| 405 |
+
return result
|