Spaces:
Runtime error
Runtime error
anonymous
commited on
Commit
·
37119a2
1
Parent(s):
328dfe7
add consistency decoder
Browse files- app.py +15 -8
- examples/remove_censorship.yaml +3 -0
app.py
CHANGED
|
@@ -8,7 +8,7 @@ import gradio as gr
|
|
| 8 |
import torch
|
| 9 |
import torchvision
|
| 10 |
import safetensors
|
| 11 |
-
from diffusers import AutoencoderKL
|
| 12 |
from peft import get_peft_model, LoraConfig, set_peft_model_state_dict
|
| 13 |
from huggingface_hub import snapshot_download
|
| 14 |
|
|
@@ -43,7 +43,10 @@ def prepare_model():
|
|
| 43 |
set_peft_model_state_dict(vae, state_dict)
|
| 44 |
|
| 45 |
print('Done.')
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
|
| 49 |
@spaces.GPU
|
|
@@ -102,15 +105,18 @@ def add_censorship(input_image, mode, pixelation_block_size, blur_kernel_size, s
|
|
| 102 |
|
| 103 |
@spaces.GPU
|
| 104 |
@torch.no_grad()
|
| 105 |
-
def remove_censorship(input_image, x1, y1, x2, y2):
|
| 106 |
background, layers, _ = input_image.values()
|
| 107 |
images = torch.from_numpy(background).permute(2, 0, 1)[None, :3] / 255
|
| 108 |
mask = torch.from_numpy(layers[0]).permute(2, 0, 1)[None, -1:] / 255
|
| 109 |
images = images * (1 - mask)
|
| 110 |
images = images[..., y1:y2, x1:x2]
|
| 111 |
latents = vae.encode((images * 2 - 1).to(vae.device)).latent_dist.mean
|
| 112 |
-
|
| 113 |
-
images =
|
|
|
|
|
|
|
|
|
|
| 114 |
# denormalize
|
| 115 |
images = images / 2 + 0.5
|
| 116 |
images *= 255
|
|
@@ -119,7 +125,7 @@ def remove_censorship(input_image, x1, y1, x2, y2):
|
|
| 119 |
|
| 120 |
# @@@@@@@ Start of the program @@@@@@@@
|
| 121 |
|
| 122 |
-
vae = prepare_model()
|
| 123 |
|
| 124 |
css = '''
|
| 125 |
.my-disabled {
|
|
@@ -177,6 +183,7 @@ with gr.Blocks(css=css) as demo:
|
|
| 177 |
with gr.Row():
|
| 178 |
with gr.Column():
|
| 179 |
input_image = gr.ImageEditor()
|
|
|
|
| 180 |
with gr.Accordion('Manual cropping', open=False):
|
| 181 |
with gr.Row():
|
| 182 |
with gr.Row():
|
|
@@ -197,13 +204,13 @@ with gr.Blocks(css=css) as demo:
|
|
| 197 |
|
| 198 |
submit_btn.click(
|
| 199 |
fn=remove_censorship,
|
| 200 |
-
inputs=[input_image, x1, y1, x2, y2],
|
| 201 |
outputs=output_image
|
| 202 |
)
|
| 203 |
gr.Examples(
|
| 204 |
examples=remove_censor_examples,
|
| 205 |
fn=remove_censorship,
|
| 206 |
-
inputs=[input_image, x1, y1, x2, y2],
|
| 207 |
outputs=output_image,
|
| 208 |
cache_examples=False,
|
| 209 |
)
|
|
|
|
| 8 |
import torch
|
| 9 |
import torchvision
|
| 10 |
import safetensors
|
| 11 |
+
from diffusers import AutoencoderKL, ConsistencyDecoderVAE
|
| 12 |
from peft import get_peft_model, LoraConfig, set_peft_model_state_dict
|
| 13 |
from huggingface_hub import snapshot_download
|
| 14 |
|
|
|
|
| 43 |
set_peft_model_state_dict(vae, state_dict)
|
| 44 |
|
| 45 |
print('Done.')
|
| 46 |
+
cd_vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16)
|
| 47 |
+
vae = vae.to('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
|
| 48 |
+
cd_vae = cd_vae.to('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
|
| 49 |
+
return vae, cd_vae
|
| 50 |
|
| 51 |
|
| 52 |
@spaces.GPU
|
|
|
|
| 105 |
|
| 106 |
@spaces.GPU
|
| 107 |
@torch.no_grad()
|
| 108 |
+
def remove_censorship(input_image, use_cd, x1, y1, x2, y2):
|
| 109 |
background, layers, _ = input_image.values()
|
| 110 |
images = torch.from_numpy(background).permute(2, 0, 1)[None, :3] / 255
|
| 111 |
mask = torch.from_numpy(layers[0]).permute(2, 0, 1)[None, -1:] / 255
|
| 112 |
images = images * (1 - mask)
|
| 113 |
images = images[..., y1:y2, x1:x2]
|
| 114 |
latents = vae.encode((images * 2 - 1).to(vae.device)).latent_dist.mean
|
| 115 |
+
if use_cd:
|
| 116 |
+
images = cd_vae.decode(latents.to(cd_vae.dtype), return_dict=False)[0]
|
| 117 |
+
else:
|
| 118 |
+
with vae.disable_adapter():
|
| 119 |
+
images = vae.decode(latents, return_dict=False)[0]
|
| 120 |
# denormalize
|
| 121 |
images = images / 2 + 0.5
|
| 122 |
images *= 255
|
|
|
|
| 125 |
|
| 126 |
# @@@@@@@ Start of the program @@@@@@@@
|
| 127 |
|
| 128 |
+
vae, cd_vae = prepare_model()
|
| 129 |
|
| 130 |
css = '''
|
| 131 |
.my-disabled {
|
|
|
|
| 183 |
with gr.Row():
|
| 184 |
with gr.Column():
|
| 185 |
input_image = gr.ImageEditor()
|
| 186 |
+
use_cd = gr.Checkbox(label='Use Consistency Decoder (slower)')
|
| 187 |
with gr.Accordion('Manual cropping', open=False):
|
| 188 |
with gr.Row():
|
| 189 |
with gr.Row():
|
|
|
|
| 204 |
|
| 205 |
submit_btn.click(
|
| 206 |
fn=remove_censorship,
|
| 207 |
+
inputs=[input_image, use_cd, x1, y1, x2, y2],
|
| 208 |
outputs=output_image
|
| 209 |
)
|
| 210 |
gr.Examples(
|
| 211 |
examples=remove_censor_examples,
|
| 212 |
fn=remove_censorship,
|
| 213 |
+
inputs=[input_image, use_cd, x1, y1, x2, y2],
|
| 214 |
outputs=output_image,
|
| 215 |
cache_examples=False,
|
| 216 |
)
|
examples/remove_censorship.yaml
CHANGED
|
@@ -1,14 +1,17 @@
|
|
| 1 |
- - examples/images/processed/car.png
|
|
|
|
| 2 |
- 0
|
| 3 |
- 0
|
| 4 |
- 10000
|
| 5 |
- 10000
|
| 6 |
- - examples/images/processed/obama.png
|
|
|
|
| 7 |
- 0
|
| 8 |
- 0
|
| 9 |
- 10000
|
| 10 |
- 10000
|
| 11 |
- - examples/images/processed/steam-clock.png
|
|
|
|
| 12 |
- 0
|
| 13 |
- 0
|
| 14 |
- 10000
|
|
|
|
| 1 |
- - examples/images/processed/car.png
|
| 2 |
+
- false
|
| 3 |
- 0
|
| 4 |
- 0
|
| 5 |
- 10000
|
| 6 |
- 10000
|
| 7 |
- - examples/images/processed/obama.png
|
| 8 |
+
- false
|
| 9 |
- 0
|
| 10 |
- 0
|
| 11 |
- 10000
|
| 12 |
- 10000
|
| 13 |
- - examples/images/processed/steam-clock.png
|
| 14 |
+
- false
|
| 15 |
- 0
|
| 16 |
- 0
|
| 17 |
- 10000
|