Restart please

#9
by Alchemist85 - opened
Files changed (3) hide show
  1. README.md +1 -1
  2. app_zero.py +145 -430
  3. requirements.txt +8 -15
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: ✨
4
  colorFrom: indigo
5
  colorTo: pink
6
  sdk: gradio
7
- sdk_version: 6.14.0
8
  app_file: app_zero.py
9
  pinned: false
10
  short_description: Magnify subject details and enhance image quality
 
4
  colorFrom: indigo
5
  colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 5.44.0
8
  app_file: app_zero.py
9
  pinned: false
10
  short_description: Magnify subject details and enhance image quality
app_zero.py CHANGED
@@ -1,229 +1,55 @@
1
- import sys
2
- import types
3
- import datetime
4
- import re
5
- from pathlib import Path
6
-
7
- import huggingface_hub
8
-
9
- # -------------------------------------------------------------------
10
- # Compatibility shim: older diffusers may still expect cached_download
11
- # -------------------------------------------------------------------
12
- if not hasattr(huggingface_hub, "cached_download"):
13
- def cached_download(*args, **kwargs):
14
- return huggingface_hub.hf_hub_download(*args, **kwargs)
15
-
16
- huggingface_hub.cached_download = cached_download
17
-
18
  import torch
19
- import numpy as np
20
- import einops
21
- import spaces
22
- import gradio as gr
23
- from PIL import Image
24
- from torchvision import transforms
25
- import torch.nn.functional as F
26
- from torchvision.models import resnet50, ResNet50_Weights
27
- from pytorch_lightning import seed_everything
28
- from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor
29
- from diffusers import (
30
- AutoencoderKL,
31
- DDIMScheduler,
32
- PNDMScheduler,
33
- DPMSolverMultistepScheduler,
34
- UniPCMultistepScheduler,
35
- )
36
-
37
- # -------------------------------------------------------------------
38
- # GPU spoof for Spaces env compatibility
39
- # -------------------------------------------------------------------
40
  torch.cuda.get_device_capability = lambda *args, **kwargs: (8, 6)
41
- torch.cuda.get_device_properties = lambda *args, **kwargs: types.SimpleNamespace(
42
- name="NVIDIA A10G",
43
- major=8,
44
- minor=6,
45
- total_memory=23836033024,
46
- multi_processor_count=80,
47
- )
48
 
49
- # -------------------------------------------------------------------
50
- # Download required assets
51
- # -------------------------------------------------------------------
52
  huggingface_hub.snapshot_download(
53
- repo_id="camenduru/PASD",
54
  allow_patterns=[
55
- "pasd/**",
56
- "pasd_light/**",
57
- "pasd_light_rrdb/**",
58
- "pasd_rrdb/**",
59
  ],
60
- local_dir="PASD/runs",
 
61
  )
62
-
63
  huggingface_hub.hf_hub_download(
64
- repo_id="camenduru/PASD",
65
- filename="majicmixRealistic_v6.safetensors",
66
- local_dir="PASD/checkpoints/personalized_models",
 
67
  )
68
-
69
  huggingface_hub.hf_hub_download(
70
- repo_id="akhaliq/RetinaFace-R50",
71
- filename="RetinaFace-R50.pth",
72
- local_dir="PASD/annotator/ckpts",
 
73
  )
74
 
75
- # -------------------------------------------------------------------
76
- # PASD local path
77
- # -------------------------------------------------------------------
78
- sys.path.append("./PASD")
79
-
80
- # -------------------------------------------------------------------
81
- # Runtime patching helpers
82
- # -------------------------------------------------------------------
83
- def patch_file(path_str: str, replacements: list[tuple[str, str]]) -> None:
84
- path = Path(path_str)
85
- if not path.exists():
86
- print(f"[patch] file not found: {path}")
87
- return
88
-
89
- try:
90
- text = path.read_text(encoding="utf-8")
91
- except Exception as e:
92
- print(f"[patch] failed reading {path}: {e}")
93
- return
94
-
95
- original = text
96
- for old, new in replacements:
97
- text = text.replace(old, new)
98
-
99
- if text != original:
100
- try:
101
- path.write_text(text, encoding="utf-8")
102
- print(f"[patch] updated: {path}")
103
- except Exception as e:
104
- print(f"[patch] failed writing {path}: {e}")
105
- else:
106
- print(f"[patch] no changes: {path}")
107
-
108
-
109
- def patch_controlnet_loader_import(path_str: str) -> None:
110
- path = Path(path_str)
111
- if not path.exists():
112
- print(f"[patch] file not found: {path}")
113
- return
114
-
115
- try:
116
- text = path.read_text(encoding="utf-8")
117
- except Exception as e:
118
- print(f"[patch] failed reading {path}: {e}")
119
- return
120
-
121
- safe_block = """try:
122
- from diffusers.loaders import FromOriginalControlNetMixin as FromOriginalControlnetMixin
123
- except Exception:
124
- try:
125
- from diffusers.loaders import FromOriginalControlnetMixin
126
- except Exception:
127
- class FromOriginalControlnetMixin:
128
- pass
129
-
130
- """
131
- original = text
132
-
133
- # Enlève d'anciens imports simples
134
- text = re.sub(
135
- r"(?m)^from diffusers\.loaders[^\n]*FromOriginalControl\w*Mixin[^\n]*\n",
136
- "",
137
- text,
138
- )
139
- text = re.sub(
140
- r"(?m)^from diffusers\.loaders\.single_file_model[^\n]*FromOriginal\w+[^\n]*\n",
141
- "",
142
- text,
143
- )
144
-
145
- # Enlève d'anciens blocs try/except cassés liés à ce mixin
146
- text = re.sub(
147
- r"(?ms)^try:\n(?:(?: |\t).*\n)+?except Exception:\n(?:(?: |\t).*\n)+?(?=^(?:class|def|@|from |import |\Z))",
148
- lambda m: "" if "FromOriginalControl" in m.group(0) else m.group(0),
149
- text,
150
- )
151
-
152
- # Normalise la référence de mixin dans le reste du fichier
153
- text = text.replace("FromOriginalControlNetMixin", "FromOriginalControlnetMixin")
154
-
155
- marker = "class ControlNetConditioningEmbedding"
156
- if safe_block not in text:
157
- idx = text.find(marker)
158
- if idx != -1:
159
- text = text[:idx] + safe_block + text[idx:]
160
- else:
161
- text = safe_block + text
162
-
163
- if text != original:
164
- try:
165
- path.write_text(text, encoding="utf-8")
166
- print(f"[patch] normalized: {path}")
167
- except Exception as e:
168
- print(f"[patch] failed writing {path}: {e}")
169
- else:
170
- print(f"[patch] no changes: {path}")
171
-
172
-
173
- def patch_pasd_for_diffusers() -> None:
174
- # pipeline_utils path moved
175
- patch_file(
176
- "./PASD/pipelines/pipeline_pasd.py",
177
- [
178
- (
179
- "from diffusers.pipeline_utils import DiffusionPipeline",
180
- "from diffusers import DiffusionPipeline",
181
- ),
182
- ],
183
- )
184
-
185
- # PositionNet -> GLIGENTextBoundingboxProjection alias
186
- patch_file(
187
- "./PASD/models/pasd/unet_2d_condition.py",
188
- [
189
- (" PositionNet,\n", ""),
190
- (
191
- " GLIGENTextBoundingboxProjection,\n",
192
- " GLIGENTextBoundingboxProjection as PositionNet,\n",
193
- ),
194
- ],
195
- )
196
-
197
- # internal module paths moved in newer diffusers
198
- patch_file(
199
- "./PASD/models/pasd/unet_2d_blocks.py",
200
- [
201
- (
202
- "from diffusers.models.attention import AdaGroupNorm",
203
- "from diffusers.models.normalization import AdaGroupNorm",
204
- ),
205
- (
206
- "from diffusers.models.dual_transformer_2d import DualTransformer2DModel",
207
- "from diffusers.models.transformers.dual_transformer_2d import DualTransformer2DModel",
208
- ),
209
- (
210
- "from diffusers.models.transformer_2d import Transformer2DModel",
211
- "from diffusers.models.transformers.transformer_2d import Transformer2DModel",
212
- ),
213
- ],
214
- )
215
-
216
- # robust controlnet patch
217
- patch_controlnet_loader_import("./PASD/models/pasd/controlnet.py")
218
-
219
 
220
- patch_pasd_for_diffusers()
 
 
221
 
222
- # -------------------------------------------------------------------
223
- # Import PASD modules only after patching
224
- # -------------------------------------------------------------------
225
  from pipelines.pipeline_pasd import StableDiffusionControlNetPipeline
226
- from myutils.misc import load_dreambooth_lora
227
  from myutils.wavelet_color_fix import wavelet_color_fix
228
  from annotator.retinaface import RetinaFaceDetection
229
 
@@ -237,57 +63,27 @@ else:
237
  from models.pasd.unet_2d_condition import UNet2DConditionModel
238
  from models.pasd.controlnet import ControlNetModel
239
 
240
- # -------------------------------------------------------------------
241
- # Model setup
242
- # -------------------------------------------------------------------
243
- pretrained_model_path = "stable-diffusion-v1-5/stable-diffusion-v1-5"
244
  ckpt_path = "PASD/runs/pasd/checkpoint-100000"
 
245
  dreambooth_lora_path = "PASD/checkpoints/personalized_models/majicmixRealistic_v6.safetensors"
246
-
247
  weight_dtype = torch.float16
248
  device = "cuda"
249
 
250
- scheduler = UniPCMultistepScheduler.from_pretrained(
251
- pretrained_model_path,
252
- subfolder="scheduler",
253
- )
254
- text_encoder = CLIPTextModel.from_pretrained(
255
- pretrained_model_path,
256
- subfolder="text_encoder",
257
- )
258
- tokenizer = CLIPTokenizer.from_pretrained(
259
- pretrained_model_path,
260
- subfolder="tokenizer",
261
- )
262
- vae = AutoencoderKL.from_pretrained(
263
- pretrained_model_path,
264
- subfolder="vae",
265
- )
266
- feature_extractor = CLIPImageProcessor.from_pretrained(
267
- pretrained_model_path,
268
- subfolder="feature_extractor",
269
- )
270
-
271
- unet = UNet2DConditionModel.from_pretrained(
272
- ckpt_path,
273
- subfolder="unet",
274
- )
275
- controlnet = ControlNetModel.from_pretrained(
276
- ckpt_path,
277
- subfolder="controlnet",
278
- )
279
-
280
  vae.requires_grad_(False)
281
  text_encoder.requires_grad_(False)
282
  unet.requires_grad_(False)
283
  controlnet.requires_grad_(False)
284
 
285
- unet, vae, text_encoder = load_dreambooth_lora(
286
- unet,
287
- vae,
288
- text_encoder,
289
- dreambooth_lora_path,
290
- )
291
 
292
  text_encoder.to(device, dtype=weight_dtype)
293
  vae.to(device, dtype=weight_dtype)
@@ -295,133 +91,101 @@ unet.to(device, dtype=weight_dtype)
295
  controlnet.to(device, dtype=weight_dtype)
296
 
297
  validation_pipeline = StableDiffusionControlNetPipeline(
298
- vae=vae,
299
- text_encoder=text_encoder,
300
- tokenizer=tokenizer,
301
- feature_extractor=feature_extractor,
302
- unet=unet,
303
- controlnet=controlnet,
304
- scheduler=scheduler,
305
- safety_checker=None,
306
- requires_safety_checker=False,
307
- )
308
-
309
  validation_pipeline._init_tiled_vae(decoder_tile_size=224)
310
 
311
- # -------------------------------------------------------------------
312
- # ResNet helper
313
- # -------------------------------------------------------------------
314
  weights = ResNet50_Weights.DEFAULT
315
  preprocess = weights.transforms()
316
  resnet = resnet50(weights=weights)
317
  resnet.eval()
318
 
319
-
320
- def resize_image(image_path: str, target_height: int) -> Image.Image:
321
  with Image.open(image_path) as img:
 
322
  ratio = target_height / float(img.size[1])
 
323
  new_width = int(float(img.size[0]) * ratio)
324
- return img.resize((new_width, target_height), Image.LANCZOS)
325
-
 
 
 
326
 
327
  @spaces.GPU(enable_queue=True)
328
- def super_resolve_image(
329
- input_image,
330
- prompt,
331
- added_prompt,
332
- negative_prompt,
333
- denoise_steps,
334
- upscale,
335
- alpha,
336
- guidance_scale,
337
- seed,
338
- progress=gr.Progress(track_tqdm=True),
339
- ):
340
- """
341
- Super-resolve an input image with PASD and optional prompt guidance.
342
-
343
- Use this tool when you need to generate a higher-resolution restored image from an input image.
344
-
345
- Args:
346
- input_image (str): File path to the input image.
347
- prompt (str): Main text prompt describing the desired image content.
348
- added_prompt (str): Additional quality or style prompt appended to the main prompt.
349
- negative_prompt (str): Negative prompt describing unwanted visual qualities.
350
- denoise_steps (int): Number of denoising steps used by the diffusion pipeline.
351
- upscale (int): Integer upscale factor applied to the image.
352
- alpha (float): Conditioning scale passed to the ControlNet pipeline.
353
- guidance_scale (float): Classifier-free guidance scale passed to the diffusion pipeline.
354
- seed (int): Random seed, where -1 is converted to 0.
355
-
356
- Returns:
357
- tuple: Input image path, result image path, and downloadable result image path.
358
- """
359
  if seed == -1:
360
  seed = 0
361
-
362
  input_image = resize_image(input_image, 512)
 
 
 
 
 
 
363
  timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
364
 
365
  with torch.no_grad():
366
  seed_everything(seed)
367
  generator = torch.Generator(device=device)
368
- generator.manual_seed(seed)
369
-
370
- input_image = input_image.convert("RGB")
371
 
 
372
  batch = preprocess(input_image).unsqueeze(0)
373
  prediction = resnet(batch).squeeze(0).softmax(0)
374
  class_id = prediction.argmax().item()
375
  score = prediction[class_id].item()
376
  category_name = weights.meta["categories"][class_id]
377
-
378
  if score >= 0.1:
379
- prompt += f"{category_name}" if prompt == "" else f", {category_name}"
380
 
381
- prompt = added_prompt if prompt == "" else f"{prompt}, {added_prompt}"
382
 
383
  ori_width, ori_height = input_image.size
384
- rscale = upscale
385
 
386
- input_image = input_image.resize(
387
- (input_image.size[0] * rscale, input_image.size[1] * rscale)
388
- )
389
- input_image = input_image.resize(
390
- (input_image.size[0] // 8 * 8, input_image.size[1] // 8 * 8)
391
- )
392
 
 
393
  width, height = input_image.size
 
394
 
395
  try:
396
  image = validation_pipeline(
397
- None,
398
- prompt,
399
- input_image,
400
- num_inference_steps=denoise_steps,
401
- generator=generator,
402
- height=height,
403
- width=width,
404
- guidance_scale=guidance_scale,
405
- negative_prompt=negative_prompt,
406
- conditioning_scale=alpha,
407
- eta=0.0,
408
- ).images[0]
409
-
410
- image = wavelet_color_fix(image, input_image)
411
- image = image.resize((ori_width * rscale, ori_height * rscale))
412
-
413
  except Exception as e:
414
- print(f"[inference] error: {e}")
415
  image = Image.new(mode="RGB", size=(512, 512))
 
 
 
416
 
417
- result_path = f"result_{timestamp}.jpg"
418
- input_path = f"input_{timestamp}.jpg"
419
-
420
- image.save(result_path, "JPEG")
421
- input_image.save(input_path, "JPEG")
422
-
423
- return input_path, result_path, result_path
424
 
 
 
 
 
425
 
426
  css = """
427
  #col-container{
@@ -439,102 +203,53 @@ css = """
439
  }
440
  """
441
 
442
- with gr.Blocks() as demo:
443
  with gr.Column(elem_id="col-container"):
444
- gr.Markdown("""
445
- ## PASD Magnify
446
-
447
- Pixel-Aware Stable Diffusion for Realistic Image Super-resolution and Personalized Stylization
448
-
449
- <a href='https://arxiv.org/abs/2308.14469' target='_blank'><img src='https://img.shields.io/badge/arXiv-2308.14469-red'></a> <a href='https://github.com/yangxy/PASD' target='_blank'><img src='https://img.shields.io/badge/GitHub-Code-blue'></a>
450
-
451
- """)
452
-
 
 
 
 
 
 
 
453
  with gr.Row():
454
  with gr.Column():
455
- input_image = gr.Image(
456
- type="filepath",
457
- sources=["upload"],
458
- value="PASD/samples/frog.png",
459
- label="Input image",
460
- )
461
  prompt_in = gr.Textbox(label="Prompt", value="Frog")
462
-
463
  with gr.Accordion(label="Advanced settings", open=False):
464
- added_prompt = gr.Textbox(
465
- label="Added Prompt",
466
- value="clean, high-resolution, 8k, best quality, masterpiece",
467
- )
468
- neg_prompt = gr.Textbox(
469
- label="Negative Prompt",
470
- value="dotted, noise, blur, lowres, oversmooth, longbody, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
471
- )
472
- denoise_steps = gr.Slider(
473
- label="Denoise Steps",
474
- minimum=10,
475
- maximum=50,
476
- value=20,
477
- step=1,
478
- )
479
- upsample_scale = gr.Slider(
480
- label="Upsample Scale",
481
- minimum=1,
482
- maximum=4,
483
- value=2,
484
- step=1,
485
- )
486
- condition_scale = gr.Slider(
487
- label="Conditioning Scale",
488
- minimum=0.5,
489
- maximum=1.5,
490
- value=1.1,
491
- step=0.1,
492
- )
493
- classifier_free_guidance = gr.Slider(
494
- label="Classifier-free Guidance",
495
- minimum=0.1,
496
- maximum=10.0,
497
- value=7.5,
498
- step=0.1,
499
- )
500
- seed = gr.Slider(
501
- label="Seed",
502
- minimum=-1,
503
- maximum=2147483647,
504
- step=1,
505
- randomize=True,
506
- )
507
-
508
  submit_btn = gr.Button("Submit")
509
-
510
  with gr.Column():
511
- before_img = gr.Image(label="Input")
512
- after_img = gr.Image(label="Result")
513
  file_output = gr.File(label="Downloadable image result")
514
-
515
- submit_btn.click(
516
- fn=super_resolve_image,
517
- inputs=[
518
- input_image,
519
- prompt_in,
520
- added_prompt,
521
- neg_prompt,
522
- denoise_steps,
523
- upsample_scale,
524
- condition_scale,
525
- classifier_free_guidance,
526
- seed,
527
- ],
528
- outputs=[
529
- before_img,
530
- after_img,
531
- file_output,
532
- ],
533
- api_visibility="public",
534
- )
535
-
536
- demo.queue(max_size=10).launch(
537
- ssr_mode=False,
538
- mcp_server=True,
539
- css=css,
540
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ import types
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  torch.cuda.get_device_capability = lambda *args, **kwargs: (8, 6)
4
+ torch.cuda.get_device_properties = lambda *args, **kwargs: types.SimpleNamespace(name='NVIDIA A10G', major=8, minor=6, total_memory=23836033024, multi_processor_count=80)
 
 
 
 
 
 
5
 
6
+ import huggingface_hub
 
 
7
  huggingface_hub.snapshot_download(
8
+ repo_id='camenduru/PASD',
9
  allow_patterns=[
10
+ 'pasd/**',
11
+ 'pasd_light/**',
12
+ 'pasd_light_rrdb/**',
13
+ 'pasd_rrdb/**',
14
  ],
15
+ local_dir='PASD/runs',
16
+ local_dir_use_symlinks=False,
17
  )
 
18
  huggingface_hub.hf_hub_download(
19
+ repo_id='camenduru/PASD',
20
+ filename='majicmixRealistic_v6.safetensors',
21
+ local_dir='PASD/checkpoints/personalized_models',
22
+ local_dir_use_symlinks=False,
23
  )
 
24
  huggingface_hub.hf_hub_download(
25
+ repo_id='akhaliq/RetinaFace-R50',
26
+ filename='RetinaFace-R50.pth',
27
+ local_dir='PASD/annotator/ckpts',
28
+ local_dir_use_symlinks=False,
29
  )
30
 
31
+ import sys; sys.path.append('./PASD')
32
+ import spaces
33
+ import os
34
+ import datetime
35
+ import einops
36
+ import gradio as gr
37
+ from gradio_imageslider import ImageSlider
38
+ import numpy as np
39
+ import torch
40
+ import random
41
+ from PIL import Image
42
+ from pathlib import Path
43
+ from torchvision import transforms
44
+ import torch.nn.functional as F
45
+ from torchvision.models import resnet50, ResNet50_Weights
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
+ from pytorch_lightning import seed_everything
48
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor
49
+ from diffusers import AutoencoderKL, DDIMScheduler, PNDMScheduler, DPMSolverMultistepScheduler, UniPCMultistepScheduler
50
 
 
 
 
51
  from pipelines.pipeline_pasd import StableDiffusionControlNetPipeline
52
+ from myutils.misc import load_dreambooth_lora, rand_name
53
  from myutils.wavelet_color_fix import wavelet_color_fix
54
  from annotator.retinaface import RetinaFaceDetection
55
 
 
63
  from models.pasd.unet_2d_condition import UNet2DConditionModel
64
  from models.pasd.controlnet import ControlNetModel
65
 
66
+ pretrained_model_path = "runwayml/stable-diffusion-v1-5"
 
 
 
67
  ckpt_path = "PASD/runs/pasd/checkpoint-100000"
68
+ #dreambooth_lora_path = "checkpoints/personalized_models/toonyou_beta3.safetensors"
69
  dreambooth_lora_path = "PASD/checkpoints/personalized_models/majicmixRealistic_v6.safetensors"
70
+ #dreambooth_lora_path = "checkpoints/personalized_models/Realistic_Vision_V5.1.safetensors"
71
  weight_dtype = torch.float16
72
  device = "cuda"
73
 
74
+ scheduler = UniPCMultistepScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler")
75
+ text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
76
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
77
+ vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
78
+ feature_extractor = CLIPImageProcessor.from_pretrained(pretrained_model_path, subfolder="feature_extractor")
79
+ unet = UNet2DConditionModel.from_pretrained(ckpt_path, subfolder="unet")
80
+ controlnet = ControlNetModel.from_pretrained(ckpt_path, subfolder="controlnet")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  vae.requires_grad_(False)
82
  text_encoder.requires_grad_(False)
83
  unet.requires_grad_(False)
84
  controlnet.requires_grad_(False)
85
 
86
+ unet, vae, text_encoder = load_dreambooth_lora(unet, vae, text_encoder, dreambooth_lora_path)
 
 
 
 
 
87
 
88
  text_encoder.to(device, dtype=weight_dtype)
89
  vae.to(device, dtype=weight_dtype)
 
91
  controlnet.to(device, dtype=weight_dtype)
92
 
93
  validation_pipeline = StableDiffusionControlNetPipeline(
94
+ vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, feature_extractor=feature_extractor,
95
+ unet=unet, controlnet=controlnet, scheduler=scheduler, safety_checker=None, requires_safety_checker=False,
96
+ )
97
+ #validation_pipeline.enable_vae_tiling()
 
 
 
 
 
 
 
98
  validation_pipeline._init_tiled_vae(decoder_tile_size=224)
99
 
 
 
 
100
  weights = ResNet50_Weights.DEFAULT
101
  preprocess = weights.transforms()
102
  resnet = resnet50(weights=weights)
103
  resnet.eval()
104
 
105
+ def resize_image(image_path, target_height):
106
+ # Open the image file
107
  with Image.open(image_path) as img:
108
+ # Calculate the ratio to resize the image to the target height
109
  ratio = target_height / float(img.size[1])
110
+ # Calculate the new width based on the aspect ratio
111
  new_width = int(float(img.size[0]) * ratio)
112
+ # Resize the image
113
+ resized_img = img.resize((new_width, target_height), Image.LANCZOS)
114
+ # Save the resized image
115
+ #resized_img.save(output_path)
116
+ return resized_img
117
 
118
  @spaces.GPU(enable_queue=True)
119
+ def inference(input_image, prompt, a_prompt, n_prompt, denoise_steps, upscale, alpha, cfg, seed):
120
+
121
+ #tempo fix for seed equals-1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  if seed == -1:
123
  seed = 0
124
+
125
  input_image = resize_image(input_image, 512)
126
+ process_size = 768
127
+ resize_preproc = transforms.Compose([
128
+ transforms.Resize(process_size, interpolation=transforms.InterpolationMode.BILINEAR),
129
+ ])
130
+
131
+ # Get the current timestamp
132
  timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
133
 
134
  with torch.no_grad():
135
  seed_everything(seed)
136
  generator = torch.Generator(device=device)
 
 
 
137
 
138
+ input_image = input_image.convert('RGB')
139
  batch = preprocess(input_image).unsqueeze(0)
140
  prediction = resnet(batch).squeeze(0).softmax(0)
141
  class_id = prediction.argmax().item()
142
  score = prediction[class_id].item()
143
  category_name = weights.meta["categories"][class_id]
 
144
  if score >= 0.1:
145
+ prompt += f"{category_name}" if prompt=='' else f", {category_name}"
146
 
147
+ prompt = a_prompt if prompt=='' else f"{prompt}, {a_prompt}"
148
 
149
  ori_width, ori_height = input_image.size
150
+ resize_flag = False
151
 
152
+ rscale = upscale
153
+ input_image = input_image.resize((input_image.size[0]*rscale, input_image.size[1]*rscale))
154
+
155
+ #if min(validation_image.size) < process_size:
156
+ # validation_image = resize_preproc(validation_image)
 
157
 
158
+ input_image = input_image.resize((input_image.size[0]//8*8, input_image.size[1]//8*8))
159
  width, height = input_image.size
160
+ resize_flag = True #
161
 
162
  try:
163
  image = validation_pipeline(
164
+ None, prompt, input_image, num_inference_steps=denoise_steps, generator=generator, height=height, width=width, guidance_scale=cfg,
165
+ negative_prompt=n_prompt, conditioning_scale=alpha, eta=0.0,
166
+ ).images[0]
167
+
168
+ if True: #alpha<1.0:
169
+ image = wavelet_color_fix(image, input_image)
170
+
171
+ if resize_flag:
172
+ image = image.resize((ori_width*rscale, ori_height*rscale))
 
 
 
 
 
 
 
173
  except Exception as e:
174
+ print(e)
175
  image = Image.new(mode="RGB", size=(512, 512))
176
+
177
+ # Convert and save the image as JPEG
178
+ image.save(f'result_{timestamp}.jpg', 'JPEG')
179
 
180
+ # Convert and save the image as JPEG
181
+ input_image.save(f'input_{timestamp}.jpg', 'JPEG')
182
+
183
+ return (f"input_{timestamp}.jpg", f"result_{timestamp}.jpg"), f"result_{timestamp}.jpg"
 
 
 
184
 
185
+ title = "Pixel-Aware Stable Diffusion for Real-ISR"
186
+ description = "Gradio Demo for PASD Real-ISR. To use it, simply upload your image, or click one of the examples to load them."
187
+ article = "<a href='https://github.com/yangxy/PASD' target='_blank'>Github Repo Pytorch</a>"
188
+ #examples=[['samples/27d38eeb2dbbe7c9.png'],['samples/629e4da70703193b.png']]
189
 
190
  css = """
191
  #col-container{
 
203
  }
204
  """
205
 
206
+ with gr.Blocks(css=css) as demo:
207
  with gr.Column(elem_id="col-container"):
208
+ gr.HTML(f"""
209
+ <h2 style="text-align: center;">
210
+ PASD Magnify
211
+ </h2>
212
+ <p style="text-align: center;">
213
+ Pixel-Aware Stable Diffusion for Realistic Image Super-resolution and Personalized Stylization
214
+ </p>
215
+ <p id="project-links" align="center">
216
+ <a href='https://github.com/yangxy/PASD'><img src='https://img.shields.io/badge/Project-Page-Green'></a> <a href='https://huggingface.co/papers/2308.14469'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a>
217
+ </p>
218
+ <p style="margin:12px auto;display: flex;justify-content: center;">
219
+ <a href="https://huggingface.co/spaces/fffiloni/PASD?duplicate=true"><img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-lg.svg" alt="Duplicate this Space"></a>
220
+ </p>
221
+
222
+ """)
223
+
224
  with gr.Row():
225
  with gr.Column():
226
+ input_image = gr.Image(type="filepath", sources=["upload"], value="PASD/samples/frog.png")
 
 
 
 
 
227
  prompt_in = gr.Textbox(label="Prompt", value="Frog")
 
228
  with gr.Accordion(label="Advanced settings", open=False):
229
+ added_prompt = gr.Textbox(label="Added Prompt", value='clean, high-resolution, 8k, best quality, masterpiece')
230
+ neg_prompt = gr.Textbox(label="Negative Prompt",value='dotted, noise, blur, lowres, oversmooth, longbody, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
231
+ denoise_steps = gr.Slider(label="Denoise Steps", minimum=10, maximum=50, value=20, step=1)
232
+ upsample_scale = gr.Slider(label="Upsample Scale", minimum=1, maximum=4, value=2, step=1)
233
+ condition_scale = gr.Slider(label="Conditioning Scale", minimum=0.5, maximum=1.5, value=1.1, step=0.1)
234
+ classifier_free_guidance = gr.Slider(label="Classier-free Guidance", minimum=0.1, maximum=10.0, value=7.5, step=0.1)
235
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  submit_btn = gr.Button("Submit")
 
237
  with gr.Column():
238
+ b_a_slider = ImageSlider(label="B/A result", position=0.5)
 
239
  file_output = gr.File(label="Downloadable image result")
240
+
241
+ submit_btn.click(
242
+ fn = inference,
243
+ inputs = [
244
+ input_image, prompt_in,
245
+ added_prompt, neg_prompt,
246
+ denoise_steps,
247
+ upsample_scale, condition_scale,
248
+ classifier_free_guidance, seed
249
+ ],
250
+ outputs = [
251
+ b_a_slider,
252
+ file_output
253
+ ]
254
+ )
255
+ demo.queue(max_size=10).launch(show_api=False)
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,20 +1,13 @@
1
- #gradio==6.12.0
2
- spaces>=0.48.1
3
-
4
- huggingface_hub==0.33.5
5
- diffusers==0.27.2
6
- transformers==4.35.2
7
- accelerate==0.24.1
8
-
9
- torch==2.8.0
10
- torchvision==0.23.0
11
-
12
- basicsr-fixed
13
  ultralytics
14
  salesforce-lavis
15
  webdataset
16
  pytorch_lightning
 
 
17
  spacy
18
- einops
19
- numpy
20
- pillow
 
1
+ diffusers==0.28.2
2
+ accelerate
3
+ transformers==4.52.3
4
+ xformers==0.0.29.post1
5
+ basicsr
 
 
 
 
 
 
 
6
  ultralytics
7
  salesforce-lavis
8
  webdataset
9
  pytorch_lightning
10
+ torch==2.5.1
11
+ torchvision==0.20.1
12
  spacy
13
+ gradio_imageslider