Spaces:
Running on Zero
Running on Zero
[Admin maintenance] Migrate grant to ZeroGPU
#21
by multimodalart HF Staff - opened
app.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import torch
|
| 3 |
import numpy as np
|
|
@@ -26,7 +27,47 @@ pipe.scheduler = DPMSolverMultistepSchedulerInject.from_pretrained(sd_model_id,
|
|
| 26 |
blip_processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
| 27 |
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base",torch_dtype=torch.float16).to(device)
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
## IMAGE CPATIONING ##
|
|
|
|
| 30 |
def caption_image(input_image):
|
| 31 |
inputs = blip_processor(images=input_image, return_tensors="pt").to(device, torch.float16)
|
| 32 |
pixel_values = inputs.pixel_values
|
|
@@ -51,6 +92,7 @@ def sample(zs, wts, attention_store, text_cross_attention_maps, prompt_tar="", c
|
|
| 51 |
return img.images[0], attention_store, text_cross_attention_maps
|
| 52 |
|
| 53 |
|
|
|
|
| 54 |
def reconstruct(
|
| 55 |
tar_prompt,
|
| 56 |
image_caption,
|
|
@@ -64,6 +106,10 @@ def reconstruct(
|
|
| 64 |
reconstruction,
|
| 65 |
reconstruct_button,
|
| 66 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
if reconstruct_button == "Hide Reconstruction":
|
| 68 |
return (
|
| 69 |
reconstruction,
|
|
@@ -130,6 +176,7 @@ def load_and_invert(
|
|
| 130 |
|
| 131 |
## SEGA ##
|
| 132 |
|
|
|
|
| 133 |
def edit(input_image,
|
| 134 |
wts, zs, attention_store, text_cross_attention_maps,
|
| 135 |
tar_prompt,
|
|
@@ -143,15 +190,19 @@ def edit(input_image,
|
|
| 143 |
neg_guidance_1, neg_guidance_2, neg_guidance_3,
|
| 144 |
threshold_1, threshold_2, threshold_3,
|
| 145 |
do_reconstruction,
|
| 146 |
-
reconstruction,
|
| 147 |
# for inversion in case it needs to be re computed (and avoid delay):
|
| 148 |
do_inversion,
|
| 149 |
-
seed,
|
| 150 |
randomize_seed,
|
| 151 |
src_prompt,
|
| 152 |
src_cfg_scale,
|
| 153 |
mask_type,
|
| 154 |
progress=gr.Progress(track_tqdm=True)):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
show_share_button = gr.update(visible=True)
|
| 156 |
if(mask_type == "No mask"):
|
| 157 |
use_cross_attn_mask = False
|
|
@@ -207,18 +258,18 @@ def edit(input_image,
|
|
| 207 |
# wts=wts.value,
|
| 208 |
zs=zs, attention_store=attention_store, text_cross_attention_maps=text_cross_attention_maps, **editing_args)
|
| 209 |
|
| 210 |
-
return sega_out.images[0], gr.update(visible=True), do_reconstruction, reconstruction, wts, zs, attention_store, text_cross_attention_maps, do_inversion, show_share_button
|
| 211 |
-
|
| 212 |
-
|
| 213 |
else: # if sega concepts were not added, performs regular ddpm sampling
|
| 214 |
-
|
| 215 |
if do_reconstruction: # if ddpm sampling wasn't computed
|
| 216 |
pure_ddpm_img, attention_store, text_cross_attention_maps = sample(zs, wts, attention_store=attention_store, text_cross_attention_maps=text_cross_attention_maps, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale)
|
| 217 |
reconstruction = pure_ddpm_img
|
| 218 |
do_reconstruction = False
|
| 219 |
-
return pure_ddpm_img, gr.update(visible=False), do_reconstruction, reconstruction, wts, zs, attention_store, text_cross_attention_maps, do_inversion, show_share_button
|
| 220 |
-
|
| 221 |
-
return reconstruction, gr.update(visible=False), do_reconstruction, reconstruction, wts, zs, attention_store, text_cross_attention_maps, do_inversion, show_share_button
|
| 222 |
|
| 223 |
|
| 224 |
def randomize_seed_fn(seed, is_random):
|
|
|
|
| 1 |
+
import spaces
|
| 2 |
import gradio as gr
|
| 3 |
import torch
|
| 4 |
import numpy as np
|
|
|
|
| 27 |
blip_processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
| 28 |
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base",torch_dtype=torch.float16).to(device)
|
| 29 |
|
| 30 |
+
## Helpers to bounce CUDA tensors across gr.State (multiprocessing pickle barrier on ZeroGPU)
|
| 31 |
+
def _to_cpu(obj, _seen=None):
|
| 32 |
+
if _seen is None:
|
| 33 |
+
_seen = set()
|
| 34 |
+
if isinstance(obj, torch.Tensor):
|
| 35 |
+
return obj.detach().cpu() if obj.is_cuda else obj
|
| 36 |
+
if isinstance(obj, list):
|
| 37 |
+
return [_to_cpu(x, _seen) for x in obj]
|
| 38 |
+
if isinstance(obj, tuple):
|
| 39 |
+
return tuple(_to_cpu(x, _seen) for x in obj)
|
| 40 |
+
if isinstance(obj, dict):
|
| 41 |
+
return {k: _to_cpu(v, _seen) for k, v in obj.items()}
|
| 42 |
+
if hasattr(obj, "__dict__") and id(obj) not in _seen:
|
| 43 |
+
_seen.add(id(obj))
|
| 44 |
+
for k, v in list(obj.__dict__.items()):
|
| 45 |
+
setattr(obj, k, _to_cpu(v, _seen))
|
| 46 |
+
return obj
|
| 47 |
+
return obj
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _to_cuda(obj, _seen=None):
|
| 51 |
+
if _seen is None:
|
| 52 |
+
_seen = set()
|
| 53 |
+
if isinstance(obj, torch.Tensor):
|
| 54 |
+
return obj.to(device) if not obj.is_cuda else obj
|
| 55 |
+
if isinstance(obj, list):
|
| 56 |
+
return [_to_cuda(x, _seen) for x in obj]
|
| 57 |
+
if isinstance(obj, tuple):
|
| 58 |
+
return tuple(_to_cuda(x, _seen) for x in obj)
|
| 59 |
+
if isinstance(obj, dict):
|
| 60 |
+
return {k: _to_cuda(v, _seen) for k, v in obj.items()}
|
| 61 |
+
if hasattr(obj, "__dict__") and id(obj) not in _seen:
|
| 62 |
+
_seen.add(id(obj))
|
| 63 |
+
for k, v in list(obj.__dict__.items()):
|
| 64 |
+
setattr(obj, k, _to_cuda(v, _seen))
|
| 65 |
+
return obj
|
| 66 |
+
return obj
|
| 67 |
+
|
| 68 |
+
|
| 69 |
## IMAGE CPATIONING ##
|
| 70 |
+
@spaces.GPU
|
| 71 |
def caption_image(input_image):
|
| 72 |
inputs = blip_processor(images=input_image, return_tensors="pt").to(device, torch.float16)
|
| 73 |
pixel_values = inputs.pixel_values
|
|
|
|
| 92 |
return img.images[0], attention_store, text_cross_attention_maps
|
| 93 |
|
| 94 |
|
| 95 |
+
@spaces.GPU
|
| 96 |
def reconstruct(
|
| 97 |
tar_prompt,
|
| 98 |
image_caption,
|
|
|
|
| 106 |
reconstruction,
|
| 107 |
reconstruct_button,
|
| 108 |
):
|
| 109 |
+
wts = _to_cuda(wts)
|
| 110 |
+
zs = _to_cuda(zs)
|
| 111 |
+
attention_store = _to_cuda(attention_store)
|
| 112 |
+
text_cross_attention_maps = _to_cuda(text_cross_attention_maps)
|
| 113 |
if reconstruct_button == "Hide Reconstruction":
|
| 114 |
return (
|
| 115 |
reconstruction,
|
|
|
|
| 176 |
|
| 177 |
## SEGA ##
|
| 178 |
|
| 179 |
+
@spaces.GPU
|
| 180 |
def edit(input_image,
|
| 181 |
wts, zs, attention_store, text_cross_attention_maps,
|
| 182 |
tar_prompt,
|
|
|
|
| 190 |
neg_guidance_1, neg_guidance_2, neg_guidance_3,
|
| 191 |
threshold_1, threshold_2, threshold_3,
|
| 192 |
do_reconstruction,
|
| 193 |
+
reconstruction,
|
| 194 |
# for inversion in case it needs to be re computed (and avoid delay):
|
| 195 |
do_inversion,
|
| 196 |
+
seed,
|
| 197 |
randomize_seed,
|
| 198 |
src_prompt,
|
| 199 |
src_cfg_scale,
|
| 200 |
mask_type,
|
| 201 |
progress=gr.Progress(track_tqdm=True)):
|
| 202 |
+
wts = _to_cuda(wts)
|
| 203 |
+
zs = _to_cuda(zs)
|
| 204 |
+
attention_store = _to_cuda(attention_store)
|
| 205 |
+
text_cross_attention_maps = _to_cuda(text_cross_attention_maps)
|
| 206 |
show_share_button = gr.update(visible=True)
|
| 207 |
if(mask_type == "No mask"):
|
| 208 |
use_cross_attn_mask = False
|
|
|
|
| 258 |
# wts=wts.value,
|
| 259 |
zs=zs, attention_store=attention_store, text_cross_attention_maps=text_cross_attention_maps, **editing_args)
|
| 260 |
|
| 261 |
+
return sega_out.images[0], gr.update(visible=True), do_reconstruction, reconstruction, _to_cpu(wts), _to_cpu(zs), _to_cpu(attention_store), _to_cpu(text_cross_attention_maps), do_inversion, show_share_button
|
| 262 |
+
|
| 263 |
+
|
| 264 |
else: # if sega concepts were not added, performs regular ddpm sampling
|
| 265 |
+
|
| 266 |
if do_reconstruction: # if ddpm sampling wasn't computed
|
| 267 |
pure_ddpm_img, attention_store, text_cross_attention_maps = sample(zs, wts, attention_store=attention_store, text_cross_attention_maps=text_cross_attention_maps, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale)
|
| 268 |
reconstruction = pure_ddpm_img
|
| 269 |
do_reconstruction = False
|
| 270 |
+
return pure_ddpm_img, gr.update(visible=False), do_reconstruction, reconstruction, _to_cpu(wts), _to_cpu(zs), _to_cpu(attention_store), _to_cpu(text_cross_attention_maps), do_inversion, show_share_button
|
| 271 |
+
|
| 272 |
+
return reconstruction, gr.update(visible=False), do_reconstruction, reconstruction, _to_cpu(wts), _to_cpu(zs), _to_cpu(attention_store), _to_cpu(text_cross_attention_maps), do_inversion, show_share_button
|
| 273 |
|
| 274 |
|
| 275 |
def randomize_seed_fn(seed, is_random):
|