Spaces:
Running on Zero
Running on Zero
Update app_zero.py
Browse files- app_zero.py +83 -63
app_zero.py
CHANGED
|
@@ -12,6 +12,7 @@ import huggingface_hub
|
|
| 12 |
if not hasattr(huggingface_hub, "cached_download"):
|
| 13 |
def cached_download(*args, **kwargs):
|
| 14 |
return huggingface_hub.hf_hub_download(*args, **kwargs)
|
|
|
|
| 15 |
huggingface_hub.cached_download = cached_download
|
| 16 |
|
| 17 |
import torch
|
|
@@ -19,12 +20,10 @@ import numpy as np
|
|
| 19 |
import einops
|
| 20 |
import spaces
|
| 21 |
import gradio as gr
|
| 22 |
-
|
| 23 |
from PIL import Image
|
| 24 |
from torchvision import transforms
|
| 25 |
import torch.nn.functional as F
|
| 26 |
from torchvision.models import resnet50, ResNet50_Weights
|
| 27 |
-
|
| 28 |
from pytorch_lightning import seed_everything
|
| 29 |
from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor
|
| 30 |
from diffusers import (
|
|
@@ -78,7 +77,6 @@ huggingface_hub.hf_hub_download(
|
|
| 78 |
# -------------------------------------------------------------------
|
| 79 |
sys.path.append("./PASD")
|
| 80 |
|
| 81 |
-
|
| 82 |
# -------------------------------------------------------------------
|
| 83 |
# Runtime patching helpers
|
| 84 |
# -------------------------------------------------------------------
|
|
@@ -130,7 +128,6 @@ except Exception:
|
|
| 130 |
pass
|
| 131 |
|
| 132 |
"""
|
| 133 |
-
|
| 134 |
original = text
|
| 135 |
|
| 136 |
# Enlève d'anciens imports simples
|
|
@@ -147,7 +144,7 @@ except Exception:
|
|
| 147 |
|
| 148 |
# Enlève d'anciens blocs try/except cassés liés à ce mixin
|
| 149 |
text = re.sub(
|
| 150 |
-
r"(?ms)^try:\n(?:(?:
|
| 151 |
lambda m: "" if "FromOriginalControl" in m.group(0) else m.group(0),
|
| 152 |
text,
|
| 153 |
)
|
|
@@ -189,10 +186,10 @@ def patch_pasd_for_diffusers() -> None:
|
|
| 189 |
patch_file(
|
| 190 |
"./PASD/models/pasd/unet_2d_condition.py",
|
| 191 |
[
|
| 192 |
-
("
|
| 193 |
(
|
| 194 |
-
"
|
| 195 |
-
"
|
| 196 |
),
|
| 197 |
],
|
| 198 |
)
|
|
@@ -251,25 +248,33 @@ weight_dtype = torch.float16
|
|
| 251 |
device = "cuda"
|
| 252 |
|
| 253 |
scheduler = UniPCMultistepScheduler.from_pretrained(
|
| 254 |
-
pretrained_model_path,
|
|
|
|
| 255 |
)
|
| 256 |
text_encoder = CLIPTextModel.from_pretrained(
|
| 257 |
-
pretrained_model_path,
|
|
|
|
| 258 |
)
|
| 259 |
tokenizer = CLIPTokenizer.from_pretrained(
|
| 260 |
-
pretrained_model_path,
|
|
|
|
| 261 |
)
|
| 262 |
vae = AutoencoderKL.from_pretrained(
|
| 263 |
-
pretrained_model_path,
|
|
|
|
| 264 |
)
|
| 265 |
feature_extractor = CLIPImageProcessor.from_pretrained(
|
| 266 |
-
pretrained_model_path,
|
|
|
|
| 267 |
)
|
|
|
|
| 268 |
unet = UNet2DConditionModel.from_pretrained(
|
| 269 |
-
ckpt_path,
|
|
|
|
| 270 |
)
|
| 271 |
controlnet = ControlNetModel.from_pretrained(
|
| 272 |
-
ckpt_path,
|
|
|
|
| 273 |
)
|
| 274 |
|
| 275 |
vae.requires_grad_(False)
|
|
@@ -278,7 +283,10 @@ unet.requires_grad_(False)
|
|
| 278 |
controlnet.requires_grad_(False)
|
| 279 |
|
| 280 |
unet, vae, text_encoder = load_dreambooth_lora(
|
| 281 |
-
unet,
|
|
|
|
|
|
|
|
|
|
| 282 |
)
|
| 283 |
|
| 284 |
text_encoder.to(device, dtype=weight_dtype)
|
|
@@ -317,18 +325,37 @@ def resize_image(image_path: str, target_height: int) -> Image.Image:
|
|
| 317 |
|
| 318 |
|
| 319 |
@spaces.GPU(enable_queue=True)
|
| 320 |
-
def
|
| 321 |
input_image,
|
| 322 |
prompt,
|
| 323 |
-
|
| 324 |
-
|
| 325 |
denoise_steps,
|
| 326 |
upscale,
|
| 327 |
alpha,
|
| 328 |
-
|
| 329 |
seed,
|
| 330 |
-
progress=gr.Progress(track_tqdm=True)
|
| 331 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
if seed == -1:
|
| 333 |
seed = 0
|
| 334 |
|
|
@@ -351,17 +378,18 @@ def inference(
|
|
| 351 |
if score >= 0.1:
|
| 352 |
prompt += f"{category_name}" if prompt == "" else f", {category_name}"
|
| 353 |
|
| 354 |
-
prompt =
|
| 355 |
|
| 356 |
ori_width, ori_height = input_image.size
|
| 357 |
-
|
| 358 |
rscale = upscale
|
|
|
|
| 359 |
input_image = input_image.resize(
|
| 360 |
(input_image.size[0] * rscale, input_image.size[1] * rscale)
|
| 361 |
)
|
| 362 |
input_image = input_image.resize(
|
| 363 |
(input_image.size[0] // 8 * 8, input_image.size[1] // 8 * 8)
|
| 364 |
)
|
|
|
|
| 365 |
width, height = input_image.size
|
| 366 |
|
| 367 |
try:
|
|
@@ -373,14 +401,15 @@ def inference(
|
|
| 373 |
generator=generator,
|
| 374 |
height=height,
|
| 375 |
width=width,
|
| 376 |
-
guidance_scale=
|
| 377 |
-
negative_prompt=
|
| 378 |
conditioning_scale=alpha,
|
| 379 |
eta=0.0,
|
| 380 |
).images[0]
|
| 381 |
|
| 382 |
image = wavelet_color_fix(image, input_image)
|
| 383 |
image = image.resize((ori_width * rscale, ori_height * rscale))
|
|
|
|
| 384 |
except Exception as e:
|
| 385 |
print(f"[inference] error: {e}")
|
| 386 |
image = Image.new(mode="RGB", size=(512, 512))
|
|
@@ -412,23 +441,14 @@ css = """
|
|
| 412 |
|
| 413 |
with gr.Blocks() as demo:
|
| 414 |
with gr.Column(elem_id="col-container"):
|
| 415 |
-
gr.
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
<a href="https://github.com/yangxy/PASD"><img src="https://img.shields.io/badge/Project-Page-Green"></a>
|
| 424 |
-
<a href="https://huggingface.co/papers/2308.14469"><img src="https://img.shields.io/badge/Paper-Arxiv-red"></a>
|
| 425 |
-
</p>
|
| 426 |
-
<p style="margin:12px auto;display: flex;justify-content: center;">
|
| 427 |
-
<a href="https://huggingface.co/spaces/fffiloni/PASD?duplicate=true">
|
| 428 |
-
<img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-lg.svg" alt="Duplicate this Space">
|
| 429 |
-
</a>
|
| 430 |
-
</p>
|
| 431 |
-
""")
|
| 432 |
|
| 433 |
with gr.Row():
|
| 434 |
with gr.Column():
|
|
@@ -492,29 +512,29 @@ with gr.Blocks() as demo:
|
|
| 492 |
after_img = gr.Image(label="Result")
|
| 493 |
file_output = gr.File(label="Downloadable image result")
|
| 494 |
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
|
| 516 |
demo.queue(max_size=10).launch(
|
| 517 |
ssr_mode=False,
|
| 518 |
-
mcp_server=
|
| 519 |
css=css,
|
| 520 |
)
|
|
|
|
| 12 |
if not hasattr(huggingface_hub, "cached_download"):
|
| 13 |
def cached_download(*args, **kwargs):
|
| 14 |
return huggingface_hub.hf_hub_download(*args, **kwargs)
|
| 15 |
+
|
| 16 |
huggingface_hub.cached_download = cached_download
|
| 17 |
|
| 18 |
import torch
|
|
|
|
| 20 |
import einops
|
| 21 |
import spaces
|
| 22 |
import gradio as gr
|
|
|
|
| 23 |
from PIL import Image
|
| 24 |
from torchvision import transforms
|
| 25 |
import torch.nn.functional as F
|
| 26 |
from torchvision.models import resnet50, ResNet50_Weights
|
|
|
|
| 27 |
from pytorch_lightning import seed_everything
|
| 28 |
from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor
|
| 29 |
from diffusers import (
|
|
|
|
| 77 |
# -------------------------------------------------------------------
|
| 78 |
sys.path.append("./PASD")
|
| 79 |
|
|
|
|
| 80 |
# -------------------------------------------------------------------
|
| 81 |
# Runtime patching helpers
|
| 82 |
# -------------------------------------------------------------------
|
|
|
|
| 128 |
pass
|
| 129 |
|
| 130 |
"""
|
|
|
|
| 131 |
original = text
|
| 132 |
|
| 133 |
# Enlève d'anciens imports simples
|
|
|
|
| 144 |
|
| 145 |
# Enlève d'anciens blocs try/except cassés liés à ce mixin
|
| 146 |
text = re.sub(
|
| 147 |
+
r"(?ms)^try:\n(?:(?: |\t).*\n)+?except Exception:\n(?:(?: |\t).*\n)+?(?=^(?:class|def|@|from |import |\Z))",
|
| 148 |
lambda m: "" if "FromOriginalControl" in m.group(0) else m.group(0),
|
| 149 |
text,
|
| 150 |
)
|
|
|
|
| 186 |
patch_file(
|
| 187 |
"./PASD/models/pasd/unet_2d_condition.py",
|
| 188 |
[
|
| 189 |
+
(" PositionNet,\n", ""),
|
| 190 |
(
|
| 191 |
+
" GLIGENTextBoundingboxProjection,\n",
|
| 192 |
+
" GLIGENTextBoundingboxProjection as PositionNet,\n",
|
| 193 |
),
|
| 194 |
],
|
| 195 |
)
|
|
|
|
| 248 |
device = "cuda"
|
| 249 |
|
| 250 |
scheduler = UniPCMultistepScheduler.from_pretrained(
|
| 251 |
+
pretrained_model_path,
|
| 252 |
+
subfolder="scheduler",
|
| 253 |
)
|
| 254 |
text_encoder = CLIPTextModel.from_pretrained(
|
| 255 |
+
pretrained_model_path,
|
| 256 |
+
subfolder="text_encoder",
|
| 257 |
)
|
| 258 |
tokenizer = CLIPTokenizer.from_pretrained(
|
| 259 |
+
pretrained_model_path,
|
| 260 |
+
subfolder="tokenizer",
|
| 261 |
)
|
| 262 |
vae = AutoencoderKL.from_pretrained(
|
| 263 |
+
pretrained_model_path,
|
| 264 |
+
subfolder="vae",
|
| 265 |
)
|
| 266 |
feature_extractor = CLIPImageProcessor.from_pretrained(
|
| 267 |
+
pretrained_model_path,
|
| 268 |
+
subfolder="feature_extractor",
|
| 269 |
)
|
| 270 |
+
|
| 271 |
unet = UNet2DConditionModel.from_pretrained(
|
| 272 |
+
ckpt_path,
|
| 273 |
+
subfolder="unet",
|
| 274 |
)
|
| 275 |
controlnet = ControlNetModel.from_pretrained(
|
| 276 |
+
ckpt_path,
|
| 277 |
+
subfolder="controlnet",
|
| 278 |
)
|
| 279 |
|
| 280 |
vae.requires_grad_(False)
|
|
|
|
| 283 |
controlnet.requires_grad_(False)
|
| 284 |
|
| 285 |
unet, vae, text_encoder = load_dreambooth_lora(
|
| 286 |
+
unet,
|
| 287 |
+
vae,
|
| 288 |
+
text_encoder,
|
| 289 |
+
dreambooth_lora_path,
|
| 290 |
)
|
| 291 |
|
| 292 |
text_encoder.to(device, dtype=weight_dtype)
|
|
|
|
| 325 |
|
| 326 |
|
| 327 |
@spaces.GPU(enable_queue=True)
|
| 328 |
+
def super_resolve_image(
|
| 329 |
input_image,
|
| 330 |
prompt,
|
| 331 |
+
added_prompt,
|
| 332 |
+
negative_prompt,
|
| 333 |
denoise_steps,
|
| 334 |
upscale,
|
| 335 |
alpha,
|
| 336 |
+
guidance_scale,
|
| 337 |
seed,
|
| 338 |
+
progress=gr.Progress(track_tqdm=True),
|
| 339 |
):
|
| 340 |
+
"""
|
| 341 |
+
Super-resolve an input image with PASD and optional prompt guidance.
|
| 342 |
+
|
| 343 |
+
Use this tool when you need to generate a higher-resolution restored image from an input image.
|
| 344 |
+
|
| 345 |
+
Args:
|
| 346 |
+
input_image (str): File path to the input image.
|
| 347 |
+
prompt (str): Main text prompt describing the desired image content.
|
| 348 |
+
added_prompt (str): Additional quality or style prompt appended to the main prompt.
|
| 349 |
+
negative_prompt (str): Negative prompt describing unwanted visual qualities.
|
| 350 |
+
denoise_steps (int): Number of denoising steps used by the diffusion pipeline.
|
| 351 |
+
upscale (int): Integer upscale factor applied to the image.
|
| 352 |
+
alpha (float): Conditioning scale passed to the ControlNet pipeline.
|
| 353 |
+
guidance_scale (float): Classifier-free guidance scale passed to the diffusion pipeline.
|
| 354 |
+
seed (int): Random seed, where -1 is converted to 0.
|
| 355 |
+
|
| 356 |
+
Returns:
|
| 357 |
+
tuple: Input image path, result image path, and downloadable result image path.
|
| 358 |
+
"""
|
| 359 |
if seed == -1:
|
| 360 |
seed = 0
|
| 361 |
|
|
|
|
| 378 |
if score >= 0.1:
|
| 379 |
prompt += f"{category_name}" if prompt == "" else f", {category_name}"
|
| 380 |
|
| 381 |
+
prompt = added_prompt if prompt == "" else f"{prompt}, {added_prompt}"
|
| 382 |
|
| 383 |
ori_width, ori_height = input_image.size
|
|
|
|
| 384 |
rscale = upscale
|
| 385 |
+
|
| 386 |
input_image = input_image.resize(
|
| 387 |
(input_image.size[0] * rscale, input_image.size[1] * rscale)
|
| 388 |
)
|
| 389 |
input_image = input_image.resize(
|
| 390 |
(input_image.size[0] // 8 * 8, input_image.size[1] // 8 * 8)
|
| 391 |
)
|
| 392 |
+
|
| 393 |
width, height = input_image.size
|
| 394 |
|
| 395 |
try:
|
|
|
|
| 401 |
generator=generator,
|
| 402 |
height=height,
|
| 403 |
width=width,
|
| 404 |
+
guidance_scale=guidance_scale,
|
| 405 |
+
negative_prompt=negative_prompt,
|
| 406 |
conditioning_scale=alpha,
|
| 407 |
eta=0.0,
|
| 408 |
).images[0]
|
| 409 |
|
| 410 |
image = wavelet_color_fix(image, input_image)
|
| 411 |
image = image.resize((ori_width * rscale, ori_height * rscale))
|
| 412 |
+
|
| 413 |
except Exception as e:
|
| 414 |
print(f"[inference] error: {e}")
|
| 415 |
image = Image.new(mode="RGB", size=(512, 512))
|
|
|
|
| 441 |
|
| 442 |
with gr.Blocks() as demo:
|
| 443 |
with gr.Column(elem_id="col-container"):
|
| 444 |
+
gr.Markdown("""
|
| 445 |
+
## PASD Magnify
|
| 446 |
+
|
| 447 |
+
Pixel-Aware Stable Diffusion for Realistic Image Super-resolution and Personalized Stylization
|
| 448 |
+
|
| 449 |
+
<a href='https://arxiv.org/abs/2308.14469' target='_blank'><img src='https://img.shields.io/badge/arXiv-2308.14469-red'></a> <a href='https://github.com/yangxy/PASD' target='_blank'><img src='https://img.shields.io/badge/GitHub-Code-blue'></a>
|
| 450 |
+
|
| 451 |
+
""")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 452 |
|
| 453 |
with gr.Row():
|
| 454 |
with gr.Column():
|
|
|
|
| 512 |
after_img = gr.Image(label="Result")
|
| 513 |
file_output = gr.File(label="Downloadable image result")
|
| 514 |
|
| 515 |
+
submit_btn.click(
|
| 516 |
+
fn=super_resolve_image,
|
| 517 |
+
inputs=[
|
| 518 |
+
input_image,
|
| 519 |
+
prompt_in,
|
| 520 |
+
added_prompt,
|
| 521 |
+
neg_prompt,
|
| 522 |
+
denoise_steps,
|
| 523 |
+
upsample_scale,
|
| 524 |
+
condition_scale,
|
| 525 |
+
classifier_free_guidance,
|
| 526 |
+
seed,
|
| 527 |
+
],
|
| 528 |
+
outputs=[
|
| 529 |
+
before_img,
|
| 530 |
+
after_img,
|
| 531 |
+
file_output,
|
| 532 |
+
],
|
| 533 |
+
api_visibility="public",
|
| 534 |
+
)
|
| 535 |
|
| 536 |
demo.queue(max_size=10).launch(
|
| 537 |
ssr_mode=False,
|
| 538 |
+
mcp_server=True,
|
| 539 |
css=css,
|
| 540 |
)
|