fffiloni commited on
Commit
7ab1133
·
verified ·
1 Parent(s): 29f5a64

Update app_zero.py

Browse files
Files changed (1) hide show
  1. app_zero.py +83 -63
app_zero.py CHANGED
@@ -12,6 +12,7 @@ import huggingface_hub
12
  if not hasattr(huggingface_hub, "cached_download"):
13
  def cached_download(*args, **kwargs):
14
  return huggingface_hub.hf_hub_download(*args, **kwargs)
 
15
  huggingface_hub.cached_download = cached_download
16
 
17
  import torch
@@ -19,12 +20,10 @@ import numpy as np
19
  import einops
20
  import spaces
21
  import gradio as gr
22
-
23
  from PIL import Image
24
  from torchvision import transforms
25
  import torch.nn.functional as F
26
  from torchvision.models import resnet50, ResNet50_Weights
27
-
28
  from pytorch_lightning import seed_everything
29
  from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor
30
  from diffusers import (
@@ -78,7 +77,6 @@ huggingface_hub.hf_hub_download(
78
  # -------------------------------------------------------------------
79
  sys.path.append("./PASD")
80
 
81
-
82
  # -------------------------------------------------------------------
83
  # Runtime patching helpers
84
  # -------------------------------------------------------------------
@@ -130,7 +128,6 @@ except Exception:
130
  pass
131
 
132
  """
133
-
134
  original = text
135
 
136
  # Enlève d'anciens imports simples
@@ -147,7 +144,7 @@ except Exception:
147
 
148
  # Enlève d'anciens blocs try/except cassés liés à ce mixin
149
  text = re.sub(
150
- r"(?ms)^try:\n(?:(?: |\t).*\n)+?except Exception:\n(?:(?: |\t).*\n)+?(?=^(?:class|def|@|from |import |\Z))",
151
  lambda m: "" if "FromOriginalControl" in m.group(0) else m.group(0),
152
  text,
153
  )
@@ -189,10 +186,10 @@ def patch_pasd_for_diffusers() -> None:
189
  patch_file(
190
  "./PASD/models/pasd/unet_2d_condition.py",
191
  [
192
- (" PositionNet,\n", ""),
193
  (
194
- " GLIGENTextBoundingboxProjection,\n",
195
- " GLIGENTextBoundingboxProjection as PositionNet,\n",
196
  ),
197
  ],
198
  )
@@ -251,25 +248,33 @@ weight_dtype = torch.float16
251
  device = "cuda"
252
 
253
  scheduler = UniPCMultistepScheduler.from_pretrained(
254
- pretrained_model_path, subfolder="scheduler"
 
255
  )
256
  text_encoder = CLIPTextModel.from_pretrained(
257
- pretrained_model_path, subfolder="text_encoder"
 
258
  )
259
  tokenizer = CLIPTokenizer.from_pretrained(
260
- pretrained_model_path, subfolder="tokenizer"
 
261
  )
262
  vae = AutoencoderKL.from_pretrained(
263
- pretrained_model_path, subfolder="vae"
 
264
  )
265
  feature_extractor = CLIPImageProcessor.from_pretrained(
266
- pretrained_model_path, subfolder="feature_extractor"
 
267
  )
 
268
  unet = UNet2DConditionModel.from_pretrained(
269
- ckpt_path, subfolder="unet"
 
270
  )
271
  controlnet = ControlNetModel.from_pretrained(
272
- ckpt_path, subfolder="controlnet"
 
273
  )
274
 
275
  vae.requires_grad_(False)
@@ -278,7 +283,10 @@ unet.requires_grad_(False)
278
  controlnet.requires_grad_(False)
279
 
280
  unet, vae, text_encoder = load_dreambooth_lora(
281
- unet, vae, text_encoder, dreambooth_lora_path
 
 
 
282
  )
283
 
284
  text_encoder.to(device, dtype=weight_dtype)
@@ -317,18 +325,37 @@ def resize_image(image_path: str, target_height: int) -> Image.Image:
317
 
318
 
319
  @spaces.GPU(enable_queue=True)
320
- def inference(
321
  input_image,
322
  prompt,
323
- a_prompt,
324
- n_prompt,
325
  denoise_steps,
326
  upscale,
327
  alpha,
328
- cfg,
329
  seed,
330
- progress=gr.Progress(track_tqdm=True)
331
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
  if seed == -1:
333
  seed = 0
334
 
@@ -351,17 +378,18 @@ def inference(
351
  if score >= 0.1:
352
  prompt += f"{category_name}" if prompt == "" else f", {category_name}"
353
 
354
- prompt = a_prompt if prompt == "" else f"{prompt}, {a_prompt}"
355
 
356
  ori_width, ori_height = input_image.size
357
-
358
  rscale = upscale
 
359
  input_image = input_image.resize(
360
  (input_image.size[0] * rscale, input_image.size[1] * rscale)
361
  )
362
  input_image = input_image.resize(
363
  (input_image.size[0] // 8 * 8, input_image.size[1] // 8 * 8)
364
  )
 
365
  width, height = input_image.size
366
 
367
  try:
@@ -373,14 +401,15 @@ def inference(
373
  generator=generator,
374
  height=height,
375
  width=width,
376
- guidance_scale=cfg,
377
- negative_prompt=n_prompt,
378
  conditioning_scale=alpha,
379
  eta=0.0,
380
  ).images[0]
381
 
382
  image = wavelet_color_fix(image, input_image)
383
  image = image.resize((ori_width * rscale, ori_height * rscale))
 
384
  except Exception as e:
385
  print(f"[inference] error: {e}")
386
  image = Image.new(mode="RGB", size=(512, 512))
@@ -412,23 +441,14 @@ css = """
412
 
413
  with gr.Blocks() as demo:
414
  with gr.Column(elem_id="col-container"):
415
- gr.HTML("""
416
- <h2 style="text-align: center;">
417
- PASD Magnify
418
- </h2>
419
- <p style="text-align: center;">
420
- Pixel-Aware Stable Diffusion for Realistic Image Super-resolution and Personalized Stylization
421
- </p>
422
- <p id="project-links" align="center">
423
- <a href="https://github.com/yangxy/PASD"><img src="https://img.shields.io/badge/Project-Page-Green"></a>
424
- <a href="https://huggingface.co/papers/2308.14469"><img src="https://img.shields.io/badge/Paper-Arxiv-red"></a>
425
- </p>
426
- <p style="margin:12px auto;display: flex;justify-content: center;">
427
- <a href="https://huggingface.co/spaces/fffiloni/PASD?duplicate=true">
428
- <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-lg.svg" alt="Duplicate this Space">
429
- </a>
430
- </p>
431
- """)
432
 
433
  with gr.Row():
434
  with gr.Column():
@@ -492,29 +512,29 @@ with gr.Blocks() as demo:
492
  after_img = gr.Image(label="Result")
493
  file_output = gr.File(label="Downloadable image result")
494
 
495
- submit_btn.click(
496
- fn=inference,
497
- inputs=[
498
- input_image,
499
- prompt_in,
500
- added_prompt,
501
- neg_prompt,
502
- denoise_steps,
503
- upsample_scale,
504
- condition_scale,
505
- classifier_free_guidance,
506
- seed,
507
- ],
508
- outputs=[
509
- before_img,
510
- after_img,
511
- file_output,
512
- ],
513
- api_visibility="private",
514
- )
515
 
516
  demo.queue(max_size=10).launch(
517
  ssr_mode=False,
518
- mcp_server=False,
519
  css=css,
520
  )
 
12
  if not hasattr(huggingface_hub, "cached_download"):
13
  def cached_download(*args, **kwargs):
14
  return huggingface_hub.hf_hub_download(*args, **kwargs)
15
+
16
  huggingface_hub.cached_download = cached_download
17
 
18
  import torch
 
20
  import einops
21
  import spaces
22
  import gradio as gr
 
23
  from PIL import Image
24
  from torchvision import transforms
25
  import torch.nn.functional as F
26
  from torchvision.models import resnet50, ResNet50_Weights
 
27
  from pytorch_lightning import seed_everything
28
  from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor
29
  from diffusers import (
 
77
  # -------------------------------------------------------------------
78
  sys.path.append("./PASD")
79
 
 
80
  # -------------------------------------------------------------------
81
  # Runtime patching helpers
82
  # -------------------------------------------------------------------
 
128
  pass
129
 
130
  """
 
131
  original = text
132
 
133
  # Enlève d'anciens imports simples
 
144
 
145
  # Enlève d'anciens blocs try/except cassés liés à ce mixin
146
  text = re.sub(
147
+ r"(?ms)^try:\n(?:(?: |\t).*\n)+?except Exception:\n(?:(?: |\t).*\n)+?(?=^(?:class|def|@|from |import |\Z))",
148
  lambda m: "" if "FromOriginalControl" in m.group(0) else m.group(0),
149
  text,
150
  )
 
186
  patch_file(
187
  "./PASD/models/pasd/unet_2d_condition.py",
188
  [
189
+ (" PositionNet,\n", ""),
190
  (
191
+ " GLIGENTextBoundingboxProjection,\n",
192
+ " GLIGENTextBoundingboxProjection as PositionNet,\n",
193
  ),
194
  ],
195
  )
 
248
  device = "cuda"
249
 
250
  scheduler = UniPCMultistepScheduler.from_pretrained(
251
+ pretrained_model_path,
252
+ subfolder="scheduler",
253
  )
254
  text_encoder = CLIPTextModel.from_pretrained(
255
+ pretrained_model_path,
256
+ subfolder="text_encoder",
257
  )
258
  tokenizer = CLIPTokenizer.from_pretrained(
259
+ pretrained_model_path,
260
+ subfolder="tokenizer",
261
  )
262
  vae = AutoencoderKL.from_pretrained(
263
+ pretrained_model_path,
264
+ subfolder="vae",
265
  )
266
  feature_extractor = CLIPImageProcessor.from_pretrained(
267
+ pretrained_model_path,
268
+ subfolder="feature_extractor",
269
  )
270
+
271
  unet = UNet2DConditionModel.from_pretrained(
272
+ ckpt_path,
273
+ subfolder="unet",
274
  )
275
  controlnet = ControlNetModel.from_pretrained(
276
+ ckpt_path,
277
+ subfolder="controlnet",
278
  )
279
 
280
  vae.requires_grad_(False)
 
283
  controlnet.requires_grad_(False)
284
 
285
  unet, vae, text_encoder = load_dreambooth_lora(
286
+ unet,
287
+ vae,
288
+ text_encoder,
289
+ dreambooth_lora_path,
290
  )
291
 
292
  text_encoder.to(device, dtype=weight_dtype)
 
325
 
326
 
327
  @spaces.GPU(enable_queue=True)
328
+ def super_resolve_image(
329
  input_image,
330
  prompt,
331
+ added_prompt,
332
+ negative_prompt,
333
  denoise_steps,
334
  upscale,
335
  alpha,
336
+ guidance_scale,
337
  seed,
338
+ progress=gr.Progress(track_tqdm=True),
339
  ):
340
+ """
341
+ Super-resolve an input image with PASD and optional prompt guidance.
342
+
343
+ Use this tool when you need to generate a higher-resolution restored image from an input image.
344
+
345
+ Args:
346
+ input_image (str): File path to the input image.
347
+ prompt (str): Main text prompt describing the desired image content.
348
+ added_prompt (str): Additional quality or style prompt appended to the main prompt.
349
+ negative_prompt (str): Negative prompt describing unwanted visual qualities.
350
+ denoise_steps (int): Number of denoising steps used by the diffusion pipeline.
351
+ upscale (int): Integer upscale factor applied to the image.
352
+ alpha (float): Conditioning scale passed to the ControlNet pipeline.
353
+ guidance_scale (float): Classifier-free guidance scale passed to the diffusion pipeline.
354
+ seed (int): Random seed, where -1 is converted to 0.
355
+
356
+ Returns:
357
+ tuple: Input image path, result image path, and downloadable result image path.
358
+ """
359
  if seed == -1:
360
  seed = 0
361
 
 
378
  if score >= 0.1:
379
  prompt += f"{category_name}" if prompt == "" else f", {category_name}"
380
 
381
+ prompt = added_prompt if prompt == "" else f"{prompt}, {added_prompt}"
382
 
383
  ori_width, ori_height = input_image.size
 
384
  rscale = upscale
385
+
386
  input_image = input_image.resize(
387
  (input_image.size[0] * rscale, input_image.size[1] * rscale)
388
  )
389
  input_image = input_image.resize(
390
  (input_image.size[0] // 8 * 8, input_image.size[1] // 8 * 8)
391
  )
392
+
393
  width, height = input_image.size
394
 
395
  try:
 
401
  generator=generator,
402
  height=height,
403
  width=width,
404
+ guidance_scale=guidance_scale,
405
+ negative_prompt=negative_prompt,
406
  conditioning_scale=alpha,
407
  eta=0.0,
408
  ).images[0]
409
 
410
  image = wavelet_color_fix(image, input_image)
411
  image = image.resize((ori_width * rscale, ori_height * rscale))
412
+
413
  except Exception as e:
414
  print(f"[inference] error: {e}")
415
  image = Image.new(mode="RGB", size=(512, 512))
 
441
 
442
  with gr.Blocks() as demo:
443
  with gr.Column(elem_id="col-container"):
444
+ gr.Markdown("""
445
+ ## PASD Magnify
446
+
447
+ Pixel-Aware Stable Diffusion for Realistic Image Super-resolution and Personalized Stylization
448
+
449
+ <a href='https://arxiv.org/abs/2308.14469' target='_blank'><img src='https://img.shields.io/badge/arXiv-2308.14469-red'></a> <a href='https://github.com/yangxy/PASD' target='_blank'><img src='https://img.shields.io/badge/GitHub-Code-blue'></a>
450
+
451
+ """)
 
 
 
 
 
 
 
 
 
452
 
453
  with gr.Row():
454
  with gr.Column():
 
512
  after_img = gr.Image(label="Result")
513
  file_output = gr.File(label="Downloadable image result")
514
 
515
+ submit_btn.click(
516
+ fn=super_resolve_image,
517
+ inputs=[
518
+ input_image,
519
+ prompt_in,
520
+ added_prompt,
521
+ neg_prompt,
522
+ denoise_steps,
523
+ upsample_scale,
524
+ condition_scale,
525
+ classifier_free_guidance,
526
+ seed,
527
+ ],
528
+ outputs=[
529
+ before_img,
530
+ after_img,
531
+ file_output,
532
+ ],
533
+ api_visibility="public",
534
+ )
535
 
536
  demo.queue(max_size=10).launch(
537
  ssr_mode=False,
538
+ mcp_server=True,
539
  css=css,
540
  )