[Admin maintenance] Migrate grant to ZeroGPU

#21
by multimodalart HF Staff - opened
Files changed (1) hide show
  1. app.py +60 -9
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import gradio as gr
2
  import torch
3
  import numpy as np
@@ -26,7 +27,47 @@ pipe.scheduler = DPMSolverMultistepSchedulerInject.from_pretrained(sd_model_id,
26
  blip_processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
27
  blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base",torch_dtype=torch.float16).to(device)
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  ## IMAGE CPATIONING ##
 
30
  def caption_image(input_image):
31
  inputs = blip_processor(images=input_image, return_tensors="pt").to(device, torch.float16)
32
  pixel_values = inputs.pixel_values
@@ -51,6 +92,7 @@ def sample(zs, wts, attention_store, text_cross_attention_maps, prompt_tar="", c
51
  return img.images[0], attention_store, text_cross_attention_maps
52
 
53
 
 
54
  def reconstruct(
55
  tar_prompt,
56
  image_caption,
@@ -64,6 +106,10 @@ def reconstruct(
64
  reconstruction,
65
  reconstruct_button,
66
  ):
 
 
 
 
67
  if reconstruct_button == "Hide Reconstruction":
68
  return (
69
  reconstruction,
@@ -130,6 +176,7 @@ def load_and_invert(
130
 
131
  ## SEGA ##
132
 
 
133
  def edit(input_image,
134
  wts, zs, attention_store, text_cross_attention_maps,
135
  tar_prompt,
@@ -143,15 +190,19 @@ def edit(input_image,
143
  neg_guidance_1, neg_guidance_2, neg_guidance_3,
144
  threshold_1, threshold_2, threshold_3,
145
  do_reconstruction,
146
- reconstruction,
147
  # for inversion in case it needs to be re computed (and avoid delay):
148
  do_inversion,
149
- seed,
150
  randomize_seed,
151
  src_prompt,
152
  src_cfg_scale,
153
  mask_type,
154
  progress=gr.Progress(track_tqdm=True)):
 
 
 
 
155
  show_share_button = gr.update(visible=True)
156
  if(mask_type == "No mask"):
157
  use_cross_attn_mask = False
@@ -207,18 +258,18 @@ def edit(input_image,
207
  # wts=wts.value,
208
  zs=zs, attention_store=attention_store, text_cross_attention_maps=text_cross_attention_maps, **editing_args)
209
 
210
- return sega_out.images[0], gr.update(visible=True), do_reconstruction, reconstruction, wts, zs, attention_store, text_cross_attention_maps, do_inversion, show_share_button
211
-
212
-
213
  else: # if sega concepts were not added, performs regular ddpm sampling
214
-
215
  if do_reconstruction: # if ddpm sampling wasn't computed
216
  pure_ddpm_img, attention_store, text_cross_attention_maps = sample(zs, wts, attention_store=attention_store, text_cross_attention_maps=text_cross_attention_maps, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale)
217
  reconstruction = pure_ddpm_img
218
  do_reconstruction = False
219
- return pure_ddpm_img, gr.update(visible=False), do_reconstruction, reconstruction, wts, zs, attention_store, text_cross_attention_maps, do_inversion, show_share_button
220
-
221
- return reconstruction, gr.update(visible=False), do_reconstruction, reconstruction, wts, zs, attention_store, text_cross_attention_maps, do_inversion, show_share_button
222
 
223
 
224
  def randomize_seed_fn(seed, is_random):
 
1
+ import spaces
2
  import gradio as gr
3
  import torch
4
  import numpy as np
 
27
  blip_processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
28
  blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base",torch_dtype=torch.float16).to(device)
29
 
30
+ ## Helpers to bounce CUDA tensors across gr.State (multiprocessing pickle barrier on ZeroGPU)
31
+ def _to_cpu(obj, _seen=None):
32
+ if _seen is None:
33
+ _seen = set()
34
+ if isinstance(obj, torch.Tensor):
35
+ return obj.detach().cpu() if obj.is_cuda else obj
36
+ if isinstance(obj, list):
37
+ return [_to_cpu(x, _seen) for x in obj]
38
+ if isinstance(obj, tuple):
39
+ return tuple(_to_cpu(x, _seen) for x in obj)
40
+ if isinstance(obj, dict):
41
+ return {k: _to_cpu(v, _seen) for k, v in obj.items()}
42
+ if hasattr(obj, "__dict__") and id(obj) not in _seen:
43
+ _seen.add(id(obj))
44
+ for k, v in list(obj.__dict__.items()):
45
+ setattr(obj, k, _to_cpu(v, _seen))
46
+ return obj
47
+ return obj
48
+
49
+
50
+ def _to_cuda(obj, _seen=None):
51
+ if _seen is None:
52
+ _seen = set()
53
+ if isinstance(obj, torch.Tensor):
54
+ return obj.to(device) if not obj.is_cuda else obj
55
+ if isinstance(obj, list):
56
+ return [_to_cuda(x, _seen) for x in obj]
57
+ if isinstance(obj, tuple):
58
+ return tuple(_to_cuda(x, _seen) for x in obj)
59
+ if isinstance(obj, dict):
60
+ return {k: _to_cuda(v, _seen) for k, v in obj.items()}
61
+ if hasattr(obj, "__dict__") and id(obj) not in _seen:
62
+ _seen.add(id(obj))
63
+ for k, v in list(obj.__dict__.items()):
64
+ setattr(obj, k, _to_cuda(v, _seen))
65
+ return obj
66
+ return obj
67
+
68
+
69
  ## IMAGE CPATIONING ##
70
+ @spaces.GPU
71
  def caption_image(input_image):
72
  inputs = blip_processor(images=input_image, return_tensors="pt").to(device, torch.float16)
73
  pixel_values = inputs.pixel_values
 
92
  return img.images[0], attention_store, text_cross_attention_maps
93
 
94
 
95
+ @spaces.GPU
96
  def reconstruct(
97
  tar_prompt,
98
  image_caption,
 
106
  reconstruction,
107
  reconstruct_button,
108
  ):
109
+ wts = _to_cuda(wts)
110
+ zs = _to_cuda(zs)
111
+ attention_store = _to_cuda(attention_store)
112
+ text_cross_attention_maps = _to_cuda(text_cross_attention_maps)
113
  if reconstruct_button == "Hide Reconstruction":
114
  return (
115
  reconstruction,
 
176
 
177
  ## SEGA ##
178
 
179
+ @spaces.GPU
180
  def edit(input_image,
181
  wts, zs, attention_store, text_cross_attention_maps,
182
  tar_prompt,
 
190
  neg_guidance_1, neg_guidance_2, neg_guidance_3,
191
  threshold_1, threshold_2, threshold_3,
192
  do_reconstruction,
193
+ reconstruction,
194
  # for inversion in case it needs to be re computed (and avoid delay):
195
  do_inversion,
196
+ seed,
197
  randomize_seed,
198
  src_prompt,
199
  src_cfg_scale,
200
  mask_type,
201
  progress=gr.Progress(track_tqdm=True)):
202
+ wts = _to_cuda(wts)
203
+ zs = _to_cuda(zs)
204
+ attention_store = _to_cuda(attention_store)
205
+ text_cross_attention_maps = _to_cuda(text_cross_attention_maps)
206
  show_share_button = gr.update(visible=True)
207
  if(mask_type == "No mask"):
208
  use_cross_attn_mask = False
 
258
  # wts=wts.value,
259
  zs=zs, attention_store=attention_store, text_cross_attention_maps=text_cross_attention_maps, **editing_args)
260
 
261
+ return sega_out.images[0], gr.update(visible=True), do_reconstruction, reconstruction, _to_cpu(wts), _to_cpu(zs), _to_cpu(attention_store), _to_cpu(text_cross_attention_maps), do_inversion, show_share_button
262
+
263
+
264
  else: # if sega concepts were not added, performs regular ddpm sampling
265
+
266
  if do_reconstruction: # if ddpm sampling wasn't computed
267
  pure_ddpm_img, attention_store, text_cross_attention_maps = sample(zs, wts, attention_store=attention_store, text_cross_attention_maps=text_cross_attention_maps, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale)
268
  reconstruction = pure_ddpm_img
269
  do_reconstruction = False
270
+ return pure_ddpm_img, gr.update(visible=False), do_reconstruction, reconstruction, _to_cpu(wts), _to_cpu(zs), _to_cpu(attention_store), _to_cpu(text_cross_attention_maps), do_inversion, show_share_button
271
+
272
+ return reconstruction, gr.update(visible=False), do_reconstruction, reconstruction, _to_cpu(wts), _to_cpu(zs), _to_cpu(attention_store), _to_cpu(text_cross_attention_maps), do_inversion, show_share_button
273
 
274
 
275
  def randomize_seed_fn(seed, is_random):