fffiloni commited on
Commit
ea08edb
·
verified ·
1 Parent(s): 0e55432

Update app_zero.py

Browse files
Files changed (1) hide show
  1. app_zero.py +155 -100
app_zero.py CHANGED
@@ -1,55 +1,70 @@
1
- import torch
 
2
  import types
3
- torch.cuda.get_device_capability = lambda *args, **kwargs: (8, 6)
4
- torch.cuda.get_device_properties = lambda *args, **kwargs: types.SimpleNamespace(name='NVIDIA A10G', major=8, minor=6, total_memory=23836033024, multi_processor_count=80)
 
5
 
 
 
 
 
 
6
  import huggingface_hub
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  huggingface_hub.snapshot_download(
8
- repo_id='camenduru/PASD',
9
  allow_patterns=[
10
- 'pasd/**',
11
- 'pasd_light/**',
12
- 'pasd_light_rrdb/**',
13
- 'pasd_rrdb/**',
14
  ],
15
- local_dir='PASD/runs',
16
- #local_dir_use_symlinks=False,
17
  )
 
18
  huggingface_hub.hf_hub_download(
19
- repo_id='camenduru/PASD',
20
- filename='majicmixRealistic_v6.safetensors',
21
- local_dir='PASD/checkpoints/personalized_models',
22
- #local_dir_use_symlinks=False,
23
  )
 
24
  huggingface_hub.hf_hub_download(
25
- repo_id='akhaliq/RetinaFace-R50',
26
- filename='RetinaFace-R50.pth',
27
- local_dir='PASD/annotator/ckpts',
28
- #local_dir_use_symlinks=False,
29
  )
30
 
31
- import sys; sys.path.append('./PASD')
32
- import spaces
33
- import os
34
- import datetime
35
- import einops
36
- import gradio as gr
37
- from gradio_imageslider import ImageSlider
38
- import numpy as np
39
- import torch
40
- import random
41
- from PIL import Image
42
- from pathlib import Path
43
- from torchvision import transforms
44
- import torch.nn.functional as F
45
- from torchvision.models import resnet50, ResNet50_Weights
46
-
47
- from pytorch_lightning import seed_everything
48
- from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor
49
- from diffusers import AutoencoderKL, DDIMScheduler, PNDMScheduler, DPMSolverMultistepScheduler, UniPCMultistepScheduler
50
 
51
- # compat shim for older custom pipelines expecting diffusers.pipeline_utils
52
- import sys
53
  try:
54
  import diffusers.pipelines.pipeline_utils as _pipeline_utils
55
  sys.modules["diffusers.pipeline_utils"] = _pipeline_utils
@@ -61,6 +76,7 @@ from myutils.misc import load_dreambooth_lora, rand_name
61
  from myutils.wavelet_color_fix import wavelet_color_fix
62
  from annotator.retinaface import RetinaFaceDetection
63
 
 
64
  use_pasd_light = False
65
  face_detector = RetinaFaceDetection()
66
 
@@ -73,12 +89,12 @@ else:
73
 
74
  pretrained_model_path = "stable-diffusion-v1-5/stable-diffusion-v1-5"
75
  ckpt_path = "PASD/runs/pasd/checkpoint-100000"
76
- #dreambooth_lora_path = "checkpoints/personalized_models/toonyou_beta3.safetensors"
77
  dreambooth_lora_path = "PASD/checkpoints/personalized_models/majicmixRealistic_v6.safetensors"
78
- #dreambooth_lora_path = "checkpoints/personalized_models/Realistic_Vision_V5.1.safetensors"
79
  weight_dtype = torch.float16
80
  device = "cuda"
81
 
 
82
  scheduler = UniPCMultistepScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler")
83
  text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
84
  tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
@@ -86,6 +102,7 @@ vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
86
  feature_extractor = CLIPImageProcessor.from_pretrained(pretrained_model_path, subfolder="feature_extractor")
87
  unet = UNet2DConditionModel.from_pretrained(ckpt_path, subfolder="unet")
88
  controlnet = ControlNetModel.from_pretrained(ckpt_path, subfolder="controlnet")
 
89
  vae.requires_grad_(False)
90
  text_encoder.requires_grad_(False)
91
  unet.requires_grad_(False)
@@ -99,101 +116,106 @@ unet.to(device, dtype=weight_dtype)
99
  controlnet.to(device, dtype=weight_dtype)
100
 
101
  validation_pipeline = StableDiffusionControlNetPipeline(
102
- vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, feature_extractor=feature_extractor,
103
- unet=unet, controlnet=controlnet, scheduler=scheduler, safety_checker=None, requires_safety_checker=False,
104
- )
105
- #validation_pipeline.enable_vae_tiling()
 
 
 
 
 
 
 
 
106
  validation_pipeline._init_tiled_vae(decoder_tile_size=224)
107
 
 
108
  weights = ResNet50_Weights.DEFAULT
109
  preprocess = weights.transforms()
110
  resnet = resnet50(weights=weights)
111
  resnet.eval()
112
 
 
113
  def resize_image(image_path, target_height):
114
- # Open the image file
115
  with Image.open(image_path) as img:
116
- # Calculate the ratio to resize the image to the target height
117
  ratio = target_height / float(img.size[1])
118
- # Calculate the new width based on the aspect ratio
119
  new_width = int(float(img.size[0]) * ratio)
120
- # Resize the image
121
  resized_img = img.resize((new_width, target_height), Image.LANCZOS)
122
- # Save the resized image
123
- #resized_img.save(output_path)
124
  return resized_img
125
 
 
126
  @spaces.GPU(enable_queue=True)
127
  def inference(input_image, prompt, a_prompt, n_prompt, denoise_steps, upscale, alpha, cfg, seed):
128
-
129
- #tempo fix for seed equals-1
130
  if seed == -1:
131
  seed = 0
132
-
133
  input_image = resize_image(input_image, 512)
134
  process_size = 768
135
  resize_preproc = transforms.Compose([
136
  transforms.Resize(process_size, interpolation=transforms.InterpolationMode.BILINEAR),
137
  ])
138
-
139
- # Get the current timestamp
140
  timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
141
 
142
  with torch.no_grad():
143
  seed_everything(seed)
144
  generator = torch.Generator(device=device)
 
145
 
146
- input_image = input_image.convert('RGB')
147
  batch = preprocess(input_image).unsqueeze(0)
148
  prediction = resnet(batch).squeeze(0).softmax(0)
149
  class_id = prediction.argmax().item()
150
  score = prediction[class_id].item()
151
  category_name = weights.meta["categories"][class_id]
 
152
  if score >= 0.1:
153
- prompt += f"{category_name}" if prompt=='' else f", {category_name}"
154
 
155
- prompt = a_prompt if prompt=='' else f"{prompt}, {a_prompt}"
156
 
157
  ori_width, ori_height = input_image.size
158
  resize_flag = False
159
 
160
  rscale = upscale
161
- input_image = input_image.resize((input_image.size[0]*rscale, input_image.size[1]*rscale))
162
-
163
- #if min(validation_image.size) < process_size:
164
- # validation_image = resize_preproc(validation_image)
165
 
166
- input_image = input_image.resize((input_image.size[0]//8*8, input_image.size[1]//8*8))
167
  width, height = input_image.size
168
- resize_flag = True #
169
 
170
  try:
171
  image = validation_pipeline(
172
- None, prompt, input_image, num_inference_steps=denoise_steps, generator=generator, height=height, width=width, guidance_scale=cfg,
173
- negative_prompt=n_prompt, conditioning_scale=alpha, eta=0.0,
174
- ).images[0]
175
-
176
- if True: #alpha<1.0:
177
- image = wavelet_color_fix(image, input_image)
178
-
179
- if resize_flag:
180
- image = image.resize((ori_width*rscale, ori_height*rscale))
 
 
 
 
 
 
 
 
181
  except Exception as e:
182
  print(e)
183
  image = Image.new(mode="RGB", size=(512, 512))
184
-
185
- # Convert and save the image as JPEG
186
- image.save(f'result_{timestamp}.jpg', 'JPEG')
187
 
188
- # Convert and save the image as JPEG
189
- input_image.save(f'input_{timestamp}.jpg', 'JPEG')
190
-
191
- return f"input_{timestamp}.jpg", f"result_{timestamp}.jpg", f"result_{timestamp}.jpg"
 
 
 
192
 
193
- title = "Pixel-Aware Stable Diffusion for Real-ISR"
194
- description = "Gradio Demo for PASD Real-ISR. To use it, simply upload your image, or click one of the examples to load them."
195
- article = "<a href='https://github.com/yangxy/PASD' target='_blank'>Github Repo Pytorch</a>"
196
- #examples=[['samples/27d38eeb2dbbe7c9.png'],['samples/629e4da70703193b.png']]
197
 
198
  css = """
199
  #col-container{
@@ -221,9 +243,9 @@ with gr.Blocks() as demo:
221
  Pixel-Aware Stable Diffusion for Realistic Image Super-resolution and Personalized Stylization
222
  </p>
223
  <p id="project-links" align="center">
224
- <a href='https://github.com/yangxy/PASD'><img src='https://img.shields.io/badge/Project-Page-Green'></a>
225
- <a href='https://huggingface.co/papers/2308.14469'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a>
226
- </p>
227
  <p style="margin:12px auto;display: flex;justify-content: center;">
228
  <a href="https://huggingface.co/spaces/fffiloni/PASD?duplicate=true">
229
  <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-lg.svg" alt="Duplicate this Space">
@@ -233,16 +255,41 @@ with gr.Blocks() as demo:
233
 
234
  with gr.Row():
235
  with gr.Column():
236
- input_image = gr.Image(type="filepath", sources=["upload"], value="PASD/samples/frog.png")
 
 
 
 
 
237
  prompt_in = gr.Textbox(label="Prompt", value="Frog")
 
238
  with gr.Accordion(label="Advanced settings", open=False):
239
- added_prompt = gr.Textbox(label="Added Prompt", value='clean, high-resolution, 8k, best quality, masterpiece')
240
- neg_prompt = gr.Textbox(label="Negative Prompt", value='dotted, noise, blur, lowres, oversmooth, longbody, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
 
 
 
 
 
 
241
  denoise_steps = gr.Slider(label="Denoise Steps", minimum=10, maximum=50, value=20, step=1)
242
  upsample_scale = gr.Slider(label="Upsample Scale", minimum=1, maximum=4, value=2, step=1)
243
  condition_scale = gr.Slider(label="Conditioning Scale", minimum=0.5, maximum=1.5, value=1.1, step=0.1)
244
- classifier_free_guidance = gr.Slider(label="Classifier-free Guidance", minimum=0.1, maximum=10.0, value=7.5, step=0.1)
245
- seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  submit_btn = gr.Button("Submit")
247
 
248
  with gr.Column():
@@ -253,18 +300,26 @@ with gr.Blocks() as demo:
253
  submit_btn.click(
254
  fn=inference,
255
  inputs=[
256
- input_image, prompt_in,
257
- added_prompt, neg_prompt,
 
 
258
  denoise_steps,
259
- upsample_scale, condition_scale,
260
- classifier_free_guidance, seed
 
 
261
  ],
262
  outputs=[
263
  before_img,
264
  after_img,
265
- file_output
266
  ],
267
- api_visibility="private"
268
  )
269
 
270
- demo.queue(max_size=10).launch(ssr_mode=False, mcp_server=False, css=css)
 
 
 
 
 
1
+ import sys
2
+ import os
3
  import types
4
+ import random
5
+ import datetime
6
+ from pathlib import Path
7
 
8
+ import torch
9
+ import numpy as np
10
+ import einops
11
+ import spaces
12
+ import gradio as gr
13
  import huggingface_hub
14
+
15
+ from PIL import Image
16
+ from torchvision import transforms
17
+ import torch.nn.functional as F
18
+ from torchvision.models import resnet50, ResNet50_Weights
19
+
20
+ from pytorch_lightning import seed_everything
21
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor
22
+ from diffusers import (
23
+ AutoencoderKL,
24
+ DDIMScheduler,
25
+ PNDMScheduler,
26
+ DPMSolverMultistepScheduler,
27
+ UniPCMultistepScheduler,
28
+ )
29
+
30
+ # ---- GPU spoof for Spaces env compatibility ----
31
+ torch.cuda.get_device_capability = lambda *args, **kwargs: (8, 6)
32
+ torch.cuda.get_device_properties = lambda *args, **kwargs: types.SimpleNamespace(
33
+ name="NVIDIA A10G",
34
+ major=8,
35
+ minor=6,
36
+ total_memory=23836033024,
37
+ multi_processor_count=80,
38
+ )
39
+
40
+ # ---- Download required assets ----
41
  huggingface_hub.snapshot_download(
42
+ repo_id="camenduru/PASD",
43
  allow_patterns=[
44
+ "pasd/**",
45
+ "pasd_light/**",
46
+ "pasd_light_rrdb/**",
47
+ "pasd_rrdb/**",
48
  ],
49
+ local_dir="PASD/runs",
 
50
  )
51
+
52
  huggingface_hub.hf_hub_download(
53
+ repo_id="camenduru/PASD",
54
+ filename="majicmixRealistic_v6.safetensors",
55
+ local_dir="PASD/checkpoints/personalized_models",
 
56
  )
57
+
58
  huggingface_hub.hf_hub_download(
59
+ repo_id="akhaliq/RetinaFace-R50",
60
+ filename="RetinaFace-R50.pth",
61
+ local_dir="PASD/annotator/ckpts",
 
62
  )
63
 
64
+ # ---- Local PASD path ----
65
+ sys.path.append("./PASD")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
+ # ---- Compat shim for older custom pipelines expecting diffusers.pipeline_utils ----
 
68
  try:
69
  import diffusers.pipelines.pipeline_utils as _pipeline_utils
70
  sys.modules["diffusers.pipeline_utils"] = _pipeline_utils
 
76
  from myutils.wavelet_color_fix import wavelet_color_fix
77
  from annotator.retinaface import RetinaFaceDetection
78
 
79
+ # ---- Model selection ----
80
  use_pasd_light = False
81
  face_detector = RetinaFaceDetection()
82
 
 
89
 
90
  pretrained_model_path = "stable-diffusion-v1-5/stable-diffusion-v1-5"
91
  ckpt_path = "PASD/runs/pasd/checkpoint-100000"
 
92
  dreambooth_lora_path = "PASD/checkpoints/personalized_models/majicmixRealistic_v6.safetensors"
93
+
94
  weight_dtype = torch.float16
95
  device = "cuda"
96
 
97
+ # ---- Load models ----
98
  scheduler = UniPCMultistepScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler")
99
  text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
100
  tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
 
102
  feature_extractor = CLIPImageProcessor.from_pretrained(pretrained_model_path, subfolder="feature_extractor")
103
  unet = UNet2DConditionModel.from_pretrained(ckpt_path, subfolder="unet")
104
  controlnet = ControlNetModel.from_pretrained(ckpt_path, subfolder="controlnet")
105
+
106
  vae.requires_grad_(False)
107
  text_encoder.requires_grad_(False)
108
  unet.requires_grad_(False)
 
116
  controlnet.to(device, dtype=weight_dtype)
117
 
118
  validation_pipeline = StableDiffusionControlNetPipeline(
119
+ vae=vae,
120
+ text_encoder=text_encoder,
121
+ tokenizer=tokenizer,
122
+ feature_extractor=feature_extractor,
123
+ unet=unet,
124
+ controlnet=controlnet,
125
+ scheduler=scheduler,
126
+ safety_checker=None,
127
+ requires_safety_checker=False,
128
+ )
129
+
130
+ # validation_pipeline.enable_vae_tiling()
131
  validation_pipeline._init_tiled_vae(decoder_tile_size=224)
132
 
133
+ # ---- ResNet auto-tag helper ----
134
  weights = ResNet50_Weights.DEFAULT
135
  preprocess = weights.transforms()
136
  resnet = resnet50(weights=weights)
137
  resnet.eval()
138
 
139
+
140
  def resize_image(image_path, target_height):
 
141
  with Image.open(image_path) as img:
 
142
  ratio = target_height / float(img.size[1])
 
143
  new_width = int(float(img.size[0]) * ratio)
 
144
  resized_img = img.resize((new_width, target_height), Image.LANCZOS)
 
 
145
  return resized_img
146
 
147
+
148
  @spaces.GPU(enable_queue=True)
149
  def inference(input_image, prompt, a_prompt, n_prompt, denoise_steps, upscale, alpha, cfg, seed):
 
 
150
  if seed == -1:
151
  seed = 0
152
+
153
  input_image = resize_image(input_image, 512)
154
  process_size = 768
155
  resize_preproc = transforms.Compose([
156
  transforms.Resize(process_size, interpolation=transforms.InterpolationMode.BILINEAR),
157
  ])
158
+
 
159
  timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
160
 
161
  with torch.no_grad():
162
  seed_everything(seed)
163
  generator = torch.Generator(device=device)
164
+ generator.manual_seed(seed)
165
 
166
+ input_image = input_image.convert("RGB")
167
  batch = preprocess(input_image).unsqueeze(0)
168
  prediction = resnet(batch).squeeze(0).softmax(0)
169
  class_id = prediction.argmax().item()
170
  score = prediction[class_id].item()
171
  category_name = weights.meta["categories"][class_id]
172
+
173
  if score >= 0.1:
174
+ prompt += f"{category_name}" if prompt == "" else f", {category_name}"
175
 
176
+ prompt = a_prompt if prompt == "" else f"{prompt}, {a_prompt}"
177
 
178
  ori_width, ori_height = input_image.size
179
  resize_flag = False
180
 
181
  rscale = upscale
182
+ input_image = input_image.resize((input_image.size[0] * rscale, input_image.size[1] * rscale))
 
 
 
183
 
184
+ input_image = input_image.resize((input_image.size[0] // 8 * 8, input_image.size[1] // 8 * 8))
185
  width, height = input_image.size
186
+ resize_flag = True
187
 
188
  try:
189
  image = validation_pipeline(
190
+ None,
191
+ prompt,
192
+ input_image,
193
+ num_inference_steps=denoise_steps,
194
+ generator=generator,
195
+ height=height,
196
+ width=width,
197
+ guidance_scale=cfg,
198
+ negative_prompt=n_prompt,
199
+ conditioning_scale=alpha,
200
+ eta=0.0,
201
+ ).images[0]
202
+
203
+ image = wavelet_color_fix(image, input_image)
204
+
205
+ if resize_flag:
206
+ image = image.resize((ori_width * rscale, ori_height * rscale))
207
  except Exception as e:
208
  print(e)
209
  image = Image.new(mode="RGB", size=(512, 512))
 
 
 
210
 
211
+ result_path = f"result_{timestamp}.jpg"
212
+ input_path = f"input_{timestamp}.jpg"
213
+
214
+ image.save(result_path, "JPEG")
215
+ input_image.save(input_path, "JPEG")
216
+
217
+ return input_path, result_path, result_path
218
 
 
 
 
 
219
 
220
  css = """
221
  #col-container{
 
243
  Pixel-Aware Stable Diffusion for Realistic Image Super-resolution and Personalized Stylization
244
  </p>
245
  <p id="project-links" align="center">
246
+ <a href="https://github.com/yangxy/PASD"><img src="https://img.shields.io/badge/Project-Page-Green"></a>
247
+ <a href="https://huggingface.co/papers/2308.14469"><img src="https://img.shields.io/badge/Paper-Arxiv-red"></a>
248
+ </p>
249
  <p style="margin:12px auto;display: flex;justify-content: center;">
250
  <a href="https://huggingface.co/spaces/fffiloni/PASD?duplicate=true">
251
  <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-lg.svg" alt="Duplicate this Space">
 
255
 
256
  with gr.Row():
257
  with gr.Column():
258
+ input_image = gr.Image(
259
+ type="filepath",
260
+ sources=["upload"],
261
+ value="PASD/samples/frog.png",
262
+ label="Input image",
263
+ )
264
  prompt_in = gr.Textbox(label="Prompt", value="Frog")
265
+
266
  with gr.Accordion(label="Advanced settings", open=False):
267
+ added_prompt = gr.Textbox(
268
+ label="Added Prompt",
269
+ value="clean, high-resolution, 8k, best quality, masterpiece",
270
+ )
271
+ neg_prompt = gr.Textbox(
272
+ label="Negative Prompt",
273
+ value="dotted, noise, blur, lowres, oversmooth, longbody, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
274
+ )
275
  denoise_steps = gr.Slider(label="Denoise Steps", minimum=10, maximum=50, value=20, step=1)
276
  upsample_scale = gr.Slider(label="Upsample Scale", minimum=1, maximum=4, value=2, step=1)
277
  condition_scale = gr.Slider(label="Conditioning Scale", minimum=0.5, maximum=1.5, value=1.1, step=0.1)
278
+ classifier_free_guidance = gr.Slider(
279
+ label="Classifier-free Guidance",
280
+ minimum=0.1,
281
+ maximum=10.0,
282
+ value=7.5,
283
+ step=0.1,
284
+ )
285
+ seed = gr.Slider(
286
+ label="Seed",
287
+ minimum=-1,
288
+ maximum=2147483647,
289
+ step=1,
290
+ randomize=True,
291
+ )
292
+
293
  submit_btn = gr.Button("Submit")
294
 
295
  with gr.Column():
 
300
  submit_btn.click(
301
  fn=inference,
302
  inputs=[
303
+ input_image,
304
+ prompt_in,
305
+ added_prompt,
306
+ neg_prompt,
307
  denoise_steps,
308
+ upsample_scale,
309
+ condition_scale,
310
+ classifier_free_guidance,
311
+ seed,
312
  ],
313
  outputs=[
314
  before_img,
315
  after_img,
316
+ file_output,
317
  ],
318
+ api_visibility="private",
319
  )
320
 
321
+ demo.queue(max_size=10).launch(
322
+ ssr_mode=False,
323
+ mcp_server=False,
324
+ css=css,
325
+ )