Spaces:
Runtime error
Runtime error
Delete app
Browse files- app/app_sana.py +0 -502
- app/app_sana_4bit.py +0 -409
- app/app_sana_4bit_compare_bf16.py +0 -313
- app/app_sana_controlnet_hed.py +0 -306
- app/app_sana_multithread.py +0 -565
- app/safety_check.py +0 -72
- app/sana_controlnet_pipeline.py +0 -353
- app/sana_pipeline.py +0 -304
app/app_sana.py
DELETED
|
@@ -1,502 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python
|
| 2 |
-
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
#
|
| 16 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 17 |
-
from __future__ import annotations
|
| 18 |
-
|
| 19 |
-
import argparse
|
| 20 |
-
import os
|
| 21 |
-
import random
|
| 22 |
-
import socket
|
| 23 |
-
import sqlite3
|
| 24 |
-
import time
|
| 25 |
-
import uuid
|
| 26 |
-
from datetime import datetime
|
| 27 |
-
|
| 28 |
-
import gradio as gr
|
| 29 |
-
import numpy as np
|
| 30 |
-
import spaces
|
| 31 |
-
import torch
|
| 32 |
-
from PIL import Image
|
| 33 |
-
from torchvision.utils import make_grid, save_image
|
| 34 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 35 |
-
|
| 36 |
-
from app import safety_check
|
| 37 |
-
from app.sana_pipeline import SanaPipeline
|
| 38 |
-
|
| 39 |
-
MAX_SEED = np.iinfo(np.int32).max
|
| 40 |
-
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
|
| 41 |
-
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
|
| 42 |
-
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
|
| 43 |
-
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
|
| 44 |
-
DEMO_PORT = int(os.getenv("DEMO_PORT", "15432"))
|
| 45 |
-
os.environ["GRADIO_EXAMPLES_CACHE"] = "./.gradio/cache"
|
| 46 |
-
COUNTER_DB = os.getenv("COUNTER_DB", ".count.db")
|
| 47 |
-
|
| 48 |
-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 49 |
-
|
| 50 |
-
style_list = [
|
| 51 |
-
{
|
| 52 |
-
"name": "(No style)",
|
| 53 |
-
"prompt": "{prompt}",
|
| 54 |
-
"negative_prompt": "",
|
| 55 |
-
},
|
| 56 |
-
{
|
| 57 |
-
"name": "Cinematic",
|
| 58 |
-
"prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, "
|
| 59 |
-
"cinemascope, moody, epic, gorgeous, film grain, grainy",
|
| 60 |
-
"negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
|
| 61 |
-
},
|
| 62 |
-
{
|
| 63 |
-
"name": "Photographic",
|
| 64 |
-
"prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
|
| 65 |
-
"negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
|
| 66 |
-
},
|
| 67 |
-
{
|
| 68 |
-
"name": "Anime",
|
| 69 |
-
"prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
|
| 70 |
-
"negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast",
|
| 71 |
-
},
|
| 72 |
-
{
|
| 73 |
-
"name": "Manga",
|
| 74 |
-
"prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style",
|
| 75 |
-
"negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
|
| 76 |
-
},
|
| 77 |
-
{
|
| 78 |
-
"name": "Digital Art",
|
| 79 |
-
"prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
|
| 80 |
-
"negative_prompt": "photo, photorealistic, realism, ugly",
|
| 81 |
-
},
|
| 82 |
-
{
|
| 83 |
-
"name": "Pixel art",
|
| 84 |
-
"prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
|
| 85 |
-
"negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
|
| 86 |
-
},
|
| 87 |
-
{
|
| 88 |
-
"name": "Fantasy art",
|
| 89 |
-
"prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, "
|
| 90 |
-
"majestic, magical, fantasy art, cover art, dreamy",
|
| 91 |
-
"negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, "
|
| 92 |
-
"glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, "
|
| 93 |
-
"disfigured, sloppy, duplicate, mutated, black and white",
|
| 94 |
-
},
|
| 95 |
-
{
|
| 96 |
-
"name": "Neonpunk",
|
| 97 |
-
"prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, "
|
| 98 |
-
"detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, "
|
| 99 |
-
"ultra detailed, intricate, professional",
|
| 100 |
-
"negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
|
| 101 |
-
},
|
| 102 |
-
{
|
| 103 |
-
"name": "3D Model",
|
| 104 |
-
"prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
|
| 105 |
-
"negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
|
| 106 |
-
},
|
| 107 |
-
]
|
| 108 |
-
|
| 109 |
-
styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
|
| 110 |
-
STYLE_NAMES = list(styles.keys())
|
| 111 |
-
DEFAULT_STYLE_NAME = "(No style)"
|
| 112 |
-
SCHEDULE_NAME = ["Flow_DPM_Solver"]
|
| 113 |
-
DEFAULT_SCHEDULE_NAME = "Flow_DPM_Solver"
|
| 114 |
-
NUM_IMAGES_PER_PROMPT = 1
|
| 115 |
-
INFER_SPEED = 0
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
def norm_ip(img, low, high):
|
| 119 |
-
img.clamp_(min=low, max=high)
|
| 120 |
-
img.sub_(low).div_(max(high - low, 1e-5))
|
| 121 |
-
return img
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
def open_db():
|
| 125 |
-
db = sqlite3.connect(COUNTER_DB)
|
| 126 |
-
db.execute("CREATE TABLE IF NOT EXISTS counter(app CHARS PRIMARY KEY UNIQUE, value INTEGER)")
|
| 127 |
-
db.execute('INSERT OR IGNORE INTO counter(app, value) VALUES("Sana", 0)')
|
| 128 |
-
return db
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
def read_inference_count():
|
| 132 |
-
with open_db() as db:
|
| 133 |
-
cur = db.execute('SELECT value FROM counter WHERE app="Sana"')
|
| 134 |
-
db.commit()
|
| 135 |
-
return cur.fetchone()[0]
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
def write_inference_count(count):
|
| 139 |
-
count = max(0, int(count))
|
| 140 |
-
with open_db() as db:
|
| 141 |
-
db.execute(f'UPDATE counter SET value=value+{count} WHERE app="Sana"')
|
| 142 |
-
db.commit()
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
def run_inference(num_imgs=1):
|
| 146 |
-
write_inference_count(num_imgs)
|
| 147 |
-
count = read_inference_count()
|
| 148 |
-
|
| 149 |
-
return (
|
| 150 |
-
f"<span style='font-size: 16px; font-weight: bold;'>Total inference runs: </span><span style='font-size: "
|
| 151 |
-
f"16px; color:red; font-weight: bold;'>{count}</span>"
|
| 152 |
-
)
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
def update_inference_count():
|
| 156 |
-
count = read_inference_count()
|
| 157 |
-
return (
|
| 158 |
-
f"<span style='font-size: 16px; font-weight: bold;'>Total inference runs: </span><span style='font-size: "
|
| 159 |
-
f"16px; color:red; font-weight: bold;'>{count}</span>"
|
| 160 |
-
)
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
|
| 164 |
-
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
|
| 165 |
-
if not negative:
|
| 166 |
-
negative = ""
|
| 167 |
-
return p.replace("{prompt}", positive), n + negative
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
def get_args():
|
| 171 |
-
parser = argparse.ArgumentParser()
|
| 172 |
-
parser.add_argument("--config", type=str, help="config")
|
| 173 |
-
parser.add_argument(
|
| 174 |
-
"--model_path",
|
| 175 |
-
nargs="?",
|
| 176 |
-
default="hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth",
|
| 177 |
-
type=str,
|
| 178 |
-
help="Path to the model file (positional)",
|
| 179 |
-
)
|
| 180 |
-
parser.add_argument("--output", default="./", type=str)
|
| 181 |
-
parser.add_argument("--bs", default=1, type=int)
|
| 182 |
-
parser.add_argument("--image_size", default=1024, type=int)
|
| 183 |
-
parser.add_argument("--cfg_scale", default=5.0, type=float)
|
| 184 |
-
parser.add_argument("--pag_scale", default=2.0, type=float)
|
| 185 |
-
parser.add_argument("--seed", default=42, type=int)
|
| 186 |
-
parser.add_argument("--step", default=-1, type=int)
|
| 187 |
-
parser.add_argument("--custom_image_size", default=None, type=int)
|
| 188 |
-
parser.add_argument("--share", action="store_true")
|
| 189 |
-
parser.add_argument(
|
| 190 |
-
"--shield_model_path",
|
| 191 |
-
type=str,
|
| 192 |
-
help="The path to shield model, we employ ShieldGemma-2B by default.",
|
| 193 |
-
default="google/shieldgemma-2b",
|
| 194 |
-
)
|
| 195 |
-
|
| 196 |
-
return parser.parse_known_args()[0]
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
args = get_args()
|
| 200 |
-
|
| 201 |
-
if torch.cuda.is_available():
|
| 202 |
-
model_path = args.model_path
|
| 203 |
-
pipe = SanaPipeline(args.config)
|
| 204 |
-
pipe.from_pretrained(model_path)
|
| 205 |
-
pipe.register_progress_bar(gr.Progress())
|
| 206 |
-
|
| 207 |
-
# safety checker
|
| 208 |
-
safety_checker_tokenizer = AutoTokenizer.from_pretrained(args.shield_model_path)
|
| 209 |
-
safety_checker_model = AutoModelForCausalLM.from_pretrained(
|
| 210 |
-
args.shield_model_path,
|
| 211 |
-
device_map="auto",
|
| 212 |
-
torch_dtype=torch.bfloat16,
|
| 213 |
-
).to(device)
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
def save_image_sana(img, seed="", save_img=False):
|
| 217 |
-
unique_name = f"{str(uuid.uuid4())}_{seed}.png"
|
| 218 |
-
save_path = os.path.join(f"output/online_demo_img/{datetime.now().date()}")
|
| 219 |
-
os.umask(0o000) # file permission: 666; dir permission: 777
|
| 220 |
-
os.makedirs(save_path, exist_ok=True)
|
| 221 |
-
unique_name = os.path.join(save_path, unique_name)
|
| 222 |
-
if save_img:
|
| 223 |
-
save_image(img, unique_name, nrow=1, normalize=True, value_range=(-1, 1))
|
| 224 |
-
|
| 225 |
-
return unique_name
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
| 229 |
-
if randomize_seed:
|
| 230 |
-
seed = random.randint(0, MAX_SEED)
|
| 231 |
-
return seed
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
@torch.no_grad()
|
| 235 |
-
@torch.inference_mode()
|
| 236 |
-
@spaces.GPU(enable_queue=True)
|
| 237 |
-
def generate(
|
| 238 |
-
prompt: str = None,
|
| 239 |
-
negative_prompt: str = "",
|
| 240 |
-
style: str = DEFAULT_STYLE_NAME,
|
| 241 |
-
use_negative_prompt: bool = False,
|
| 242 |
-
num_imgs: int = 1,
|
| 243 |
-
seed: int = 0,
|
| 244 |
-
height: int = 1024,
|
| 245 |
-
width: int = 1024,
|
| 246 |
-
flow_dpms_guidance_scale: float = 5.0,
|
| 247 |
-
flow_dpms_pag_guidance_scale: float = 2.0,
|
| 248 |
-
flow_dpms_inference_steps: int = 20,
|
| 249 |
-
randomize_seed: bool = False,
|
| 250 |
-
):
|
| 251 |
-
global INFER_SPEED
|
| 252 |
-
# seed = 823753551
|
| 253 |
-
box = run_inference(num_imgs)
|
| 254 |
-
seed = int(randomize_seed_fn(seed, randomize_seed))
|
| 255 |
-
generator = torch.Generator(device=device).manual_seed(seed)
|
| 256 |
-
print(f"PORT: {DEMO_PORT}, model_path: {model_path}")
|
| 257 |
-
if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt, threshold=0.2):
|
| 258 |
-
prompt = "A red heart."
|
| 259 |
-
|
| 260 |
-
print(prompt)
|
| 261 |
-
|
| 262 |
-
num_inference_steps = flow_dpms_inference_steps
|
| 263 |
-
guidance_scale = flow_dpms_guidance_scale
|
| 264 |
-
pag_guidance_scale = flow_dpms_pag_guidance_scale
|
| 265 |
-
|
| 266 |
-
if not use_negative_prompt:
|
| 267 |
-
negative_prompt = None # type: ignore
|
| 268 |
-
prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
|
| 269 |
-
|
| 270 |
-
pipe.progress_fn(0, desc="Sana Start")
|
| 271 |
-
|
| 272 |
-
time_start = time.time()
|
| 273 |
-
images = pipe(
|
| 274 |
-
prompt=prompt,
|
| 275 |
-
height=height,
|
| 276 |
-
width=width,
|
| 277 |
-
negative_prompt=negative_prompt,
|
| 278 |
-
guidance_scale=guidance_scale,
|
| 279 |
-
pag_guidance_scale=pag_guidance_scale,
|
| 280 |
-
num_inference_steps=num_inference_steps,
|
| 281 |
-
num_images_per_prompt=num_imgs,
|
| 282 |
-
generator=generator,
|
| 283 |
-
)
|
| 284 |
-
|
| 285 |
-
pipe.progress_fn(1.0, desc="Sana End")
|
| 286 |
-
INFER_SPEED = (time.time() - time_start) / num_imgs
|
| 287 |
-
|
| 288 |
-
save_img = False
|
| 289 |
-
if save_img:
|
| 290 |
-
img = [save_image_sana(img, seed, save_img=save_image) for img in images]
|
| 291 |
-
print(img)
|
| 292 |
-
else:
|
| 293 |
-
img = [
|
| 294 |
-
Image.fromarray(
|
| 295 |
-
norm_ip(img, -1, 1)
|
| 296 |
-
.mul(255)
|
| 297 |
-
.add_(0.5)
|
| 298 |
-
.clamp_(0, 255)
|
| 299 |
-
.permute(1, 2, 0)
|
| 300 |
-
.to("cpu", torch.uint8)
|
| 301 |
-
.numpy()
|
| 302 |
-
.astype(np.uint8)
|
| 303 |
-
)
|
| 304 |
-
for img in images
|
| 305 |
-
]
|
| 306 |
-
|
| 307 |
-
torch.cuda.empty_cache()
|
| 308 |
-
|
| 309 |
-
return (
|
| 310 |
-
img,
|
| 311 |
-
seed,
|
| 312 |
-
f"<span style='font-size: 16px; font-weight: bold;'>Inference Speed: {INFER_SPEED:.3f} s/Img</span>",
|
| 313 |
-
box,
|
| 314 |
-
)
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
model_size = "1.6" if "1600M" in args.model_path else "0.6"
|
| 318 |
-
title = f"""
|
| 319 |
-
<div style='display: flex; align-items: center; justify-content: center; text-align: center;'>
|
| 320 |
-
<img src="https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/logo.png" width="50%" alt="logo"/>
|
| 321 |
-
</div>
|
| 322 |
-
"""
|
| 323 |
-
DESCRIPTION = f"""
|
| 324 |
-
<p><span style="font-size: 36px; font-weight: bold;">Sana-{model_size}B</span><span style="font-size: 20px; font-weight: bold;">{args.image_size}px</span></p>
|
| 325 |
-
<p style="font-size: 16px; font-weight: bold;">Sana: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer</p>
|
| 326 |
-
<p><span style="font-size: 16px;"><a href="https://arxiv.org/abs/2410.10629">[Paper]</a></span> <span style="font-size: 16px;"><a href="https://github.com/NVlabs/Sana">[Github]</a></span> <span style="font-size: 16px;"><a href="https://nvlabs.github.io/Sana">[Project]</a></span</p>
|
| 327 |
-
<p style="font-size: 16px; font-weight: bold;">Powered by <a href="https://hanlab.mit.edu/projects/dc-ae">DC-AE</a> with 32x latent space, </p>running on node {socket.gethostname()}.
|
| 328 |
-
<p style="font-size: 16px; font-weight: bold;">Unsafe word will give you a 'Red Heart' in the image instead.</p>
|
| 329 |
-
"""
|
| 330 |
-
if model_size == "0.6":
|
| 331 |
-
DESCRIPTION += "\n<p>0.6B model's text rendering ability is limited.</p>"
|
| 332 |
-
if not torch.cuda.is_available():
|
| 333 |
-
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
|
| 334 |
-
|
| 335 |
-
examples = [
|
| 336 |
-
'a cyberpunk cat with a neon sign that says "Sana"',
|
| 337 |
-
"A very detailed and realistic full body photo set of a tall, slim, and athletic Shiba Inu in a white oversized straight t-shirt, white shorts, and short white shoes.",
|
| 338 |
-
"Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, volumetric lighting, spectacular, ambient lights, light pollution, cinematic atmosphere, art nouveau style, illustration art artwork by SenseiJaye, intricate detail.",
|
| 339 |
-
"portrait photo of a girl, photograph, highly detailed face, depth of field",
|
| 340 |
-
'make me a logo that says "So Fast" with a really cool flying dragon shape with lightning sparks all over the sides and all of it contains Indonesian language',
|
| 341 |
-
"🐶 Wearing 🕶 flying on the 🌈",
|
| 342 |
-
"👧 with 🌹 in the ❄️",
|
| 343 |
-
"an old rusted robot wearing pants and a jacket riding skis in a supermarket.",
|
| 344 |
-
"professional portrait photo of an anthropomorphic cat wearing fancy gentleman hat and jacket walking in autumn forest.",
|
| 345 |
-
"Astronaut in a jungle, cold color palette, muted colors, detailed",
|
| 346 |
-
"a stunning and luxurious bedroom carved into a rocky mountainside seamlessly blending nature with modern design with a plush earth-toned bed textured stone walls circular fireplace massive uniquely shaped window framing snow-capped mountains dense forests",
|
| 347 |
-
]
|
| 348 |
-
|
| 349 |
-
css = """
|
| 350 |
-
.gradio-container{max-width: 640px !important}
|
| 351 |
-
h1{text-align:center}
|
| 352 |
-
"""
|
| 353 |
-
with gr.Blocks(css=css, title="Sana") as demo:
|
| 354 |
-
gr.Markdown(title)
|
| 355 |
-
gr.HTML(DESCRIPTION)
|
| 356 |
-
gr.DuplicateButton(
|
| 357 |
-
value="Duplicate Space for private use",
|
| 358 |
-
elem_id="duplicate-button",
|
| 359 |
-
visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
|
| 360 |
-
)
|
| 361 |
-
info_box = gr.Markdown(
|
| 362 |
-
value=f"<span style='font-size: 16px; font-weight: bold;'>Total inference runs: </span><span style='font-size: 16px; color:red; font-weight: bold;'>{read_inference_count()}</span>"
|
| 363 |
-
)
|
| 364 |
-
demo.load(fn=update_inference_count, outputs=info_box) # update the value when re-loading the page
|
| 365 |
-
# with gr.Row(equal_height=False):
|
| 366 |
-
with gr.Group():
|
| 367 |
-
with gr.Row():
|
| 368 |
-
prompt = gr.Text(
|
| 369 |
-
label="Prompt",
|
| 370 |
-
show_label=False,
|
| 371 |
-
max_lines=1,
|
| 372 |
-
placeholder="Enter your prompt",
|
| 373 |
-
container=False,
|
| 374 |
-
)
|
| 375 |
-
run_button = gr.Button("Run", scale=0)
|
| 376 |
-
result = gr.Gallery(label="Result", show_label=False, columns=NUM_IMAGES_PER_PROMPT, format="png")
|
| 377 |
-
speed_box = gr.Markdown(
|
| 378 |
-
value=f"<span style='font-size: 16px; font-weight: bold;'>Inference speed: {INFER_SPEED} s/Img</span>"
|
| 379 |
-
)
|
| 380 |
-
with gr.Accordion("Advanced options", open=False):
|
| 381 |
-
with gr.Group():
|
| 382 |
-
with gr.Row(visible=True):
|
| 383 |
-
height = gr.Slider(
|
| 384 |
-
label="Height",
|
| 385 |
-
minimum=256,
|
| 386 |
-
maximum=MAX_IMAGE_SIZE,
|
| 387 |
-
step=32,
|
| 388 |
-
value=args.image_size,
|
| 389 |
-
)
|
| 390 |
-
width = gr.Slider(
|
| 391 |
-
label="Width",
|
| 392 |
-
minimum=256,
|
| 393 |
-
maximum=MAX_IMAGE_SIZE,
|
| 394 |
-
step=32,
|
| 395 |
-
value=args.image_size,
|
| 396 |
-
)
|
| 397 |
-
with gr.Row():
|
| 398 |
-
flow_dpms_inference_steps = gr.Slider(
|
| 399 |
-
label="Sampling steps",
|
| 400 |
-
minimum=5,
|
| 401 |
-
maximum=40,
|
| 402 |
-
step=1,
|
| 403 |
-
value=20,
|
| 404 |
-
)
|
| 405 |
-
flow_dpms_guidance_scale = gr.Slider(
|
| 406 |
-
label="CFG Guidance scale",
|
| 407 |
-
minimum=1,
|
| 408 |
-
maximum=10,
|
| 409 |
-
step=0.1,
|
| 410 |
-
value=4.5,
|
| 411 |
-
)
|
| 412 |
-
flow_dpms_pag_guidance_scale = gr.Slider(
|
| 413 |
-
label="PAG Guidance scale",
|
| 414 |
-
minimum=1,
|
| 415 |
-
maximum=4,
|
| 416 |
-
step=0.5,
|
| 417 |
-
value=1.0,
|
| 418 |
-
)
|
| 419 |
-
with gr.Row():
|
| 420 |
-
use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False, visible=True)
|
| 421 |
-
negative_prompt = gr.Text(
|
| 422 |
-
label="Negative prompt",
|
| 423 |
-
max_lines=1,
|
| 424 |
-
placeholder="Enter a negative prompt",
|
| 425 |
-
visible=True,
|
| 426 |
-
)
|
| 427 |
-
style_selection = gr.Radio(
|
| 428 |
-
show_label=True,
|
| 429 |
-
container=True,
|
| 430 |
-
interactive=True,
|
| 431 |
-
choices=STYLE_NAMES,
|
| 432 |
-
value=DEFAULT_STYLE_NAME,
|
| 433 |
-
label="Image Style",
|
| 434 |
-
)
|
| 435 |
-
seed = gr.Slider(
|
| 436 |
-
label="Seed",
|
| 437 |
-
minimum=0,
|
| 438 |
-
maximum=MAX_SEED,
|
| 439 |
-
step=1,
|
| 440 |
-
value=0,
|
| 441 |
-
)
|
| 442 |
-
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
| 443 |
-
with gr.Row(visible=True):
|
| 444 |
-
schedule = gr.Radio(
|
| 445 |
-
show_label=True,
|
| 446 |
-
container=True,
|
| 447 |
-
interactive=True,
|
| 448 |
-
choices=SCHEDULE_NAME,
|
| 449 |
-
value=DEFAULT_SCHEDULE_NAME,
|
| 450 |
-
label="Sampler Schedule",
|
| 451 |
-
visible=True,
|
| 452 |
-
)
|
| 453 |
-
num_imgs = gr.Slider(
|
| 454 |
-
label="Num Images",
|
| 455 |
-
minimum=1,
|
| 456 |
-
maximum=6,
|
| 457 |
-
step=1,
|
| 458 |
-
value=1,
|
| 459 |
-
)
|
| 460 |
-
|
| 461 |
-
gr.Examples(
|
| 462 |
-
examples=examples,
|
| 463 |
-
inputs=prompt,
|
| 464 |
-
outputs=[result, seed],
|
| 465 |
-
fn=generate,
|
| 466 |
-
cache_examples=CACHE_EXAMPLES,
|
| 467 |
-
)
|
| 468 |
-
|
| 469 |
-
use_negative_prompt.change(
|
| 470 |
-
fn=lambda x: gr.update(visible=x),
|
| 471 |
-
inputs=use_negative_prompt,
|
| 472 |
-
outputs=negative_prompt,
|
| 473 |
-
api_name=False,
|
| 474 |
-
)
|
| 475 |
-
|
| 476 |
-
gr.on(
|
| 477 |
-
triggers=[
|
| 478 |
-
prompt.submit,
|
| 479 |
-
negative_prompt.submit,
|
| 480 |
-
run_button.click,
|
| 481 |
-
],
|
| 482 |
-
fn=generate,
|
| 483 |
-
inputs=[
|
| 484 |
-
prompt,
|
| 485 |
-
negative_prompt,
|
| 486 |
-
style_selection,
|
| 487 |
-
use_negative_prompt,
|
| 488 |
-
num_imgs,
|
| 489 |
-
seed,
|
| 490 |
-
height,
|
| 491 |
-
width,
|
| 492 |
-
flow_dpms_guidance_scale,
|
| 493 |
-
flow_dpms_pag_guidance_scale,
|
| 494 |
-
flow_dpms_inference_steps,
|
| 495 |
-
randomize_seed,
|
| 496 |
-
],
|
| 497 |
-
outputs=[result, seed, speed_box, info_box],
|
| 498 |
-
api_name="run",
|
| 499 |
-
)
|
| 500 |
-
|
| 501 |
-
if __name__ == "__main__":
|
| 502 |
-
demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=DEMO_PORT, debug=False, share=args.share)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app/app_sana_4bit.py
DELETED
|
@@ -1,409 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python
|
| 2 |
-
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
#!/usr/bin/env python
|
| 6 |
-
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 7 |
-
#
|
| 8 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 9 |
-
# you may not use this file except in compliance with the License.
|
| 10 |
-
# You may obtain a copy of the License at
|
| 11 |
-
#
|
| 12 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 13 |
-
#
|
| 14 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 15 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 16 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 17 |
-
# See the License for the specific language governing permissions and
|
| 18 |
-
# limitations under the License.
|
| 19 |
-
#
|
| 20 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 21 |
-
from __future__ import annotations
|
| 22 |
-
|
| 23 |
-
import argparse
|
| 24 |
-
import os
|
| 25 |
-
import random
|
| 26 |
-
import time
|
| 27 |
-
import uuid
|
| 28 |
-
from datetime import datetime
|
| 29 |
-
|
| 30 |
-
import gradio as gr
|
| 31 |
-
import numpy as np
|
| 32 |
-
import spaces
|
| 33 |
-
import torch
|
| 34 |
-
from diffusers import SanaPipeline
|
| 35 |
-
from nunchaku.models.transformer_sana import NunchakuSanaTransformer2DModel
|
| 36 |
-
from torchvision.utils import save_image
|
| 37 |
-
|
| 38 |
-
MAX_SEED = np.iinfo(np.int32).max
|
| 39 |
-
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
|
| 40 |
-
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
|
| 41 |
-
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
|
| 42 |
-
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
|
| 43 |
-
DEMO_PORT = int(os.getenv("DEMO_PORT", "15432"))
|
| 44 |
-
os.environ["GRADIO_EXAMPLES_CACHE"] = "./.gradio/cache"
|
| 45 |
-
COUNTER_DB = os.getenv("COUNTER_DB", ".count.db")
|
| 46 |
-
INFER_SPEED = 0
|
| 47 |
-
|
| 48 |
-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 49 |
-
|
| 50 |
-
style_list = [
|
| 51 |
-
{
|
| 52 |
-
"name": "(No style)",
|
| 53 |
-
"prompt": "{prompt}",
|
| 54 |
-
"negative_prompt": "",
|
| 55 |
-
},
|
| 56 |
-
{
|
| 57 |
-
"name": "Cinematic",
|
| 58 |
-
"prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, "
|
| 59 |
-
"cinemascope, moody, epic, gorgeous, film grain, grainy",
|
| 60 |
-
"negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
|
| 61 |
-
},
|
| 62 |
-
{
|
| 63 |
-
"name": "Photographic",
|
| 64 |
-
"prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
|
| 65 |
-
"negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
|
| 66 |
-
},
|
| 67 |
-
{
|
| 68 |
-
"name": "Anime",
|
| 69 |
-
"prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
|
| 70 |
-
"negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast",
|
| 71 |
-
},
|
| 72 |
-
{
|
| 73 |
-
"name": "Manga",
|
| 74 |
-
"prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style",
|
| 75 |
-
"negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
|
| 76 |
-
},
|
| 77 |
-
{
|
| 78 |
-
"name": "Digital Art",
|
| 79 |
-
"prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
|
| 80 |
-
"negative_prompt": "photo, photorealistic, realism, ugly",
|
| 81 |
-
},
|
| 82 |
-
{
|
| 83 |
-
"name": "Pixel art",
|
| 84 |
-
"prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
|
| 85 |
-
"negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
|
| 86 |
-
},
|
| 87 |
-
{
|
| 88 |
-
"name": "Fantasy art",
|
| 89 |
-
"prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, "
|
| 90 |
-
"majestic, magical, fantasy art, cover art, dreamy",
|
| 91 |
-
"negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, "
|
| 92 |
-
"glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, "
|
| 93 |
-
"disfigured, sloppy, duplicate, mutated, black and white",
|
| 94 |
-
},
|
| 95 |
-
{
|
| 96 |
-
"name": "Neonpunk",
|
| 97 |
-
"prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, "
|
| 98 |
-
"detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, "
|
| 99 |
-
"ultra detailed, intricate, professional",
|
| 100 |
-
"negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
|
| 101 |
-
},
|
| 102 |
-
{
|
| 103 |
-
"name": "3D Model",
|
| 104 |
-
"prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
|
| 105 |
-
"negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
|
| 106 |
-
},
|
| 107 |
-
]
|
| 108 |
-
|
| 109 |
-
styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
|
| 110 |
-
STYLE_NAMES = list(styles.keys())
|
| 111 |
-
DEFAULT_STYLE_NAME = "(No style)"
|
| 112 |
-
SCHEDULE_NAME = ["Flow_DPM_Solver"]
|
| 113 |
-
DEFAULT_SCHEDULE_NAME = "Flow_DPM_Solver"
|
| 114 |
-
NUM_IMAGES_PER_PROMPT = 1
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
|
| 118 |
-
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
|
| 119 |
-
if not negative:
|
| 120 |
-
negative = ""
|
| 121 |
-
return p.replace("{prompt}", positive), n + negative
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
def get_args():
|
| 125 |
-
parser = argparse.ArgumentParser()
|
| 126 |
-
parser.add_argument(
|
| 127 |
-
"--model_path",
|
| 128 |
-
nargs="?",
|
| 129 |
-
default="Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
|
| 130 |
-
type=str,
|
| 131 |
-
help="Path to the model file (positional)",
|
| 132 |
-
)
|
| 133 |
-
parser.add_argument("--share", action="store_true")
|
| 134 |
-
|
| 135 |
-
return parser.parse_known_args()[0]
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
args = get_args()
|
| 139 |
-
|
| 140 |
-
if torch.cuda.is_available():
|
| 141 |
-
|
| 142 |
-
transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m")
|
| 143 |
-
pipe = SanaPipeline.from_pretrained(
|
| 144 |
-
"Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
|
| 145 |
-
transformer=transformer,
|
| 146 |
-
variant="bf16",
|
| 147 |
-
torch_dtype=torch.bfloat16,
|
| 148 |
-
).to(device)
|
| 149 |
-
|
| 150 |
-
pipe.text_encoder.to(torch.bfloat16)
|
| 151 |
-
pipe.vae.to(torch.bfloat16)
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
def save_image_sana(img, seed="", save_img=False):
|
| 155 |
-
unique_name = f"{str(uuid.uuid4())}_{seed}.png"
|
| 156 |
-
save_path = os.path.join(f"output/online_demo_img/{datetime.now().date()}")
|
| 157 |
-
os.umask(0o000) # file permission: 666; dir permission: 777
|
| 158 |
-
os.makedirs(save_path, exist_ok=True)
|
| 159 |
-
unique_name = os.path.join(save_path, unique_name)
|
| 160 |
-
if save_img:
|
| 161 |
-
save_image(img, unique_name, nrow=1, normalize=True, value_range=(-1, 1))
|
| 162 |
-
|
| 163 |
-
return unique_name
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
| 167 |
-
if randomize_seed:
|
| 168 |
-
seed = random.randint(0, MAX_SEED)
|
| 169 |
-
return seed
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
@torch.no_grad()
|
| 173 |
-
@torch.inference_mode()
|
| 174 |
-
@spaces.GPU(enable_queue=True)
|
| 175 |
-
def generate(
|
| 176 |
-
prompt: str = None,
|
| 177 |
-
negative_prompt: str = "",
|
| 178 |
-
style: str = DEFAULT_STYLE_NAME,
|
| 179 |
-
use_negative_prompt: bool = False,
|
| 180 |
-
num_imgs: int = 1,
|
| 181 |
-
seed: int = 0,
|
| 182 |
-
height: int = 1024,
|
| 183 |
-
width: int = 1024,
|
| 184 |
-
flow_dpms_guidance_scale: float = 5.0,
|
| 185 |
-
flow_dpms_inference_steps: int = 20,
|
| 186 |
-
randomize_seed: bool = False,
|
| 187 |
-
):
|
| 188 |
-
global INFER_SPEED
|
| 189 |
-
# seed = 823753551
|
| 190 |
-
seed = int(randomize_seed_fn(seed, randomize_seed))
|
| 191 |
-
generator = torch.Generator(device=device).manual_seed(seed)
|
| 192 |
-
print(f"PORT: {DEMO_PORT}, model_path: {args.model_path}")
|
| 193 |
-
|
| 194 |
-
print(prompt)
|
| 195 |
-
|
| 196 |
-
num_inference_steps = flow_dpms_inference_steps
|
| 197 |
-
guidance_scale = flow_dpms_guidance_scale
|
| 198 |
-
|
| 199 |
-
if not use_negative_prompt:
|
| 200 |
-
negative_prompt = None # type: ignore
|
| 201 |
-
prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
|
| 202 |
-
|
| 203 |
-
time_start = time.time()
|
| 204 |
-
images = pipe(
|
| 205 |
-
prompt=prompt,
|
| 206 |
-
height=height,
|
| 207 |
-
width=width,
|
| 208 |
-
negative_prompt=negative_prompt,
|
| 209 |
-
guidance_scale=guidance_scale,
|
| 210 |
-
num_inference_steps=num_inference_steps,
|
| 211 |
-
num_images_per_prompt=num_imgs,
|
| 212 |
-
generator=generator,
|
| 213 |
-
).images
|
| 214 |
-
INFER_SPEED = (time.time() - time_start) / num_imgs
|
| 215 |
-
|
| 216 |
-
save_img = False
|
| 217 |
-
if save_img:
|
| 218 |
-
img = [save_image_sana(img, seed, save_img=save_image) for img in images]
|
| 219 |
-
print(img)
|
| 220 |
-
else:
|
| 221 |
-
img = images
|
| 222 |
-
|
| 223 |
-
torch.cuda.empty_cache()
|
| 224 |
-
|
| 225 |
-
return (
|
| 226 |
-
img,
|
| 227 |
-
seed,
|
| 228 |
-
f"<span style='font-size: 16px; font-weight: bold;'>Inference Speed: {INFER_SPEED:.3f} s/Img</span>",
|
| 229 |
-
)
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
model_size = "1.6" if "1600M" in args.model_path else "0.6"
|
| 233 |
-
title = f"""
|
| 234 |
-
<div style='display: flex; align-items: center; justify-content: center; text-align: center;'>
|
| 235 |
-
<img src="https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/logo.png" width="30%" alt="logo"/>
|
| 236 |
-
</div>
|
| 237 |
-
"""
|
| 238 |
-
DESCRIPTION = f"""
|
| 239 |
-
<p style="font-size: 30px; font-weight: bold; text-align: center;">Sana: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer (4bit version)</p>
|
| 240 |
-
"""
|
| 241 |
-
if model_size == "0.6":
|
| 242 |
-
DESCRIPTION += "\n<p>0.6B model's text rendering ability is limited.</p>"
|
| 243 |
-
if not torch.cuda.is_available():
|
| 244 |
-
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
|
| 245 |
-
|
| 246 |
-
examples = [
|
| 247 |
-
'a cyberpunk cat with a neon sign that says "Sana"',
|
| 248 |
-
"A very detailed and realistic full body photo set of a tall, slim, and athletic Shiba Inu in a white oversized straight t-shirt, white shorts, and short white shoes.",
|
| 249 |
-
"Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, volumetric lighting, spectacular, ambient lights, light pollution, cinematic atmosphere, art nouveau style, illustration art artwork by SenseiJaye, intricate detail.",
|
| 250 |
-
"portrait photo of a girl, photograph, highly detailed face, depth of field",
|
| 251 |
-
'make me a logo that says "So Fast" with a really cool flying dragon shape with lightning sparks all over the sides and all of it contains Indonesian language',
|
| 252 |
-
"🐶 Wearing 🕶 flying on the 🌈",
|
| 253 |
-
"👧 with 🌹 in the ❄️",
|
| 254 |
-
"an old rusted robot wearing pants and a jacket riding skis in a supermarket.",
|
| 255 |
-
"professional portrait photo of an anthropomorphic cat wearing fancy gentleman hat and jacket walking in autumn forest.",
|
| 256 |
-
"Astronaut in a jungle, cold color palette, muted colors, detailed",
|
| 257 |
-
"a stunning and luxurious bedroom carved into a rocky mountainside seamlessly blending nature with modern design with a plush earth-toned bed textured stone walls circular fireplace massive uniquely shaped window framing snow-capped mountains dense forests",
|
| 258 |
-
]
|
| 259 |
-
|
| 260 |
-
css = """
|
| 261 |
-
.gradio-container {max-width: 850px !important; height: auto !important;}
|
| 262 |
-
h1 {text-align: center;}
|
| 263 |
-
"""
|
| 264 |
-
theme = gr.themes.Base()
|
| 265 |
-
with gr.Blocks(css=css, theme=theme, title="Sana") as demo:
|
| 266 |
-
gr.Markdown(title)
|
| 267 |
-
gr.HTML(DESCRIPTION)
|
| 268 |
-
gr.DuplicateButton(
|
| 269 |
-
value="Duplicate Space for private use",
|
| 270 |
-
elem_id="duplicate-button",
|
| 271 |
-
visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
|
| 272 |
-
)
|
| 273 |
-
# with gr.Row(equal_height=False):
|
| 274 |
-
with gr.Group():
|
| 275 |
-
with gr.Row():
|
| 276 |
-
prompt = gr.Text(
|
| 277 |
-
label="Prompt",
|
| 278 |
-
show_label=False,
|
| 279 |
-
max_lines=1,
|
| 280 |
-
placeholder="Enter your prompt",
|
| 281 |
-
container=False,
|
| 282 |
-
)
|
| 283 |
-
run_button = gr.Button("Run", scale=0)
|
| 284 |
-
result = gr.Gallery(
|
| 285 |
-
label="Result",
|
| 286 |
-
show_label=False,
|
| 287 |
-
height=750,
|
| 288 |
-
columns=NUM_IMAGES_PER_PROMPT,
|
| 289 |
-
format="jpeg",
|
| 290 |
-
)
|
| 291 |
-
|
| 292 |
-
speed_box = gr.Markdown(
|
| 293 |
-
value=f"<span style='font-size: 16px; font-weight: bold;'>Inference speed: {INFER_SPEED} s/Img</span>"
|
| 294 |
-
)
|
| 295 |
-
with gr.Accordion("Advanced options", open=False):
|
| 296 |
-
with gr.Group():
|
| 297 |
-
with gr.Row(visible=True):
|
| 298 |
-
height = gr.Slider(
|
| 299 |
-
label="Height",
|
| 300 |
-
minimum=256,
|
| 301 |
-
maximum=MAX_IMAGE_SIZE,
|
| 302 |
-
step=32,
|
| 303 |
-
value=1024,
|
| 304 |
-
)
|
| 305 |
-
width = gr.Slider(
|
| 306 |
-
label="Width",
|
| 307 |
-
minimum=256,
|
| 308 |
-
maximum=MAX_IMAGE_SIZE,
|
| 309 |
-
step=32,
|
| 310 |
-
value=1024,
|
| 311 |
-
)
|
| 312 |
-
with gr.Row():
|
| 313 |
-
flow_dpms_inference_steps = gr.Slider(
|
| 314 |
-
label="Sampling steps",
|
| 315 |
-
minimum=5,
|
| 316 |
-
maximum=40,
|
| 317 |
-
step=1,
|
| 318 |
-
value=20,
|
| 319 |
-
)
|
| 320 |
-
flow_dpms_guidance_scale = gr.Slider(
|
| 321 |
-
label="CFG Guidance scale",
|
| 322 |
-
minimum=1,
|
| 323 |
-
maximum=10,
|
| 324 |
-
step=0.1,
|
| 325 |
-
value=4.5,
|
| 326 |
-
)
|
| 327 |
-
with gr.Row():
|
| 328 |
-
use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False, visible=True)
|
| 329 |
-
negative_prompt = gr.Text(
|
| 330 |
-
label="Negative prompt",
|
| 331 |
-
max_lines=1,
|
| 332 |
-
placeholder="Enter a negative prompt",
|
| 333 |
-
visible=True,
|
| 334 |
-
)
|
| 335 |
-
style_selection = gr.Radio(
|
| 336 |
-
show_label=True,
|
| 337 |
-
container=True,
|
| 338 |
-
interactive=True,
|
| 339 |
-
choices=STYLE_NAMES,
|
| 340 |
-
value=DEFAULT_STYLE_NAME,
|
| 341 |
-
label="Image Style",
|
| 342 |
-
)
|
| 343 |
-
seed = gr.Slider(
|
| 344 |
-
label="Seed",
|
| 345 |
-
minimum=0,
|
| 346 |
-
maximum=MAX_SEED,
|
| 347 |
-
step=1,
|
| 348 |
-
value=0,
|
| 349 |
-
)
|
| 350 |
-
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
| 351 |
-
with gr.Row(visible=True):
|
| 352 |
-
schedule = gr.Radio(
|
| 353 |
-
show_label=True,
|
| 354 |
-
container=True,
|
| 355 |
-
interactive=True,
|
| 356 |
-
choices=SCHEDULE_NAME,
|
| 357 |
-
value=DEFAULT_SCHEDULE_NAME,
|
| 358 |
-
label="Sampler Schedule",
|
| 359 |
-
visible=True,
|
| 360 |
-
)
|
| 361 |
-
num_imgs = gr.Slider(
|
| 362 |
-
label="Num Images",
|
| 363 |
-
minimum=1,
|
| 364 |
-
maximum=6,
|
| 365 |
-
step=1,
|
| 366 |
-
value=1,
|
| 367 |
-
)
|
| 368 |
-
|
| 369 |
-
gr.Examples(
|
| 370 |
-
examples=examples,
|
| 371 |
-
inputs=prompt,
|
| 372 |
-
outputs=[result, seed],
|
| 373 |
-
fn=generate,
|
| 374 |
-
cache_examples=CACHE_EXAMPLES,
|
| 375 |
-
)
|
| 376 |
-
|
| 377 |
-
use_negative_prompt.change(
|
| 378 |
-
fn=lambda x: gr.update(visible=x),
|
| 379 |
-
inputs=use_negative_prompt,
|
| 380 |
-
outputs=negative_prompt,
|
| 381 |
-
api_name=False,
|
| 382 |
-
)
|
| 383 |
-
|
| 384 |
-
gr.on(
|
| 385 |
-
triggers=[
|
| 386 |
-
prompt.submit,
|
| 387 |
-
negative_prompt.submit,
|
| 388 |
-
run_button.click,
|
| 389 |
-
],
|
| 390 |
-
fn=generate,
|
| 391 |
-
inputs=[
|
| 392 |
-
prompt,
|
| 393 |
-
negative_prompt,
|
| 394 |
-
style_selection,
|
| 395 |
-
use_negative_prompt,
|
| 396 |
-
num_imgs,
|
| 397 |
-
seed,
|
| 398 |
-
height,
|
| 399 |
-
width,
|
| 400 |
-
flow_dpms_guidance_scale,
|
| 401 |
-
flow_dpms_inference_steps,
|
| 402 |
-
randomize_seed,
|
| 403 |
-
],
|
| 404 |
-
outputs=[result, seed, speed_box],
|
| 405 |
-
api_name="run",
|
| 406 |
-
)
|
| 407 |
-
|
| 408 |
-
if __name__ == "__main__":
|
| 409 |
-
demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=DEMO_PORT, debug=False, share=args.share)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app/app_sana_4bit_compare_bf16.py
DELETED
|
@@ -1,313 +0,0 @@
|
|
| 1 |
-
# Changed from https://huggingface.co/spaces/playgroundai/playground-v2.5/blob/main/app.py
|
| 2 |
-
import argparse
|
| 3 |
-
import os
|
| 4 |
-
import random
|
| 5 |
-
import time
|
| 6 |
-
from datetime import datetime
|
| 7 |
-
|
| 8 |
-
import GPUtil
|
| 9 |
-
|
| 10 |
-
# import gradio last to avoid conflicts with other imports
|
| 11 |
-
import gradio as gr
|
| 12 |
-
import safety_check
|
| 13 |
-
import spaces
|
| 14 |
-
import torch
|
| 15 |
-
from diffusers import SanaPipeline
|
| 16 |
-
from nunchaku.models.transformer_sana import NunchakuSanaTransformer2DModel
|
| 17 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 18 |
-
|
| 19 |
-
MAX_IMAGE_SIZE = 2048
|
| 20 |
-
MAX_SEED = 1000000000
|
| 21 |
-
|
| 22 |
-
DEFAULT_HEIGHT = 1024
|
| 23 |
-
DEFAULT_WIDTH = 1024
|
| 24 |
-
|
| 25 |
-
# num_inference_steps, guidance_scale, seed
|
| 26 |
-
EXAMPLES = [
|
| 27 |
-
[
|
| 28 |
-
"🐶 Wearing 🕶 flying on the 🌈",
|
| 29 |
-
1024,
|
| 30 |
-
1024,
|
| 31 |
-
20,
|
| 32 |
-
5,
|
| 33 |
-
2,
|
| 34 |
-
],
|
| 35 |
-
[
|
| 36 |
-
"大漠孤烟直, 长河落日圆",
|
| 37 |
-
1024,
|
| 38 |
-
1024,
|
| 39 |
-
20,
|
| 40 |
-
5,
|
| 41 |
-
23,
|
| 42 |
-
],
|
| 43 |
-
[
|
| 44 |
-
"Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, "
|
| 45 |
-
"volumetric lighting, spectacular, ambient lights, light pollution, cinematic atmosphere, "
|
| 46 |
-
"art nouveau style, illustration art artwork by SenseiJaye, intricate detail.",
|
| 47 |
-
1024,
|
| 48 |
-
1024,
|
| 49 |
-
20,
|
| 50 |
-
5,
|
| 51 |
-
233,
|
| 52 |
-
],
|
| 53 |
-
[
|
| 54 |
-
"A photo of a Eurasian lynx in a sunlit forest, with tufted ears and a spotted coat. The lynx should be "
|
| 55 |
-
"sharply focused, gazing into the distance, while the background is softly blurred for depth. Use cinematic "
|
| 56 |
-
"lighting with soft rays filtering through the trees, and capture the scene with a shallow depth of field "
|
| 57 |
-
"for a natural, peaceful atmosphere. 8K resolution, highly detailed, photorealistic, "
|
| 58 |
-
"cinematic lighting, ultra-HD.",
|
| 59 |
-
1024,
|
| 60 |
-
1024,
|
| 61 |
-
20,
|
| 62 |
-
5,
|
| 63 |
-
2333,
|
| 64 |
-
],
|
| 65 |
-
[
|
| 66 |
-
"A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. "
|
| 67 |
-
"She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. "
|
| 68 |
-
"She wears sunglasses and red lipstick. She walks confidently and casually. "
|
| 69 |
-
"The street is damp and reflective, creating a mirror effect of the colorful lights. "
|
| 70 |
-
"Many pedestrians walk about.",
|
| 71 |
-
1024,
|
| 72 |
-
1024,
|
| 73 |
-
20,
|
| 74 |
-
5,
|
| 75 |
-
23333,
|
| 76 |
-
],
|
| 77 |
-
[
|
| 78 |
-
"Cozy bedroom with vintage wooden furniture and a large circular window covered in lush green vines, "
|
| 79 |
-
"opening to a misty forest. Soft, ambient lighting highlights the bed with crumpled blankets, a bookshelf, "
|
| 80 |
-
"and a desk. The atmosphere is serene and natural. 8K resolution, highly detailed, photorealistic, "
|
| 81 |
-
"cinematic lighting, ultra-HD.",
|
| 82 |
-
1024,
|
| 83 |
-
1024,
|
| 84 |
-
20,
|
| 85 |
-
5,
|
| 86 |
-
233333,
|
| 87 |
-
],
|
| 88 |
-
]
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
def hash_str_to_int(s: str) -> int:
|
| 92 |
-
"""Hash a string to an integer."""
|
| 93 |
-
modulus = 10**9 + 7 # Large prime modulus
|
| 94 |
-
hash_int = 0
|
| 95 |
-
for char in s:
|
| 96 |
-
hash_int = (hash_int * 31 + ord(char)) % modulus
|
| 97 |
-
return hash_int
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
def get_pipeline(
|
| 101 |
-
precision: str, use_qencoder: bool = False, device: str | torch.device = "cuda", pipeline_init_kwargs: dict = {}
|
| 102 |
-
) -> SanaPipeline:
|
| 103 |
-
if precision == "int4":
|
| 104 |
-
assert torch.device(device).type == "cuda", "int4 only supported on CUDA devices"
|
| 105 |
-
transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m")
|
| 106 |
-
|
| 107 |
-
pipeline_init_kwargs["transformer"] = transformer
|
| 108 |
-
if use_qencoder:
|
| 109 |
-
raise NotImplementedError("Quantized encoder not supported for Sana for now")
|
| 110 |
-
else:
|
| 111 |
-
assert precision == "bf16"
|
| 112 |
-
pipeline = SanaPipeline.from_pretrained(
|
| 113 |
-
"Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
|
| 114 |
-
variant="bf16",
|
| 115 |
-
torch_dtype=torch.bfloat16,
|
| 116 |
-
**pipeline_init_kwargs,
|
| 117 |
-
)
|
| 118 |
-
|
| 119 |
-
pipeline = pipeline.to(device)
|
| 120 |
-
return pipeline
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
def get_args() -> argparse.Namespace:
|
| 124 |
-
parser = argparse.ArgumentParser()
|
| 125 |
-
parser.add_argument(
|
| 126 |
-
"-p",
|
| 127 |
-
"--precisions",
|
| 128 |
-
type=str,
|
| 129 |
-
default=["int4"],
|
| 130 |
-
nargs="*",
|
| 131 |
-
choices=["int4", "bf16"],
|
| 132 |
-
help="Which precisions to use",
|
| 133 |
-
)
|
| 134 |
-
parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder")
|
| 135 |
-
parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker")
|
| 136 |
-
parser.add_argument("--count-use", action="store_true", help="Whether to count the number of uses")
|
| 137 |
-
return parser.parse_args()
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
args = get_args()
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
pipelines = []
|
| 144 |
-
pipeline_init_kwargs = {}
|
| 145 |
-
for i, precision in enumerate(args.precisions):
|
| 146 |
-
|
| 147 |
-
pipeline = get_pipeline(
|
| 148 |
-
precision=precision,
|
| 149 |
-
use_qencoder=args.use_qencoder,
|
| 150 |
-
device="cuda",
|
| 151 |
-
pipeline_init_kwargs={**pipeline_init_kwargs},
|
| 152 |
-
)
|
| 153 |
-
pipelines.append(pipeline)
|
| 154 |
-
if i == 0:
|
| 155 |
-
pipeline_init_kwargs["vae"] = pipeline.vae
|
| 156 |
-
pipeline_init_kwargs["text_encoder"] = pipeline.text_encoder
|
| 157 |
-
|
| 158 |
-
# safety checker
|
| 159 |
-
safety_checker_tokenizer = AutoTokenizer.from_pretrained(args.shield_model_path)
|
| 160 |
-
safety_checker_model = AutoModelForCausalLM.from_pretrained(
|
| 161 |
-
args.shield_model_path,
|
| 162 |
-
device_map="auto",
|
| 163 |
-
torch_dtype=torch.bfloat16,
|
| 164 |
-
).to(pipeline.device)
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
@spaces.GPU(enable_queue=True)
|
| 168 |
-
def generate(
|
| 169 |
-
prompt: str = None,
|
| 170 |
-
height: int = 1024,
|
| 171 |
-
width: int = 1024,
|
| 172 |
-
num_inference_steps: int = 4,
|
| 173 |
-
guidance_scale: float = 0,
|
| 174 |
-
seed: int = 0,
|
| 175 |
-
):
|
| 176 |
-
print(f"Prompt: {prompt}")
|
| 177 |
-
is_unsafe_prompt = False
|
| 178 |
-
if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt, threshold=0.2):
|
| 179 |
-
prompt = "A peaceful world."
|
| 180 |
-
images, latency_strs = [], []
|
| 181 |
-
for i, pipeline in enumerate(pipelines):
|
| 182 |
-
progress = gr.Progress(track_tqdm=True)
|
| 183 |
-
start_time = time.time()
|
| 184 |
-
image = pipeline(
|
| 185 |
-
prompt=prompt,
|
| 186 |
-
height=height,
|
| 187 |
-
width=width,
|
| 188 |
-
guidance_scale=guidance_scale,
|
| 189 |
-
num_inference_steps=num_inference_steps,
|
| 190 |
-
generator=torch.Generator().manual_seed(seed),
|
| 191 |
-
).images[0]
|
| 192 |
-
end_time = time.time()
|
| 193 |
-
latency = end_time - start_time
|
| 194 |
-
if latency < 1:
|
| 195 |
-
latency = latency * 1000
|
| 196 |
-
latency_str = f"{latency:.2f}ms"
|
| 197 |
-
else:
|
| 198 |
-
latency_str = f"{latency:.2f}s"
|
| 199 |
-
images.append(image)
|
| 200 |
-
latency_strs.append(latency_str)
|
| 201 |
-
if is_unsafe_prompt:
|
| 202 |
-
for i in range(len(latency_strs)):
|
| 203 |
-
latency_strs[i] += " (Unsafe prompt detected)"
|
| 204 |
-
torch.cuda.empty_cache()
|
| 205 |
-
|
| 206 |
-
if args.count_use:
|
| 207 |
-
if os.path.exists("use_count.txt"):
|
| 208 |
-
with open("use_count.txt") as f:
|
| 209 |
-
count = int(f.read())
|
| 210 |
-
else:
|
| 211 |
-
count = 0
|
| 212 |
-
count += 1
|
| 213 |
-
current_time = datetime.now()
|
| 214 |
-
print(f"{current_time}: {count}")
|
| 215 |
-
with open("use_count.txt", "w") as f:
|
| 216 |
-
f.write(str(count))
|
| 217 |
-
with open("use_record.txt", "a") as f:
|
| 218 |
-
f.write(f"{current_time}: {count}\n")
|
| 219 |
-
|
| 220 |
-
return *images, *latency_strs
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
with open("./assets/description.html") as f:
|
| 224 |
-
DESCRIPTION = f.read()
|
| 225 |
-
gpus = GPUtil.getGPUs()
|
| 226 |
-
if len(gpus) > 0:
|
| 227 |
-
gpu = gpus[0]
|
| 228 |
-
memory = gpu.memoryTotal / 1024
|
| 229 |
-
device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory."
|
| 230 |
-
else:
|
| 231 |
-
device_info = "Running on CPU 🥶 This demo does not work on CPU."
|
| 232 |
-
notice = f'<strong>Notice:</strong> We will replace unsafe prompts with a default prompt: "A peaceful world."'
|
| 233 |
-
|
| 234 |
-
with gr.Blocks(
|
| 235 |
-
css_paths=[f"assets/frame{len(args.precisions)}.css", "assets/common.css"],
|
| 236 |
-
title=f"SVDQuant SANA-1600M Demo",
|
| 237 |
-
) as demo:
|
| 238 |
-
|
| 239 |
-
def get_header_str():
|
| 240 |
-
|
| 241 |
-
if args.count_use:
|
| 242 |
-
if os.path.exists("use_count.txt"):
|
| 243 |
-
with open("use_count.txt") as f:
|
| 244 |
-
count = int(f.read())
|
| 245 |
-
else:
|
| 246 |
-
count = 0
|
| 247 |
-
count_info = (
|
| 248 |
-
f"<div style='display: flex; justify-content: center; align-items: center; text-align: center;'>"
|
| 249 |
-
f"<span style='font-size: 18px; font-weight: bold;'>Total inference runs: </span>"
|
| 250 |
-
f"<span style='font-size: 18px; color:red; font-weight: bold;'> {count}</span></div>"
|
| 251 |
-
)
|
| 252 |
-
else:
|
| 253 |
-
count_info = ""
|
| 254 |
-
header_str = DESCRIPTION.format(device_info=device_info, notice=notice, count_info=count_info)
|
| 255 |
-
return header_str
|
| 256 |
-
|
| 257 |
-
header = gr.HTML(get_header_str())
|
| 258 |
-
demo.load(fn=get_header_str, outputs=header)
|
| 259 |
-
|
| 260 |
-
with gr.Row():
|
| 261 |
-
image_results, latency_results = [], []
|
| 262 |
-
for i, precision in enumerate(args.precisions):
|
| 263 |
-
with gr.Column():
|
| 264 |
-
gr.Markdown(f"# {precision.upper()}", elem_id="image_header")
|
| 265 |
-
with gr.Group():
|
| 266 |
-
image_result = gr.Image(
|
| 267 |
-
format="png",
|
| 268 |
-
image_mode="RGB",
|
| 269 |
-
label="Result",
|
| 270 |
-
show_label=False,
|
| 271 |
-
show_download_button=True,
|
| 272 |
-
interactive=False,
|
| 273 |
-
)
|
| 274 |
-
latency_result = gr.Text(label="Inference Latency", show_label=True)
|
| 275 |
-
image_results.append(image_result)
|
| 276 |
-
latency_results.append(latency_result)
|
| 277 |
-
with gr.Row():
|
| 278 |
-
prompt = gr.Text(
|
| 279 |
-
label="Prompt", show_label=False, max_lines=1, placeholder="Enter your prompt", container=False, scale=4
|
| 280 |
-
)
|
| 281 |
-
run_button = gr.Button("Run", scale=1)
|
| 282 |
-
|
| 283 |
-
with gr.Row():
|
| 284 |
-
seed = gr.Slider(label="Seed", show_label=True, minimum=0, maximum=MAX_SEED, value=233, step=1, scale=4)
|
| 285 |
-
randomize_seed = gr.Button("Random Seed", scale=1, min_width=50, elem_id="random_seed")
|
| 286 |
-
with gr.Accordion("Advanced options", open=False):
|
| 287 |
-
with gr.Group():
|
| 288 |
-
height = gr.Slider(label="Height", minimum=256, maximum=4096, step=32, value=1024)
|
| 289 |
-
width = gr.Slider(label="Width", minimum=256, maximum=4096, step=32, value=1024)
|
| 290 |
-
with gr.Group():
|
| 291 |
-
num_inference_steps = gr.Slider(label="Sampling Steps", minimum=10, maximum=50, step=1, value=20)
|
| 292 |
-
guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=10, step=0.1, value=5)
|
| 293 |
-
|
| 294 |
-
input_args = [prompt, height, width, num_inference_steps, guidance_scale, seed]
|
| 295 |
-
|
| 296 |
-
gr.Examples(examples=EXAMPLES, inputs=input_args, outputs=[*image_results, *latency_results], fn=generate)
|
| 297 |
-
|
| 298 |
-
gr.on(
|
| 299 |
-
triggers=[prompt.submit, run_button.click],
|
| 300 |
-
fn=generate,
|
| 301 |
-
inputs=input_args,
|
| 302 |
-
outputs=[*image_results, *latency_results],
|
| 303 |
-
api_name="run",
|
| 304 |
-
)
|
| 305 |
-
randomize_seed.click(
|
| 306 |
-
lambda: random.randint(0, MAX_SEED), inputs=[], outputs=seed, api_name=False, queue=False
|
| 307 |
-
).then(fn=generate, inputs=input_args, outputs=[*image_results, *latency_results], api_name=False, queue=False)
|
| 308 |
-
|
| 309 |
-
gr.Markdown("MIT Accessibility: https://accessibility.mit.edu/", elem_id="accessibility")
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
if __name__ == "__main__":
|
| 313 |
-
demo.queue(max_size=20).launch(server_name="0.0.0.0", debug=True, share=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app/app_sana_controlnet_hed.py
DELETED
|
@@ -1,306 +0,0 @@
|
|
| 1 |
-
# Changed from https://github.com/GaParmar/img2img-turbo/blob/main/gradio_sketch2image.py
|
| 2 |
-
import argparse
|
| 3 |
-
import os
|
| 4 |
-
import random
|
| 5 |
-
import socket
|
| 6 |
-
import tempfile
|
| 7 |
-
import time
|
| 8 |
-
|
| 9 |
-
import gradio as gr
|
| 10 |
-
import numpy as np
|
| 11 |
-
import torch
|
| 12 |
-
from PIL import Image
|
| 13 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 14 |
-
|
| 15 |
-
from app import safety_check
|
| 16 |
-
from app.sana_controlnet_pipeline import SanaControlNetPipeline
|
| 17 |
-
|
| 18 |
-
STYLES = {
|
| 19 |
-
"None": "{prompt}",
|
| 20 |
-
"Cinematic": "cinematic still {prompt}. emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
|
| 21 |
-
"3D Model": "professional 3d model {prompt}. octane render, highly detailed, volumetric, dramatic lighting",
|
| 22 |
-
"Anime": "anime artwork {prompt}. anime style, key visual, vibrant, studio anime, highly detailed",
|
| 23 |
-
"Digital Art": "concept art {prompt}. digital artwork, illustrative, painterly, matte painting, highly detailed",
|
| 24 |
-
"Photographic": "cinematic photo {prompt}. 35mm photograph, film, bokeh, professional, 4k, highly detailed",
|
| 25 |
-
"Pixel art": "pixel-art {prompt}. low-res, blocky, pixel art style, 8-bit graphics",
|
| 26 |
-
"Fantasy art": "ethereal fantasy concept art of {prompt}. magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
|
| 27 |
-
"Neonpunk": "neonpunk style {prompt}. cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
|
| 28 |
-
"Manga": "manga style {prompt}. vibrant, high-energy, detailed, iconic, Japanese comic style",
|
| 29 |
-
}
|
| 30 |
-
DEFAULT_STYLE_NAME = "None"
|
| 31 |
-
STYLE_NAMES = list(STYLES.keys())
|
| 32 |
-
|
| 33 |
-
MAX_SEED = 1000000000
|
| 34 |
-
DEFAULT_SKETCH_GUIDANCE = 0.28
|
| 35 |
-
DEMO_PORT = int(os.getenv("DEMO_PORT", "15432"))
|
| 36 |
-
|
| 37 |
-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 38 |
-
|
| 39 |
-
blank_image = Image.new("RGB", (1024, 1024), (255, 255, 255))
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
def get_args():
|
| 43 |
-
parser = argparse.ArgumentParser()
|
| 44 |
-
parser.add_argument("--config", type=str, help="config")
|
| 45 |
-
parser.add_argument(
|
| 46 |
-
"--model_path",
|
| 47 |
-
nargs="?",
|
| 48 |
-
default="hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth",
|
| 49 |
-
type=str,
|
| 50 |
-
help="Path to the model file (positional)",
|
| 51 |
-
)
|
| 52 |
-
parser.add_argument("--output", default="./", type=str)
|
| 53 |
-
parser.add_argument("--bs", default=1, type=int)
|
| 54 |
-
parser.add_argument("--image_size", default=1024, type=int)
|
| 55 |
-
parser.add_argument("--cfg_scale", default=5.0, type=float)
|
| 56 |
-
parser.add_argument("--pag_scale", default=2.0, type=float)
|
| 57 |
-
parser.add_argument("--seed", default=42, type=int)
|
| 58 |
-
parser.add_argument("--step", default=-1, type=int)
|
| 59 |
-
parser.add_argument("--custom_image_size", default=None, type=int)
|
| 60 |
-
parser.add_argument("--share", action="store_true")
|
| 61 |
-
parser.add_argument(
|
| 62 |
-
"--shield_model_path",
|
| 63 |
-
type=str,
|
| 64 |
-
help="The path to shield model, we employ ShieldGemma-2B by default.",
|
| 65 |
-
default="google/shieldgemma-2b",
|
| 66 |
-
)
|
| 67 |
-
|
| 68 |
-
return parser.parse_known_args()[0]
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
args = get_args()
|
| 72 |
-
|
| 73 |
-
if torch.cuda.is_available():
|
| 74 |
-
model_path = args.model_path
|
| 75 |
-
pipe = SanaControlNetPipeline(args.config)
|
| 76 |
-
pipe.from_pretrained(model_path)
|
| 77 |
-
pipe.register_progress_bar(gr.Progress())
|
| 78 |
-
|
| 79 |
-
# safety checker
|
| 80 |
-
safety_checker_tokenizer = AutoTokenizer.from_pretrained(args.shield_model_path)
|
| 81 |
-
safety_checker_model = AutoModelForCausalLM.from_pretrained(
|
| 82 |
-
args.shield_model_path,
|
| 83 |
-
device_map="auto",
|
| 84 |
-
torch_dtype=torch.bfloat16,
|
| 85 |
-
).to(device)
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
def save_image(img):
|
| 89 |
-
if isinstance(img, dict):
|
| 90 |
-
img = img["composite"]
|
| 91 |
-
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
| 92 |
-
img.save(temp_file.name)
|
| 93 |
-
return temp_file.name
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
def norm_ip(img, low, high):
|
| 97 |
-
img.clamp_(min=low, max=high)
|
| 98 |
-
img.sub_(low).div_(max(high - low, 1e-5))
|
| 99 |
-
return img
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
@torch.no_grad()
|
| 103 |
-
@torch.inference_mode()
|
| 104 |
-
def run(
|
| 105 |
-
image,
|
| 106 |
-
prompt: str,
|
| 107 |
-
prompt_template: str,
|
| 108 |
-
sketch_thickness: int,
|
| 109 |
-
guidance_scale: float,
|
| 110 |
-
inference_steps: int,
|
| 111 |
-
seed: int,
|
| 112 |
-
blend_alpha: float,
|
| 113 |
-
) -> tuple[Image, str]:
|
| 114 |
-
|
| 115 |
-
print(f"Prompt: {prompt}")
|
| 116 |
-
image_numpy = np.array(image["composite"].convert("RGB"))
|
| 117 |
-
|
| 118 |
-
if prompt.strip() == "" and (np.sum(image_numpy == 255) >= 3145628 or np.sum(image_numpy == 0) >= 3145628):
|
| 119 |
-
return blank_image, "Please input the prompt or draw something."
|
| 120 |
-
|
| 121 |
-
if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt, threshold=0.2):
|
| 122 |
-
prompt = "A red heart."
|
| 123 |
-
|
| 124 |
-
prompt = prompt_template.format(prompt=prompt)
|
| 125 |
-
pipe.set_blend_alpha(blend_alpha)
|
| 126 |
-
start_time = time.time()
|
| 127 |
-
images = pipe(
|
| 128 |
-
prompt=prompt,
|
| 129 |
-
ref_image=image["composite"],
|
| 130 |
-
guidance_scale=guidance_scale,
|
| 131 |
-
num_inference_steps=inference_steps,
|
| 132 |
-
num_images_per_prompt=1,
|
| 133 |
-
sketch_thickness=sketch_thickness,
|
| 134 |
-
generator=torch.Generator(device=device).manual_seed(seed),
|
| 135 |
-
)
|
| 136 |
-
|
| 137 |
-
latency = time.time() - start_time
|
| 138 |
-
|
| 139 |
-
if latency < 1:
|
| 140 |
-
latency = latency * 1000
|
| 141 |
-
latency_str = f"{latency:.2f}ms"
|
| 142 |
-
else:
|
| 143 |
-
latency_str = f"{latency:.2f}s"
|
| 144 |
-
torch.cuda.empty_cache()
|
| 145 |
-
|
| 146 |
-
img = [
|
| 147 |
-
Image.fromarray(
|
| 148 |
-
norm_ip(img, -1, 1)
|
| 149 |
-
.mul(255)
|
| 150 |
-
.add_(0.5)
|
| 151 |
-
.clamp_(0, 255)
|
| 152 |
-
.permute(1, 2, 0)
|
| 153 |
-
.to("cpu", torch.uint8)
|
| 154 |
-
.numpy()
|
| 155 |
-
.astype(np.uint8)
|
| 156 |
-
)
|
| 157 |
-
for img in images
|
| 158 |
-
]
|
| 159 |
-
img = img[0]
|
| 160 |
-
return img, latency_str
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
model_size = "1.6" if "1600M" in args.model_path else "0.6"
|
| 164 |
-
title = f"""
|
| 165 |
-
<div style='display: flex; align-items: center; justify-content: center; text-align: center;'>
|
| 166 |
-
<img src="https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/logo.png" width="50%" alt="logo"/>
|
| 167 |
-
</div>
|
| 168 |
-
"""
|
| 169 |
-
DESCRIPTION = f"""
|
| 170 |
-
<p><span style="font-size: 36px; font-weight: bold;">Sana-ControlNet-{model_size}B</span><span style="font-size: 20px; font-weight: bold;">{args.image_size}px</span></p>
|
| 171 |
-
<p style="font-size: 18px; font-weight: bold;">Sana: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer</p>
|
| 172 |
-
<p><span style="font-size: 16px;"><a href="https://arxiv.org/abs/2410.10629">[Paper]</a></span> <span style="font-size: 16px;"><a href="https://github.com/NVlabs/Sana">[Github]</a></span> <span style="font-size: 16px;"><a href="https://nvlabs.github.io/Sana">[Project]</a></span</p>
|
| 173 |
-
<p style="font-size: 18px; font-weight: bold;">Powered by <a href="https://hanlab.mit.edu/projects/dc-ae">DC-AE</a> with 32x latent space, </p>running on node {socket.gethostname()}.
|
| 174 |
-
<p style="font-size: 16px; font-weight: bold;">Unsafe word will give you a 'Red Heart' in the image instead.</p>
|
| 175 |
-
"""
|
| 176 |
-
if model_size == "0.6":
|
| 177 |
-
DESCRIPTION += "\n<p>0.6B model's text rendering ability is limited.</p>"
|
| 178 |
-
if not torch.cuda.is_available():
|
| 179 |
-
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
with gr.Blocks(css_paths="asset/app_styles/controlnet_app_style.css", title=f"Sana Sketch-to-Image Demo") as demo:
|
| 183 |
-
gr.Markdown(title)
|
| 184 |
-
gr.HTML(DESCRIPTION)
|
| 185 |
-
|
| 186 |
-
with gr.Row(elem_id="main_row"):
|
| 187 |
-
with gr.Column(elem_id="column_input"):
|
| 188 |
-
gr.Markdown("## INPUT", elem_id="input_header")
|
| 189 |
-
with gr.Group():
|
| 190 |
-
canvas = gr.Sketchpad(
|
| 191 |
-
value=blank_image,
|
| 192 |
-
height=640,
|
| 193 |
-
image_mode="RGB",
|
| 194 |
-
sources=["upload", "clipboard"],
|
| 195 |
-
type="pil",
|
| 196 |
-
label="Sketch",
|
| 197 |
-
show_label=False,
|
| 198 |
-
show_download_button=True,
|
| 199 |
-
interactive=True,
|
| 200 |
-
transforms=[],
|
| 201 |
-
canvas_size=(1024, 1024),
|
| 202 |
-
scale=1,
|
| 203 |
-
brush=gr.Brush(default_size=3, colors=["#000000"], color_mode="fixed"),
|
| 204 |
-
format="png",
|
| 205 |
-
layers=False,
|
| 206 |
-
)
|
| 207 |
-
with gr.Row():
|
| 208 |
-
prompt = gr.Text(label="Prompt", placeholder="Enter your prompt", scale=6)
|
| 209 |
-
run_button = gr.Button("Run", scale=1, elem_id="run_button")
|
| 210 |
-
download_sketch = gr.DownloadButton("Download Sketch", scale=1, elem_id="download_sketch")
|
| 211 |
-
with gr.Row():
|
| 212 |
-
style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME, scale=1)
|
| 213 |
-
prompt_template = gr.Textbox(
|
| 214 |
-
label="Prompt Style Template", value=STYLES[DEFAULT_STYLE_NAME], scale=2, max_lines=1
|
| 215 |
-
)
|
| 216 |
-
|
| 217 |
-
with gr.Row():
|
| 218 |
-
sketch_thickness = gr.Slider(
|
| 219 |
-
label="Sketch Thickness",
|
| 220 |
-
minimum=1,
|
| 221 |
-
maximum=4,
|
| 222 |
-
step=1,
|
| 223 |
-
value=2,
|
| 224 |
-
)
|
| 225 |
-
with gr.Row():
|
| 226 |
-
inference_steps = gr.Slider(
|
| 227 |
-
label="Sampling steps",
|
| 228 |
-
minimum=5,
|
| 229 |
-
maximum=40,
|
| 230 |
-
step=1,
|
| 231 |
-
value=20,
|
| 232 |
-
)
|
| 233 |
-
guidance_scale = gr.Slider(
|
| 234 |
-
label="CFG Guidance scale",
|
| 235 |
-
minimum=1,
|
| 236 |
-
maximum=10,
|
| 237 |
-
step=0.1,
|
| 238 |
-
value=4.5,
|
| 239 |
-
)
|
| 240 |
-
blend_alpha = gr.Slider(
|
| 241 |
-
label="Blend Alpha",
|
| 242 |
-
minimum=0,
|
| 243 |
-
maximum=1,
|
| 244 |
-
step=0.1,
|
| 245 |
-
value=0,
|
| 246 |
-
)
|
| 247 |
-
with gr.Row():
|
| 248 |
-
seed = gr.Slider(label="Seed", show_label=True, minimum=0, maximum=MAX_SEED, value=233, step=1, scale=4)
|
| 249 |
-
randomize_seed = gr.Button("Random Seed", scale=1, min_width=50, elem_id="random_seed")
|
| 250 |
-
|
| 251 |
-
with gr.Column(elem_id="column_output"):
|
| 252 |
-
gr.Markdown("## OUTPUT", elem_id="output_header")
|
| 253 |
-
with gr.Group():
|
| 254 |
-
result = gr.Image(
|
| 255 |
-
format="png",
|
| 256 |
-
height=640,
|
| 257 |
-
image_mode="RGB",
|
| 258 |
-
type="pil",
|
| 259 |
-
label="Result",
|
| 260 |
-
show_label=False,
|
| 261 |
-
show_download_button=True,
|
| 262 |
-
interactive=False,
|
| 263 |
-
elem_id="output_image",
|
| 264 |
-
)
|
| 265 |
-
latency_result = gr.Text(label="Inference Latency", show_label=True)
|
| 266 |
-
|
| 267 |
-
download_result = gr.DownloadButton("Download Result", elem_id="download_result")
|
| 268 |
-
gr.Markdown("### Instructions")
|
| 269 |
-
gr.Markdown("**1**. Enter a text prompt (e.g. a cat)")
|
| 270 |
-
gr.Markdown("**2**. Start sketching or upload a reference image")
|
| 271 |
-
gr.Markdown("**3**. Change the image style using a style template")
|
| 272 |
-
gr.Markdown("**4**. Try different seeds to generate different results")
|
| 273 |
-
|
| 274 |
-
run_inputs = [canvas, prompt, prompt_template, sketch_thickness, guidance_scale, inference_steps, seed, blend_alpha]
|
| 275 |
-
run_outputs = [result, latency_result]
|
| 276 |
-
|
| 277 |
-
randomize_seed.click(
|
| 278 |
-
lambda: random.randint(0, MAX_SEED),
|
| 279 |
-
inputs=[],
|
| 280 |
-
outputs=seed,
|
| 281 |
-
api_name=False,
|
| 282 |
-
queue=False,
|
| 283 |
-
).then(run, inputs=run_inputs, outputs=run_outputs, api_name=False)
|
| 284 |
-
|
| 285 |
-
style.change(
|
| 286 |
-
lambda x: STYLES[x],
|
| 287 |
-
inputs=[style],
|
| 288 |
-
outputs=[prompt_template],
|
| 289 |
-
api_name=False,
|
| 290 |
-
queue=False,
|
| 291 |
-
).then(fn=run, inputs=run_inputs, outputs=run_outputs, api_name=False)
|
| 292 |
-
gr.on(
|
| 293 |
-
triggers=[prompt.submit, run_button.click, canvas.change],
|
| 294 |
-
fn=run,
|
| 295 |
-
inputs=run_inputs,
|
| 296 |
-
outputs=run_outputs,
|
| 297 |
-
api_name=False,
|
| 298 |
-
)
|
| 299 |
-
|
| 300 |
-
download_sketch.click(fn=save_image, inputs=canvas, outputs=download_sketch)
|
| 301 |
-
download_result.click(fn=save_image, inputs=result, outputs=download_result)
|
| 302 |
-
gr.Markdown("MIT Accessibility: https://accessibility.mit.edu/", elem_id="accessibility")
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
if __name__ == "__main__":
|
| 306 |
-
demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=DEMO_PORT, debug=False, share=args.share)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app/app_sana_multithread.py
DELETED
|
@@ -1,565 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python
|
| 2 |
-
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
#
|
| 16 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 17 |
-
from __future__ import annotations
|
| 18 |
-
|
| 19 |
-
import argparse
|
| 20 |
-
import os
|
| 21 |
-
import random
|
| 22 |
-
import uuid
|
| 23 |
-
from datetime import datetime
|
| 24 |
-
|
| 25 |
-
import gradio as gr
|
| 26 |
-
import numpy as np
|
| 27 |
-
import spaces
|
| 28 |
-
import torch
|
| 29 |
-
from diffusers import FluxPipeline
|
| 30 |
-
from PIL import Image
|
| 31 |
-
from torchvision.utils import make_grid, save_image
|
| 32 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 33 |
-
|
| 34 |
-
from app import safety_check
|
| 35 |
-
from app.sana_pipeline import SanaPipeline
|
| 36 |
-
|
| 37 |
-
MAX_SEED = np.iinfo(np.int32).max
|
| 38 |
-
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
|
| 39 |
-
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
|
| 40 |
-
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
|
| 41 |
-
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
|
| 42 |
-
DEMO_PORT = int(os.getenv("DEMO_PORT", "15432"))
|
| 43 |
-
os.environ["GRADIO_EXAMPLES_CACHE"] = "./.gradio/cache"
|
| 44 |
-
|
| 45 |
-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 46 |
-
|
| 47 |
-
style_list = [
|
| 48 |
-
{
|
| 49 |
-
"name": "(No style)",
|
| 50 |
-
"prompt": "{prompt}",
|
| 51 |
-
"negative_prompt": "",
|
| 52 |
-
},
|
| 53 |
-
{
|
| 54 |
-
"name": "Cinematic",
|
| 55 |
-
"prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, "
|
| 56 |
-
"cinemascope, moody, epic, gorgeous, film grain, grainy",
|
| 57 |
-
"negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
|
| 58 |
-
},
|
| 59 |
-
{
|
| 60 |
-
"name": "Photographic",
|
| 61 |
-
"prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
|
| 62 |
-
"negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
|
| 63 |
-
},
|
| 64 |
-
{
|
| 65 |
-
"name": "Anime",
|
| 66 |
-
"prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
|
| 67 |
-
"negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast",
|
| 68 |
-
},
|
| 69 |
-
{
|
| 70 |
-
"name": "Manga",
|
| 71 |
-
"prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style",
|
| 72 |
-
"negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
|
| 73 |
-
},
|
| 74 |
-
{
|
| 75 |
-
"name": "Digital Art",
|
| 76 |
-
"prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
|
| 77 |
-
"negative_prompt": "photo, photorealistic, realism, ugly",
|
| 78 |
-
},
|
| 79 |
-
{
|
| 80 |
-
"name": "Pixel art",
|
| 81 |
-
"prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
|
| 82 |
-
"negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
|
| 83 |
-
},
|
| 84 |
-
{
|
| 85 |
-
"name": "Fantasy art",
|
| 86 |
-
"prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, "
|
| 87 |
-
"majestic, magical, fantasy art, cover art, dreamy",
|
| 88 |
-
"negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, "
|
| 89 |
-
"glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, "
|
| 90 |
-
"disfigured, sloppy, duplicate, mutated, black and white",
|
| 91 |
-
},
|
| 92 |
-
{
|
| 93 |
-
"name": "Neonpunk",
|
| 94 |
-
"prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, "
|
| 95 |
-
"detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, "
|
| 96 |
-
"ultra detailed, intricate, professional",
|
| 97 |
-
"negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
|
| 98 |
-
},
|
| 99 |
-
{
|
| 100 |
-
"name": "3D Model",
|
| 101 |
-
"prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
|
| 102 |
-
"negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
|
| 103 |
-
},
|
| 104 |
-
]
|
| 105 |
-
|
| 106 |
-
styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
|
| 107 |
-
STYLE_NAMES = list(styles.keys())
|
| 108 |
-
DEFAULT_STYLE_NAME = "(No style)"
|
| 109 |
-
SCHEDULE_NAME = ["Flow_DPM_Solver"]
|
| 110 |
-
DEFAULT_SCHEDULE_NAME = "Flow_DPM_Solver"
|
| 111 |
-
NUM_IMAGES_PER_PROMPT = 1
|
| 112 |
-
TEST_TIMES = 0
|
| 113 |
-
FILENAME = f"output/port{DEMO_PORT}_inference_count.txt"
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
def set_env(seed=0):
|
| 117 |
-
torch.manual_seed(seed)
|
| 118 |
-
torch.set_grad_enabled(False)
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
def read_inference_count():
|
| 122 |
-
global TEST_TIMES
|
| 123 |
-
try:
|
| 124 |
-
with open(FILENAME) as f:
|
| 125 |
-
count = int(f.read().strip())
|
| 126 |
-
except FileNotFoundError:
|
| 127 |
-
count = 0
|
| 128 |
-
TEST_TIMES = count
|
| 129 |
-
|
| 130 |
-
return count
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
def write_inference_count(count):
|
| 134 |
-
with open(FILENAME, "w") as f:
|
| 135 |
-
f.write(str(count))
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
def run_inference(num_imgs=1):
|
| 139 |
-
TEST_TIMES = read_inference_count()
|
| 140 |
-
TEST_TIMES += int(num_imgs)
|
| 141 |
-
write_inference_count(TEST_TIMES)
|
| 142 |
-
|
| 143 |
-
return (
|
| 144 |
-
f"<span style='font-size: 16px; font-weight: bold;'>Total inference runs: </span><span style='font-size: "
|
| 145 |
-
f"16px; color:red; font-weight: bold;'>{TEST_TIMES}</span>"
|
| 146 |
-
)
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
def update_inference_count():
|
| 150 |
-
count = read_inference_count()
|
| 151 |
-
return (
|
| 152 |
-
f"<span style='font-size: 16px; font-weight: bold;'>Total inference runs: </span><span style='font-size: "
|
| 153 |
-
f"16px; color:red; font-weight: bold;'>{count}</span>"
|
| 154 |
-
)
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
|
| 158 |
-
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
|
| 159 |
-
if not negative:
|
| 160 |
-
negative = ""
|
| 161 |
-
return p.replace("{prompt}", positive), n + negative
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
def get_args():
|
| 165 |
-
parser = argparse.ArgumentParser()
|
| 166 |
-
parser.add_argument("--config", type=str, help="config")
|
| 167 |
-
parser.add_argument(
|
| 168 |
-
"--model_path",
|
| 169 |
-
nargs="?",
|
| 170 |
-
default="output/Sana_D20/SANA.pth",
|
| 171 |
-
type=str,
|
| 172 |
-
help="Path to the model file (positional)",
|
| 173 |
-
)
|
| 174 |
-
parser.add_argument("--output", default="./", type=str)
|
| 175 |
-
parser.add_argument("--bs", default=1, type=int)
|
| 176 |
-
parser.add_argument("--image_size", default=1024, type=int)
|
| 177 |
-
parser.add_argument("--cfg_scale", default=5.0, type=float)
|
| 178 |
-
parser.add_argument("--pag_scale", default=2.0, type=float)
|
| 179 |
-
parser.add_argument("--seed", default=42, type=int)
|
| 180 |
-
parser.add_argument("--step", default=-1, type=int)
|
| 181 |
-
parser.add_argument("--custom_image_size", default=None, type=int)
|
| 182 |
-
parser.add_argument(
|
| 183 |
-
"--shield_model_path",
|
| 184 |
-
type=str,
|
| 185 |
-
help="The path to shield model, we employ ShieldGemma-2B by default.",
|
| 186 |
-
default="google/shieldgemma-2b",
|
| 187 |
-
)
|
| 188 |
-
|
| 189 |
-
return parser.parse_args()
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
args = get_args()
|
| 193 |
-
|
| 194 |
-
if torch.cuda.is_available():
|
| 195 |
-
weight_dtype = torch.float16
|
| 196 |
-
model_path = args.model_path
|
| 197 |
-
pipe = SanaPipeline(args.config)
|
| 198 |
-
pipe.from_pretrained(model_path)
|
| 199 |
-
pipe.register_progress_bar(gr.Progress())
|
| 200 |
-
|
| 201 |
-
repo_name = "black-forest-labs/FLUX.1-dev"
|
| 202 |
-
pipe2 = FluxPipeline.from_pretrained(repo_name, torch_dtype=torch.float16).to("cuda")
|
| 203 |
-
|
| 204 |
-
# safety checker
|
| 205 |
-
safety_checker_tokenizer = AutoTokenizer.from_pretrained(args.shield_model_path)
|
| 206 |
-
safety_checker_model = AutoModelForCausalLM.from_pretrained(
|
| 207 |
-
args.shield_model_path,
|
| 208 |
-
device_map="auto",
|
| 209 |
-
torch_dtype=torch.bfloat16,
|
| 210 |
-
).to(device)
|
| 211 |
-
|
| 212 |
-
set_env(42)
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
def save_image_sana(img, seed="", save_img=False):
|
| 216 |
-
unique_name = f"{str(uuid.uuid4())}_{seed}.png"
|
| 217 |
-
save_path = os.path.join(f"output/online_demo_img/{datetime.now().date()}")
|
| 218 |
-
os.umask(0o000) # file permission: 666; dir permission: 777
|
| 219 |
-
os.makedirs(save_path, exist_ok=True)
|
| 220 |
-
unique_name = os.path.join(save_path, unique_name)
|
| 221 |
-
if save_img:
|
| 222 |
-
save_image(img, unique_name, nrow=1, normalize=True, value_range=(-1, 1))
|
| 223 |
-
|
| 224 |
-
return unique_name
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
| 228 |
-
if randomize_seed:
|
| 229 |
-
seed = random.randint(0, MAX_SEED)
|
| 230 |
-
return seed
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
@spaces.GPU(enable_queue=True)
|
| 234 |
-
async def generate_2(
|
| 235 |
-
prompt: str = None,
|
| 236 |
-
negative_prompt: str = "",
|
| 237 |
-
style: str = DEFAULT_STYLE_NAME,
|
| 238 |
-
use_negative_prompt: bool = False,
|
| 239 |
-
num_imgs: int = 1,
|
| 240 |
-
seed: int = 0,
|
| 241 |
-
height: int = 1024,
|
| 242 |
-
width: int = 1024,
|
| 243 |
-
flow_dpms_guidance_scale: float = 5.0,
|
| 244 |
-
flow_dpms_pag_guidance_scale: float = 2.0,
|
| 245 |
-
flow_dpms_inference_steps: int = 20,
|
| 246 |
-
randomize_seed: bool = False,
|
| 247 |
-
):
|
| 248 |
-
seed = int(randomize_seed_fn(seed, randomize_seed))
|
| 249 |
-
generator = torch.Generator(device=device).manual_seed(seed)
|
| 250 |
-
print(f"PORT: {DEMO_PORT}, model_path: {model_path}")
|
| 251 |
-
if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt):
|
| 252 |
-
prompt = "A red heart."
|
| 253 |
-
|
| 254 |
-
print(prompt)
|
| 255 |
-
|
| 256 |
-
if not use_negative_prompt:
|
| 257 |
-
negative_prompt = None # type: ignore
|
| 258 |
-
prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
|
| 259 |
-
|
| 260 |
-
with torch.no_grad():
|
| 261 |
-
images = pipe2(
|
| 262 |
-
prompt=prompt,
|
| 263 |
-
height=height,
|
| 264 |
-
width=width,
|
| 265 |
-
guidance_scale=3.5,
|
| 266 |
-
num_inference_steps=50,
|
| 267 |
-
num_images_per_prompt=num_imgs,
|
| 268 |
-
max_sequence_length=256,
|
| 269 |
-
generator=generator,
|
| 270 |
-
).images
|
| 271 |
-
|
| 272 |
-
save_img = False
|
| 273 |
-
img = images
|
| 274 |
-
if save_img:
|
| 275 |
-
img = [save_image_sana(img, seed, save_img=save_image) for img in images]
|
| 276 |
-
print(img)
|
| 277 |
-
torch.cuda.empty_cache()
|
| 278 |
-
|
| 279 |
-
return img
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
@spaces.GPU(enable_queue=True)
|
| 283 |
-
async def generate(
|
| 284 |
-
prompt: str = None,
|
| 285 |
-
negative_prompt: str = "",
|
| 286 |
-
style: str = DEFAULT_STYLE_NAME,
|
| 287 |
-
use_negative_prompt: bool = False,
|
| 288 |
-
num_imgs: int = 1,
|
| 289 |
-
seed: int = 0,
|
| 290 |
-
height: int = 1024,
|
| 291 |
-
width: int = 1024,
|
| 292 |
-
flow_dpms_guidance_scale: float = 5.0,
|
| 293 |
-
flow_dpms_pag_guidance_scale: float = 2.0,
|
| 294 |
-
flow_dpms_inference_steps: int = 20,
|
| 295 |
-
randomize_seed: bool = False,
|
| 296 |
-
):
|
| 297 |
-
global TEST_TIMES
|
| 298 |
-
# seed = 823753551
|
| 299 |
-
seed = int(randomize_seed_fn(seed, randomize_seed))
|
| 300 |
-
generator = torch.Generator(device=device).manual_seed(seed)
|
| 301 |
-
print(f"PORT: {DEMO_PORT}, model_path: {model_path}, time_times: {TEST_TIMES}")
|
| 302 |
-
if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt):
|
| 303 |
-
prompt = "A red heart."
|
| 304 |
-
|
| 305 |
-
print(prompt)
|
| 306 |
-
|
| 307 |
-
num_inference_steps = flow_dpms_inference_steps
|
| 308 |
-
guidance_scale = flow_dpms_guidance_scale
|
| 309 |
-
pag_guidance_scale = flow_dpms_pag_guidance_scale
|
| 310 |
-
|
| 311 |
-
if not use_negative_prompt:
|
| 312 |
-
negative_prompt = None # type: ignore
|
| 313 |
-
prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
|
| 314 |
-
|
| 315 |
-
pipe.progress_fn(0, desc="Sana Start")
|
| 316 |
-
|
| 317 |
-
with torch.no_grad():
|
| 318 |
-
images = pipe(
|
| 319 |
-
prompt=prompt,
|
| 320 |
-
height=height,
|
| 321 |
-
width=width,
|
| 322 |
-
negative_prompt=negative_prompt,
|
| 323 |
-
guidance_scale=guidance_scale,
|
| 324 |
-
pag_guidance_scale=pag_guidance_scale,
|
| 325 |
-
num_inference_steps=num_inference_steps,
|
| 326 |
-
num_images_per_prompt=num_imgs,
|
| 327 |
-
generator=generator,
|
| 328 |
-
)
|
| 329 |
-
|
| 330 |
-
pipe.progress_fn(1.0, desc="Sana End")
|
| 331 |
-
|
| 332 |
-
save_img = False
|
| 333 |
-
if save_img:
|
| 334 |
-
img = [save_image_sana(img, seed, save_img=save_image) for img in images]
|
| 335 |
-
print(img)
|
| 336 |
-
else:
|
| 337 |
-
if num_imgs > 1:
|
| 338 |
-
nrow = 2
|
| 339 |
-
else:
|
| 340 |
-
nrow = 1
|
| 341 |
-
img = make_grid(images, nrow=nrow, normalize=True, value_range=(-1, 1))
|
| 342 |
-
img = img.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
|
| 343 |
-
img = [Image.fromarray(img.astype(np.uint8))]
|
| 344 |
-
|
| 345 |
-
torch.cuda.empty_cache()
|
| 346 |
-
|
| 347 |
-
return img
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
TEST_TIMES = read_inference_count()
|
| 351 |
-
model_size = "1.6" if "D20" in args.model_path else "0.6"
|
| 352 |
-
title = f"""
|
| 353 |
-
<div style='display: flex; align-items: center; justify-content: center; text-align: center;'>
|
| 354 |
-
<img src="https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/logo.png" width="50%" alt="logo"/>
|
| 355 |
-
</div>
|
| 356 |
-
"""
|
| 357 |
-
DESCRIPTION = f"""
|
| 358 |
-
<p><span style="font-size: 36px; font-weight: bold;">Sana-{model_size}B</span><span style="font-size: 20px; font-weight: bold;">{args.image_size}px</span></p>
|
| 359 |
-
<p style="font-size: 16px; font-weight: bold;">Sana: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer</p>
|
| 360 |
-
<p><span style="font-size: 16px;"><a href="https://arxiv.org/abs/2410.10629">[Paper]</a></span> <span style="font-size: 16px;"><a href="https://github.com/NVlabs/Sana">[Github]</a></span> <span style="font-size: 16px;"><a href="https://nvlabs.github.io/Sana">[Project]</a></span</p>
|
| 361 |
-
<p style="font-size: 16px; font-weight: bold;">Powered by <a href="https://hanlab.mit.edu/projects/dc-ae">DC-AE</a> with 32x latent space</p>
|
| 362 |
-
<p style="font-size: 16px; font-weight: bold;">Unsafe word will give you a 'Red Heart' in the image instead.</p>
|
| 363 |
-
"""
|
| 364 |
-
if model_size == "0.6":
|
| 365 |
-
DESCRIPTION += "\n<p>0.6B model's text rendering ability is limited.</p>"
|
| 366 |
-
if not torch.cuda.is_available():
|
| 367 |
-
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
|
| 368 |
-
|
| 369 |
-
examples = [
|
| 370 |
-
'a cyberpunk cat with a neon sign that says "Sana"',
|
| 371 |
-
"A very detailed and realistic full body photo set of a tall, slim, and athletic Shiba Inu in a white oversized straight t-shirt, white shorts, and short white shoes.",
|
| 372 |
-
"Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, volumetric lighting, spectacular, ambient lights, light pollution, cinematic atmosphere, art nouveau style, illustration art artwork by SenseiJaye, intricate detail.",
|
| 373 |
-
"portrait photo of a girl, photograph, highly detailed face, depth of field",
|
| 374 |
-
'make me a logo that says "So Fast" with a really cool flying dragon shape with lightning sparks all over the sides and all of it contains Indonesian language',
|
| 375 |
-
"🐶 Wearing 🕶 flying on the 🌈",
|
| 376 |
-
# "👧 with 🌹 in the ❄️",
|
| 377 |
-
# "an old rusted robot wearing pants and a jacket riding skis in a supermarket.",
|
| 378 |
-
# "professional portrait photo of an anthropomorphic cat wearing fancy gentleman hat and jacket walking in autumn forest.",
|
| 379 |
-
# "Astronaut in a jungle, cold color palette, muted colors, detailed",
|
| 380 |
-
# "a stunning and luxurious bedroom carved into a rocky mountainside seamlessly blending nature with modern design with a plush earth-toned bed textured stone walls circular fireplace massive uniquely shaped window framing snow-capped mountains dense forests",
|
| 381 |
-
]
|
| 382 |
-
|
| 383 |
-
css = """
|
| 384 |
-
.gradio-container{max-width: 1024px !important}
|
| 385 |
-
h1{text-align:center}
|
| 386 |
-
"""
|
| 387 |
-
with gr.Blocks(css=css) as demo:
|
| 388 |
-
gr.Markdown(title)
|
| 389 |
-
gr.Markdown(DESCRIPTION)
|
| 390 |
-
gr.DuplicateButton(
|
| 391 |
-
value="Duplicate Space for private use",
|
| 392 |
-
elem_id="duplicate-button",
|
| 393 |
-
visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
|
| 394 |
-
)
|
| 395 |
-
info_box = gr.Markdown(
|
| 396 |
-
value=f"<span style='font-size: 16px; font-weight: bold;'>Total inference runs: </span><span style='font-size: 16px; color:red; font-weight: bold;'>{read_inference_count()}</span>"
|
| 397 |
-
)
|
| 398 |
-
demo.load(fn=update_inference_count, outputs=info_box) # update the value when re-loading the page
|
| 399 |
-
# with gr.Row(equal_height=False):
|
| 400 |
-
with gr.Group():
|
| 401 |
-
with gr.Row():
|
| 402 |
-
prompt = gr.Text(
|
| 403 |
-
label="Prompt",
|
| 404 |
-
show_label=False,
|
| 405 |
-
max_lines=1,
|
| 406 |
-
placeholder="Enter your prompt",
|
| 407 |
-
container=False,
|
| 408 |
-
)
|
| 409 |
-
run_button = gr.Button("Run-sana", scale=0)
|
| 410 |
-
run_button2 = gr.Button("Run-flux", scale=0)
|
| 411 |
-
|
| 412 |
-
with gr.Row():
|
| 413 |
-
result = gr.Gallery(label="Result from Sana", show_label=True, columns=NUM_IMAGES_PER_PROMPT, format="webp")
|
| 414 |
-
result_2 = gr.Gallery(
|
| 415 |
-
label="Result from FLUX", show_label=True, columns=NUM_IMAGES_PER_PROMPT, format="webp"
|
| 416 |
-
)
|
| 417 |
-
|
| 418 |
-
with gr.Accordion("Advanced options", open=False):
|
| 419 |
-
with gr.Group():
|
| 420 |
-
with gr.Row(visible=True):
|
| 421 |
-
height = gr.Slider(
|
| 422 |
-
label="Height",
|
| 423 |
-
minimum=256,
|
| 424 |
-
maximum=MAX_IMAGE_SIZE,
|
| 425 |
-
step=32,
|
| 426 |
-
value=1024,
|
| 427 |
-
)
|
| 428 |
-
width = gr.Slider(
|
| 429 |
-
label="Width",
|
| 430 |
-
minimum=256,
|
| 431 |
-
maximum=MAX_IMAGE_SIZE,
|
| 432 |
-
step=32,
|
| 433 |
-
value=1024,
|
| 434 |
-
)
|
| 435 |
-
with gr.Row():
|
| 436 |
-
flow_dpms_inference_steps = gr.Slider(
|
| 437 |
-
label="Sampling steps",
|
| 438 |
-
minimum=5,
|
| 439 |
-
maximum=40,
|
| 440 |
-
step=1,
|
| 441 |
-
value=18,
|
| 442 |
-
)
|
| 443 |
-
flow_dpms_guidance_scale = gr.Slider(
|
| 444 |
-
label="CFG Guidance scale",
|
| 445 |
-
minimum=1,
|
| 446 |
-
maximum=10,
|
| 447 |
-
step=0.1,
|
| 448 |
-
value=5.0,
|
| 449 |
-
)
|
| 450 |
-
flow_dpms_pag_guidance_scale = gr.Slider(
|
| 451 |
-
label="PAG Guidance scale",
|
| 452 |
-
minimum=1,
|
| 453 |
-
maximum=4,
|
| 454 |
-
step=0.5,
|
| 455 |
-
value=2.0,
|
| 456 |
-
)
|
| 457 |
-
with gr.Row():
|
| 458 |
-
use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False, visible=True)
|
| 459 |
-
negative_prompt = gr.Text(
|
| 460 |
-
label="Negative prompt",
|
| 461 |
-
max_lines=1,
|
| 462 |
-
placeholder="Enter a negative prompt",
|
| 463 |
-
visible=True,
|
| 464 |
-
)
|
| 465 |
-
style_selection = gr.Radio(
|
| 466 |
-
show_label=True,
|
| 467 |
-
container=True,
|
| 468 |
-
interactive=True,
|
| 469 |
-
choices=STYLE_NAMES,
|
| 470 |
-
value=DEFAULT_STYLE_NAME,
|
| 471 |
-
label="Image Style",
|
| 472 |
-
)
|
| 473 |
-
seed = gr.Slider(
|
| 474 |
-
label="Seed",
|
| 475 |
-
minimum=0,
|
| 476 |
-
maximum=MAX_SEED,
|
| 477 |
-
step=1,
|
| 478 |
-
value=0,
|
| 479 |
-
)
|
| 480 |
-
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
| 481 |
-
with gr.Row(visible=True):
|
| 482 |
-
schedule = gr.Radio(
|
| 483 |
-
show_label=True,
|
| 484 |
-
container=True,
|
| 485 |
-
interactive=True,
|
| 486 |
-
choices=SCHEDULE_NAME,
|
| 487 |
-
value=DEFAULT_SCHEDULE_NAME,
|
| 488 |
-
label="Sampler Schedule",
|
| 489 |
-
visible=True,
|
| 490 |
-
)
|
| 491 |
-
num_imgs = gr.Slider(
|
| 492 |
-
label="Num Images",
|
| 493 |
-
minimum=1,
|
| 494 |
-
maximum=6,
|
| 495 |
-
step=1,
|
| 496 |
-
value=1,
|
| 497 |
-
)
|
| 498 |
-
|
| 499 |
-
run_button.click(fn=run_inference, inputs=num_imgs, outputs=info_box)
|
| 500 |
-
|
| 501 |
-
gr.Examples(
|
| 502 |
-
examples=examples,
|
| 503 |
-
inputs=prompt,
|
| 504 |
-
outputs=[result],
|
| 505 |
-
fn=generate,
|
| 506 |
-
cache_examples=CACHE_EXAMPLES,
|
| 507 |
-
)
|
| 508 |
-
gr.Examples(
|
| 509 |
-
examples=examples,
|
| 510 |
-
inputs=prompt,
|
| 511 |
-
outputs=[result_2],
|
| 512 |
-
fn=generate_2,
|
| 513 |
-
cache_examples=CACHE_EXAMPLES,
|
| 514 |
-
)
|
| 515 |
-
|
| 516 |
-
use_negative_prompt.change(
|
| 517 |
-
fn=lambda x: gr.update(visible=x),
|
| 518 |
-
inputs=use_negative_prompt,
|
| 519 |
-
outputs=negative_prompt,
|
| 520 |
-
api_name=False,
|
| 521 |
-
)
|
| 522 |
-
|
| 523 |
-
run_button.click(
|
| 524 |
-
fn=generate,
|
| 525 |
-
inputs=[
|
| 526 |
-
prompt,
|
| 527 |
-
negative_prompt,
|
| 528 |
-
style_selection,
|
| 529 |
-
use_negative_prompt,
|
| 530 |
-
num_imgs,
|
| 531 |
-
seed,
|
| 532 |
-
height,
|
| 533 |
-
width,
|
| 534 |
-
flow_dpms_guidance_scale,
|
| 535 |
-
flow_dpms_pag_guidance_scale,
|
| 536 |
-
flow_dpms_inference_steps,
|
| 537 |
-
randomize_seed,
|
| 538 |
-
],
|
| 539 |
-
outputs=[result],
|
| 540 |
-
queue=True,
|
| 541 |
-
)
|
| 542 |
-
|
| 543 |
-
run_button2.click(
|
| 544 |
-
fn=generate_2,
|
| 545 |
-
inputs=[
|
| 546 |
-
prompt,
|
| 547 |
-
negative_prompt,
|
| 548 |
-
style_selection,
|
| 549 |
-
use_negative_prompt,
|
| 550 |
-
num_imgs,
|
| 551 |
-
seed,
|
| 552 |
-
height,
|
| 553 |
-
width,
|
| 554 |
-
flow_dpms_guidance_scale,
|
| 555 |
-
flow_dpms_pag_guidance_scale,
|
| 556 |
-
flow_dpms_inference_steps,
|
| 557 |
-
randomize_seed,
|
| 558 |
-
],
|
| 559 |
-
outputs=[result_2],
|
| 560 |
-
queue=True,
|
| 561 |
-
)
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
if __name__ == "__main__":
|
| 565 |
-
demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=DEMO_PORT, debug=True, share=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app/safety_check.py
DELETED
|
@@ -1,72 +0,0 @@
|
|
| 1 |
-
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
-
#
|
| 3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
-
# you may not use this file except in compliance with the License.
|
| 5 |
-
# You may obtain a copy of the License at
|
| 6 |
-
#
|
| 7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
-
#
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
#
|
| 15 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
-
# This file is copy from https://github.com/mit-han-lab/hart/blob/main/hart/utils/safety_check.py
|
| 17 |
-
import json
|
| 18 |
-
from functools import lru_cache
|
| 19 |
-
from typing import Dict, Sequence
|
| 20 |
-
|
| 21 |
-
import torch
|
| 22 |
-
import transformers
|
| 23 |
-
from torch.nn.functional import softmax
|
| 24 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 25 |
-
|
| 26 |
-
__all__ = ["is_dangerous"]
|
| 27 |
-
|
| 28 |
-
__rank, __local_rank, __world_size, __device = (
|
| 29 |
-
0,
|
| 30 |
-
0,
|
| 31 |
-
1,
|
| 32 |
-
"cuda" if torch.cuda.is_available() else "cpu",
|
| 33 |
-
)
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
safety_policy = """
|
| 37 |
-
The prompt must not seek to generate harmful, abusive, hateful, sexually explicit, or otherwise inappropriate content targeting individuals or protected groups.
|
| 38 |
-
"""
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
def is_dangerous(tokenizer, model, user_prompt, threshold=0.5):
|
| 42 |
-
prompt = f"""You are a policy expert trying to help determine whether a user
|
| 43 |
-
prompt is in violation of the defined safety policies.
|
| 44 |
-
|
| 45 |
-
<start_of_turn>
|
| 46 |
-
Human Question: {user_prompt.strip()}
|
| 47 |
-
<end_of_turn>
|
| 48 |
-
|
| 49 |
-
Our safety principle is defined in the below:
|
| 50 |
-
|
| 51 |
-
{safety_policy.strip()}
|
| 52 |
-
|
| 53 |
-
Does the human question violate the above principle? Your answer must start
|
| 54 |
-
with 'Yes' or 'No'. And then walk through step by step to be sure we answer
|
| 55 |
-
correctly.
|
| 56 |
-
"""
|
| 57 |
-
|
| 58 |
-
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
|
| 59 |
-
with torch.no_grad():
|
| 60 |
-
logits = model(**inputs).logits
|
| 61 |
-
|
| 62 |
-
# Extract the logits for the Yes and No tokens
|
| 63 |
-
vocab = tokenizer.get_vocab()
|
| 64 |
-
selected_logits = logits[0, -1, [vocab["Yes"], vocab["No"]]]
|
| 65 |
-
|
| 66 |
-
# Convert these logits to a probability with softmax
|
| 67 |
-
probabilities = softmax(selected_logits, dim=0)
|
| 68 |
-
|
| 69 |
-
# Return probability of 'Yes'
|
| 70 |
-
score = probabilities[0].item()
|
| 71 |
-
|
| 72 |
-
return score > threshold
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app/sana_controlnet_pipeline.py
DELETED
|
@@ -1,353 +0,0 @@
|
|
| 1 |
-
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
-
#
|
| 3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
-
# you may not use this file except in compliance with the License.
|
| 5 |
-
# You may obtain a copy of the License at
|
| 6 |
-
#
|
| 7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
-
#
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
#
|
| 15 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
-
import warnings
|
| 17 |
-
from dataclasses import dataclass, field
|
| 18 |
-
from typing import Optional, Tuple
|
| 19 |
-
|
| 20 |
-
import numpy as np
|
| 21 |
-
import pyrallis
|
| 22 |
-
import torch
|
| 23 |
-
import torch.nn as nn
|
| 24 |
-
from PIL import Image
|
| 25 |
-
|
| 26 |
-
warnings.filterwarnings("ignore") # ignore warning
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
from diffusion import DPMS, FlowEuler
|
| 30 |
-
from diffusion.data.datasets.utils import (
|
| 31 |
-
ASPECT_RATIO_512_TEST,
|
| 32 |
-
ASPECT_RATIO_1024_TEST,
|
| 33 |
-
ASPECT_RATIO_2048_TEST,
|
| 34 |
-
ASPECT_RATIO_4096_TEST,
|
| 35 |
-
)
|
| 36 |
-
from diffusion.model.builder import build_model, get_tokenizer_and_text_encoder, get_vae, vae_decode, vae_encode
|
| 37 |
-
from diffusion.model.utils import get_weight_dtype, prepare_prompt_ar, resize_and_crop_tensor
|
| 38 |
-
from diffusion.utils.config import SanaConfig, model_init_config
|
| 39 |
-
from diffusion.utils.logger import get_root_logger
|
| 40 |
-
from tools.controlnet.utils import get_scribble_map, transform_control_signal
|
| 41 |
-
from tools.download import find_model
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
def guidance_type_select(default_guidance_type, pag_scale, attn_type):
|
| 45 |
-
guidance_type = default_guidance_type
|
| 46 |
-
if not (pag_scale > 1.0 and attn_type == "linear"):
|
| 47 |
-
guidance_type = "classifier-free"
|
| 48 |
-
elif pag_scale > 1.0 and attn_type == "linear":
|
| 49 |
-
guidance_type = "classifier-free_PAG"
|
| 50 |
-
return guidance_type
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
|
| 54 |
-
"""Returns binned height and width."""
|
| 55 |
-
ar = float(height / width)
|
| 56 |
-
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
|
| 57 |
-
default_hw = ratios[closest_ratio]
|
| 58 |
-
return int(default_hw[0]), int(default_hw[1])
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
def get_ar_from_ref_image(ref_image):
|
| 62 |
-
def reduce_ratio(h, w):
|
| 63 |
-
def gcd(a, b):
|
| 64 |
-
while b:
|
| 65 |
-
a, b = b, a % b
|
| 66 |
-
return a
|
| 67 |
-
|
| 68 |
-
divisor = gcd(h, w)
|
| 69 |
-
return f"{h // divisor}:{w // divisor}"
|
| 70 |
-
|
| 71 |
-
if isinstance(ref_image, str):
|
| 72 |
-
ref_image = Image.open(ref_image)
|
| 73 |
-
w, h = ref_image.size
|
| 74 |
-
return reduce_ratio(h, w)
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
@dataclass
|
| 78 |
-
class SanaControlNetInference(SanaConfig):
|
| 79 |
-
config: Optional[str] = "configs/sana_config/1024ms/Sana_1600M_img1024.yaml" # config
|
| 80 |
-
model_path: str = field(
|
| 81 |
-
default="output/Sana_D20/SANA.pth", metadata={"help": "Path to the model file (positional)"}
|
| 82 |
-
)
|
| 83 |
-
output: str = "./output"
|
| 84 |
-
bs: int = 1
|
| 85 |
-
image_size: int = 1024
|
| 86 |
-
cfg_scale: float = 5.0
|
| 87 |
-
pag_scale: float = 2.0
|
| 88 |
-
seed: int = 42
|
| 89 |
-
step: int = -1
|
| 90 |
-
custom_image_size: Optional[int] = None
|
| 91 |
-
shield_model_path: str = field(
|
| 92 |
-
default="google/shieldgemma-2b",
|
| 93 |
-
metadata={"help": "The path to shield model, we employ ShieldGemma-2B by default."},
|
| 94 |
-
)
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
class SanaControlNetPipeline(nn.Module):
|
| 98 |
-
def __init__(
|
| 99 |
-
self,
|
| 100 |
-
config: Optional[str] = "configs/sana_config/1024ms/Sana_1600M_img1024.yaml",
|
| 101 |
-
):
|
| 102 |
-
super().__init__()
|
| 103 |
-
config = pyrallis.load(SanaControlNetInference, open(config))
|
| 104 |
-
self.args = self.config = config
|
| 105 |
-
|
| 106 |
-
# set some hyper-parameters
|
| 107 |
-
self.image_size = self.config.model.image_size
|
| 108 |
-
|
| 109 |
-
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 110 |
-
logger = get_root_logger()
|
| 111 |
-
self.logger = logger
|
| 112 |
-
self.progress_fn = lambda progress, desc: None
|
| 113 |
-
self.thickness = 2
|
| 114 |
-
self.blend_alpha = 0.0
|
| 115 |
-
|
| 116 |
-
self.latent_size = self.image_size // config.vae.vae_downsample_rate
|
| 117 |
-
self.max_sequence_length = config.text_encoder.model_max_length
|
| 118 |
-
self.flow_shift = config.scheduler.flow_shift
|
| 119 |
-
guidance_type = "classifier-free_PAG"
|
| 120 |
-
|
| 121 |
-
weight_dtype = get_weight_dtype(config.model.mixed_precision)
|
| 122 |
-
self.weight_dtype = weight_dtype
|
| 123 |
-
self.vae_dtype = get_weight_dtype(config.vae.weight_dtype)
|
| 124 |
-
|
| 125 |
-
self.base_ratios = eval(f"ASPECT_RATIO_{self.image_size}_TEST")
|
| 126 |
-
self.vis_sampler = self.config.scheduler.vis_sampler
|
| 127 |
-
logger.info(f"Sampler {self.vis_sampler}, flow_shift: {self.flow_shift}")
|
| 128 |
-
self.guidance_type = guidance_type_select(guidance_type, self.args.pag_scale, config.model.attn_type)
|
| 129 |
-
logger.info(f"Inference with {self.weight_dtype}, PAG guidance layer: {self.config.model.pag_applied_layers}")
|
| 130 |
-
|
| 131 |
-
# 1. build vae and text encoder
|
| 132 |
-
self.vae = self.build_vae(config.vae)
|
| 133 |
-
self.tokenizer, self.text_encoder = self.build_text_encoder(config.text_encoder)
|
| 134 |
-
|
| 135 |
-
# 2. build Sana model
|
| 136 |
-
self.model = self.build_sana_model(config).to(self.device)
|
| 137 |
-
|
| 138 |
-
# 3. pre-compute null embedding
|
| 139 |
-
with torch.no_grad():
|
| 140 |
-
null_caption_token = self.tokenizer(
|
| 141 |
-
"", max_length=self.max_sequence_length, padding="max_length", truncation=True, return_tensors="pt"
|
| 142 |
-
).to(self.device)
|
| 143 |
-
self.null_caption_embs = self.text_encoder(null_caption_token.input_ids, null_caption_token.attention_mask)[
|
| 144 |
-
0
|
| 145 |
-
]
|
| 146 |
-
|
| 147 |
-
def build_vae(self, config):
|
| 148 |
-
vae = get_vae(config.vae_type, config.vae_pretrained, self.device).to(self.vae_dtype)
|
| 149 |
-
return vae
|
| 150 |
-
|
| 151 |
-
def build_text_encoder(self, config):
|
| 152 |
-
tokenizer, text_encoder = get_tokenizer_and_text_encoder(name=config.text_encoder_name, device=self.device)
|
| 153 |
-
return tokenizer, text_encoder
|
| 154 |
-
|
| 155 |
-
def build_sana_model(self, config):
|
| 156 |
-
# model setting
|
| 157 |
-
model_kwargs = model_init_config(config, latent_size=self.latent_size)
|
| 158 |
-
model = build_model(
|
| 159 |
-
config.model.model,
|
| 160 |
-
use_fp32_attention=config.model.get("fp32_attention", False) and config.model.mixed_precision != "bf16",
|
| 161 |
-
**model_kwargs,
|
| 162 |
-
)
|
| 163 |
-
self.logger.info(f"use_fp32_attention: {model.fp32_attention}")
|
| 164 |
-
self.logger.info(
|
| 165 |
-
f"{model.__class__.__name__}:{config.model.model},"
|
| 166 |
-
f"Model Parameters: {sum(p.numel() for p in model.parameters()):,}"
|
| 167 |
-
)
|
| 168 |
-
return model
|
| 169 |
-
|
| 170 |
-
def from_pretrained(self, model_path):
|
| 171 |
-
state_dict = find_model(model_path)
|
| 172 |
-
state_dict = state_dict.get("state_dict", state_dict)
|
| 173 |
-
if "pos_embed" in state_dict:
|
| 174 |
-
del state_dict["pos_embed"]
|
| 175 |
-
missing, unexpected = self.model.load_state_dict(state_dict, strict=False)
|
| 176 |
-
self.model.eval().to(self.weight_dtype)
|
| 177 |
-
|
| 178 |
-
self.logger.info("Generating sample from ckpt: %s" % model_path)
|
| 179 |
-
self.logger.warning(f"Missing keys: {missing}")
|
| 180 |
-
self.logger.warning(f"Unexpected keys: {unexpected}")
|
| 181 |
-
|
| 182 |
-
def register_progress_bar(self, progress_fn=None):
|
| 183 |
-
self.progress_fn = progress_fn if progress_fn is not None else self.progress_fn
|
| 184 |
-
|
| 185 |
-
def set_blend_alpha(self, blend_alpha):
|
| 186 |
-
self.blend_alpha = blend_alpha
|
| 187 |
-
|
| 188 |
-
@torch.inference_mode()
|
| 189 |
-
def forward(
|
| 190 |
-
self,
|
| 191 |
-
prompt=None,
|
| 192 |
-
ref_image=None,
|
| 193 |
-
negative_prompt="",
|
| 194 |
-
num_inference_steps=20,
|
| 195 |
-
guidance_scale=5,
|
| 196 |
-
pag_guidance_scale=2.5,
|
| 197 |
-
num_images_per_prompt=1,
|
| 198 |
-
sketch_thickness=2,
|
| 199 |
-
generator=torch.Generator().manual_seed(42),
|
| 200 |
-
latents=None,
|
| 201 |
-
):
|
| 202 |
-
self.ori_height, self.ori_width = ref_image.height, ref_image.width
|
| 203 |
-
self.guidance_type = guidance_type_select(self.guidance_type, pag_guidance_scale, self.config.model.attn_type)
|
| 204 |
-
|
| 205 |
-
# 1. pre-compute negative embedding
|
| 206 |
-
if negative_prompt != "":
|
| 207 |
-
null_caption_token = self.tokenizer(
|
| 208 |
-
negative_prompt,
|
| 209 |
-
max_length=self.max_sequence_length,
|
| 210 |
-
padding="max_length",
|
| 211 |
-
truncation=True,
|
| 212 |
-
return_tensors="pt",
|
| 213 |
-
).to(self.device)
|
| 214 |
-
self.null_caption_embs = self.text_encoder(null_caption_token.input_ids, null_caption_token.attention_mask)[
|
| 215 |
-
0
|
| 216 |
-
]
|
| 217 |
-
|
| 218 |
-
if prompt is None:
|
| 219 |
-
prompt = [""]
|
| 220 |
-
prompts = prompt if isinstance(prompt, list) else [prompt]
|
| 221 |
-
samples = []
|
| 222 |
-
|
| 223 |
-
for prompt in prompts:
|
| 224 |
-
# data prepare
|
| 225 |
-
prompts, hw, ar = (
|
| 226 |
-
[],
|
| 227 |
-
torch.tensor([[self.image_size, self.image_size]], dtype=torch.float, device=self.device).repeat(
|
| 228 |
-
num_images_per_prompt, 1
|
| 229 |
-
),
|
| 230 |
-
torch.tensor([[1.0]], device=self.device).repeat(num_images_per_prompt, 1),
|
| 231 |
-
)
|
| 232 |
-
|
| 233 |
-
ar = get_ar_from_ref_image(ref_image)
|
| 234 |
-
prompt += f" --ar {ar}"
|
| 235 |
-
for _ in range(num_images_per_prompt):
|
| 236 |
-
prompt_clean, _, hw, ar, custom_hw = prepare_prompt_ar(
|
| 237 |
-
prompt, self.base_ratios, device=self.device, show=False
|
| 238 |
-
)
|
| 239 |
-
prompts.append(prompt_clean.strip())
|
| 240 |
-
|
| 241 |
-
self.latent_size_h, self.latent_size_w = (
|
| 242 |
-
int(hw[0, 0] // self.config.vae.vae_downsample_rate),
|
| 243 |
-
int(hw[0, 1] // self.config.vae.vae_downsample_rate),
|
| 244 |
-
)
|
| 245 |
-
|
| 246 |
-
with torch.no_grad():
|
| 247 |
-
# prepare text feature
|
| 248 |
-
if not self.config.text_encoder.chi_prompt:
|
| 249 |
-
max_length_all = self.config.text_encoder.model_max_length
|
| 250 |
-
prompts_all = prompts
|
| 251 |
-
else:
|
| 252 |
-
chi_prompt = "\n".join(self.config.text_encoder.chi_prompt)
|
| 253 |
-
prompts_all = [chi_prompt + prompt for prompt in prompts]
|
| 254 |
-
num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
|
| 255 |
-
max_length_all = (
|
| 256 |
-
num_chi_prompt_tokens + self.config.text_encoder.model_max_length - 2
|
| 257 |
-
) # magic number 2: [bos], [_]
|
| 258 |
-
|
| 259 |
-
caption_token = self.tokenizer(
|
| 260 |
-
prompts_all,
|
| 261 |
-
max_length=max_length_all,
|
| 262 |
-
padding="max_length",
|
| 263 |
-
truncation=True,
|
| 264 |
-
return_tensors="pt",
|
| 265 |
-
).to(device=self.device)
|
| 266 |
-
select_index = [0] + list(range(-self.config.text_encoder.model_max_length + 1, 0))
|
| 267 |
-
caption_embs = self.text_encoder(caption_token.input_ids, caption_token.attention_mask)[0][:, None][
|
| 268 |
-
:, :, select_index
|
| 269 |
-
].to(self.weight_dtype)
|
| 270 |
-
emb_masks = caption_token.attention_mask[:, select_index]
|
| 271 |
-
null_y = self.null_caption_embs.repeat(len(prompts), 1, 1)[:, None].to(self.weight_dtype)
|
| 272 |
-
|
| 273 |
-
n = len(prompts)
|
| 274 |
-
if latents is None:
|
| 275 |
-
z = torch.randn(
|
| 276 |
-
n,
|
| 277 |
-
self.config.vae.vae_latent_dim,
|
| 278 |
-
self.latent_size_h,
|
| 279 |
-
self.latent_size_w,
|
| 280 |
-
generator=generator,
|
| 281 |
-
device=self.device,
|
| 282 |
-
)
|
| 283 |
-
else:
|
| 284 |
-
z = latents.to(self.device)
|
| 285 |
-
model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks)
|
| 286 |
-
|
| 287 |
-
# control signal
|
| 288 |
-
if isinstance(ref_image, str):
|
| 289 |
-
ref_image = cv2.imread(ref_image)
|
| 290 |
-
elif isinstance(ref_image, Image.Image):
|
| 291 |
-
ref_image = np.array(ref_image)
|
| 292 |
-
control_signal = get_scribble_map(
|
| 293 |
-
input_image=ref_image,
|
| 294 |
-
det="Scribble_HED",
|
| 295 |
-
detect_resolution=int(hw.min()),
|
| 296 |
-
thickness=sketch_thickness,
|
| 297 |
-
)
|
| 298 |
-
|
| 299 |
-
control_signal = transform_control_signal(control_signal, hw).to(self.device).to(self.weight_dtype)
|
| 300 |
-
|
| 301 |
-
control_signal_latent = vae_encode(
|
| 302 |
-
self.config.vae.vae_type, self.vae, control_signal, self.config.vae.sample_posterior, self.device
|
| 303 |
-
)
|
| 304 |
-
|
| 305 |
-
model_kwargs["control_signal"] = control_signal_latent
|
| 306 |
-
|
| 307 |
-
if self.vis_sampler == "flow_euler":
|
| 308 |
-
flow_solver = FlowEuler(
|
| 309 |
-
self.model,
|
| 310 |
-
condition=caption_embs,
|
| 311 |
-
uncondition=null_y,
|
| 312 |
-
cfg_scale=guidance_scale,
|
| 313 |
-
model_kwargs=model_kwargs,
|
| 314 |
-
)
|
| 315 |
-
sample = flow_solver.sample(
|
| 316 |
-
z,
|
| 317 |
-
steps=num_inference_steps,
|
| 318 |
-
)
|
| 319 |
-
elif self.vis_sampler == "flow_dpm-solver":
|
| 320 |
-
scheduler = DPMS(
|
| 321 |
-
self.model.forward_with_dpmsolver,
|
| 322 |
-
condition=caption_embs,
|
| 323 |
-
uncondition=null_y,
|
| 324 |
-
guidance_type=self.guidance_type,
|
| 325 |
-
cfg_scale=guidance_scale,
|
| 326 |
-
model_type="flow",
|
| 327 |
-
model_kwargs=model_kwargs,
|
| 328 |
-
schedule="FLOW",
|
| 329 |
-
)
|
| 330 |
-
scheduler.register_progress_bar(self.progress_fn)
|
| 331 |
-
sample = scheduler.sample(
|
| 332 |
-
z,
|
| 333 |
-
steps=num_inference_steps,
|
| 334 |
-
order=2,
|
| 335 |
-
skip_type="time_uniform_flow",
|
| 336 |
-
method="multistep",
|
| 337 |
-
flow_shift=self.flow_shift,
|
| 338 |
-
)
|
| 339 |
-
|
| 340 |
-
sample = sample.to(self.vae_dtype)
|
| 341 |
-
with torch.no_grad():
|
| 342 |
-
sample = vae_decode(self.config.vae.vae_type, self.vae, sample)
|
| 343 |
-
|
| 344 |
-
if self.blend_alpha > 0:
|
| 345 |
-
print(f"blend image and mask with alpha: {self.blend_alpha}")
|
| 346 |
-
sample = sample * (1 - self.blend_alpha) + control_signal * self.blend_alpha
|
| 347 |
-
|
| 348 |
-
sample = resize_and_crop_tensor(sample, self.ori_width, self.ori_height)
|
| 349 |
-
samples.append(sample)
|
| 350 |
-
|
| 351 |
-
return sample
|
| 352 |
-
|
| 353 |
-
return samples
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app/sana_pipeline.py
DELETED
|
@@ -1,304 +0,0 @@
|
|
| 1 |
-
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
-
#
|
| 3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
-
# you may not use this file except in compliance with the License.
|
| 5 |
-
# You may obtain a copy of the License at
|
| 6 |
-
#
|
| 7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
-
#
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
#
|
| 15 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
-
import argparse
|
| 17 |
-
import warnings
|
| 18 |
-
from dataclasses import dataclass, field
|
| 19 |
-
from typing import Optional, Tuple
|
| 20 |
-
|
| 21 |
-
import pyrallis
|
| 22 |
-
import torch
|
| 23 |
-
import torch.nn as nn
|
| 24 |
-
|
| 25 |
-
warnings.filterwarnings("ignore") # ignore warning
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
from diffusion import DPMS, FlowEuler
|
| 29 |
-
from diffusion.data.datasets.utils import (
|
| 30 |
-
ASPECT_RATIO_512_TEST,
|
| 31 |
-
ASPECT_RATIO_1024_TEST,
|
| 32 |
-
ASPECT_RATIO_2048_TEST,
|
| 33 |
-
ASPECT_RATIO_4096_TEST,
|
| 34 |
-
)
|
| 35 |
-
from diffusion.model.builder import build_model, get_tokenizer_and_text_encoder, get_vae, vae_decode
|
| 36 |
-
from diffusion.model.utils import get_weight_dtype, prepare_prompt_ar, resize_and_crop_tensor
|
| 37 |
-
from diffusion.utils.config import SanaConfig, model_init_config
|
| 38 |
-
from diffusion.utils.logger import get_root_logger
|
| 39 |
-
|
| 40 |
-
# from diffusion.utils.misc import read_config
|
| 41 |
-
from tools.download import find_model
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
def guidance_type_select(default_guidance_type, pag_scale, attn_type):
|
| 45 |
-
guidance_type = default_guidance_type
|
| 46 |
-
if not (pag_scale > 1.0 and attn_type == "linear"):
|
| 47 |
-
guidance_type = "classifier-free"
|
| 48 |
-
elif pag_scale > 1.0 and attn_type == "linear":
|
| 49 |
-
guidance_type = "classifier-free_PAG"
|
| 50 |
-
return guidance_type
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
|
| 54 |
-
"""Returns binned height and width."""
|
| 55 |
-
ar = float(height / width)
|
| 56 |
-
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
|
| 57 |
-
default_hw = ratios[closest_ratio]
|
| 58 |
-
return int(default_hw[0]), int(default_hw[1])
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
@dataclass
|
| 62 |
-
class SanaInference(SanaConfig):
|
| 63 |
-
config: Optional[str] = "configs/sana_config/1024ms/Sana_1600M_img1024.yaml" # config
|
| 64 |
-
model_path: str = field(
|
| 65 |
-
default="output/Sana_D20/SANA.pth", metadata={"help": "Path to the model file (positional)"}
|
| 66 |
-
)
|
| 67 |
-
output: str = "./output"
|
| 68 |
-
bs: int = 1
|
| 69 |
-
image_size: int = 1024
|
| 70 |
-
cfg_scale: float = 5.0
|
| 71 |
-
pag_scale: float = 2.0
|
| 72 |
-
seed: int = 42
|
| 73 |
-
step: int = -1
|
| 74 |
-
custom_image_size: Optional[int] = None
|
| 75 |
-
shield_model_path: str = field(
|
| 76 |
-
default="google/shieldgemma-2b",
|
| 77 |
-
metadata={"help": "The path to shield model, we employ ShieldGemma-2B by default."},
|
| 78 |
-
)
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
class SanaPipeline(nn.Module):
|
| 82 |
-
def __init__(
|
| 83 |
-
self,
|
| 84 |
-
config: Optional[str] = "configs/sana_config/1024ms/Sana_1600M_img1024.yaml",
|
| 85 |
-
):
|
| 86 |
-
super().__init__()
|
| 87 |
-
config = pyrallis.load(SanaInference, open(config))
|
| 88 |
-
self.args = self.config = config
|
| 89 |
-
|
| 90 |
-
# set some hyper-parameters
|
| 91 |
-
self.image_size = self.config.model.image_size
|
| 92 |
-
|
| 93 |
-
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 94 |
-
logger = get_root_logger()
|
| 95 |
-
self.logger = logger
|
| 96 |
-
self.progress_fn = lambda progress, desc: None
|
| 97 |
-
|
| 98 |
-
self.latent_size = self.image_size // config.vae.vae_downsample_rate
|
| 99 |
-
self.max_sequence_length = config.text_encoder.model_max_length
|
| 100 |
-
self.flow_shift = config.scheduler.flow_shift
|
| 101 |
-
guidance_type = "classifier-free_PAG"
|
| 102 |
-
|
| 103 |
-
weight_dtype = get_weight_dtype(config.model.mixed_precision)
|
| 104 |
-
self.weight_dtype = weight_dtype
|
| 105 |
-
self.vae_dtype = get_weight_dtype(config.vae.weight_dtype)
|
| 106 |
-
|
| 107 |
-
self.base_ratios = eval(f"ASPECT_RATIO_{self.image_size}_TEST")
|
| 108 |
-
self.vis_sampler = self.config.scheduler.vis_sampler
|
| 109 |
-
logger.info(f"Sampler {self.vis_sampler}, flow_shift: {self.flow_shift}")
|
| 110 |
-
self.guidance_type = guidance_type_select(guidance_type, self.args.pag_scale, config.model.attn_type)
|
| 111 |
-
logger.info(f"Inference with {self.weight_dtype}, PAG guidance layer: {self.config.model.pag_applied_layers}")
|
| 112 |
-
|
| 113 |
-
# 1. build vae and text encoder
|
| 114 |
-
self.vae = self.build_vae(config.vae)
|
| 115 |
-
self.tokenizer, self.text_encoder = self.build_text_encoder(config.text_encoder)
|
| 116 |
-
|
| 117 |
-
# 2. build Sana model
|
| 118 |
-
self.model = self.build_sana_model(config).to(self.device)
|
| 119 |
-
|
| 120 |
-
# 3. pre-compute null embedding
|
| 121 |
-
with torch.no_grad():
|
| 122 |
-
null_caption_token = self.tokenizer(
|
| 123 |
-
"", max_length=self.max_sequence_length, padding="max_length", truncation=True, return_tensors="pt"
|
| 124 |
-
).to(self.device)
|
| 125 |
-
self.null_caption_embs = self.text_encoder(null_caption_token.input_ids, null_caption_token.attention_mask)[
|
| 126 |
-
0
|
| 127 |
-
]
|
| 128 |
-
|
| 129 |
-
def build_vae(self, config):
|
| 130 |
-
vae = get_vae(config.vae_type, config.vae_pretrained, self.device).to(self.vae_dtype)
|
| 131 |
-
return vae
|
| 132 |
-
|
| 133 |
-
def build_text_encoder(self, config):
|
| 134 |
-
tokenizer, text_encoder = get_tokenizer_and_text_encoder(name=config.text_encoder_name, device=self.device)
|
| 135 |
-
return tokenizer, text_encoder
|
| 136 |
-
|
| 137 |
-
def build_sana_model(self, config):
|
| 138 |
-
# model setting
|
| 139 |
-
model_kwargs = model_init_config(config, latent_size=self.latent_size)
|
| 140 |
-
model = build_model(
|
| 141 |
-
config.model.model,
|
| 142 |
-
use_fp32_attention=config.model.get("fp32_attention", False) and config.model.mixed_precision != "bf16",
|
| 143 |
-
**model_kwargs,
|
| 144 |
-
)
|
| 145 |
-
self.logger.info(f"use_fp32_attention: {model.fp32_attention}")
|
| 146 |
-
self.logger.info(
|
| 147 |
-
f"{model.__class__.__name__}:{config.model.model},"
|
| 148 |
-
f"Model Parameters: {sum(p.numel() for p in model.parameters()):,}"
|
| 149 |
-
)
|
| 150 |
-
return model
|
| 151 |
-
|
| 152 |
-
def from_pretrained(self, model_path):
|
| 153 |
-
state_dict = find_model(model_path)
|
| 154 |
-
state_dict = state_dict.get("state_dict", state_dict)
|
| 155 |
-
if "pos_embed" in state_dict:
|
| 156 |
-
del state_dict["pos_embed"]
|
| 157 |
-
missing, unexpected = self.model.load_state_dict(state_dict, strict=False)
|
| 158 |
-
self.model.eval().to(self.weight_dtype)
|
| 159 |
-
|
| 160 |
-
self.logger.info("Generating sample from ckpt: %s" % model_path)
|
| 161 |
-
self.logger.warning(f"Missing keys: {missing}")
|
| 162 |
-
self.logger.warning(f"Unexpected keys: {unexpected}")
|
| 163 |
-
|
| 164 |
-
def register_progress_bar(self, progress_fn=None):
|
| 165 |
-
self.progress_fn = progress_fn if progress_fn is not None else self.progress_fn
|
| 166 |
-
|
| 167 |
-
@torch.inference_mode()
|
| 168 |
-
def forward(
|
| 169 |
-
self,
|
| 170 |
-
prompt=None,
|
| 171 |
-
height=1024,
|
| 172 |
-
width=1024,
|
| 173 |
-
negative_prompt="",
|
| 174 |
-
num_inference_steps=20,
|
| 175 |
-
guidance_scale=5,
|
| 176 |
-
pag_guidance_scale=2.5,
|
| 177 |
-
num_images_per_prompt=1,
|
| 178 |
-
generator=torch.Generator().manual_seed(42),
|
| 179 |
-
latents=None,
|
| 180 |
-
):
|
| 181 |
-
self.ori_height, self.ori_width = height, width
|
| 182 |
-
self.height, self.width = classify_height_width_bin(height, width, ratios=self.base_ratios)
|
| 183 |
-
self.latent_size_h, self.latent_size_w = (
|
| 184 |
-
self.height // self.config.vae.vae_downsample_rate,
|
| 185 |
-
self.width // self.config.vae.vae_downsample_rate,
|
| 186 |
-
)
|
| 187 |
-
self.guidance_type = guidance_type_select(self.guidance_type, pag_guidance_scale, self.config.model.attn_type)
|
| 188 |
-
|
| 189 |
-
# 1. pre-compute negative embedding
|
| 190 |
-
if negative_prompt != "":
|
| 191 |
-
null_caption_token = self.tokenizer(
|
| 192 |
-
negative_prompt,
|
| 193 |
-
max_length=self.max_sequence_length,
|
| 194 |
-
padding="max_length",
|
| 195 |
-
truncation=True,
|
| 196 |
-
return_tensors="pt",
|
| 197 |
-
).to(self.device)
|
| 198 |
-
self.null_caption_embs = self.text_encoder(null_caption_token.input_ids, null_caption_token.attention_mask)[
|
| 199 |
-
0
|
| 200 |
-
]
|
| 201 |
-
|
| 202 |
-
if prompt is None:
|
| 203 |
-
prompt = [""]
|
| 204 |
-
prompts = prompt if isinstance(prompt, list) else [prompt]
|
| 205 |
-
samples = []
|
| 206 |
-
|
| 207 |
-
for prompt in prompts:
|
| 208 |
-
# data prepare
|
| 209 |
-
prompts, hw, ar = (
|
| 210 |
-
[],
|
| 211 |
-
torch.tensor([[self.image_size, self.image_size]], dtype=torch.float, device=self.device).repeat(
|
| 212 |
-
num_images_per_prompt, 1
|
| 213 |
-
),
|
| 214 |
-
torch.tensor([[1.0]], device=self.device).repeat(num_images_per_prompt, 1),
|
| 215 |
-
)
|
| 216 |
-
|
| 217 |
-
for _ in range(num_images_per_prompt):
|
| 218 |
-
prompts.append(prepare_prompt_ar(prompt, self.base_ratios, device=self.device, show=False)[0].strip())
|
| 219 |
-
|
| 220 |
-
with torch.no_grad():
|
| 221 |
-
# prepare text feature
|
| 222 |
-
if not self.config.text_encoder.chi_prompt:
|
| 223 |
-
max_length_all = self.config.text_encoder.model_max_length
|
| 224 |
-
prompts_all = prompts
|
| 225 |
-
else:
|
| 226 |
-
chi_prompt = "\n".join(self.config.text_encoder.chi_prompt)
|
| 227 |
-
prompts_all = [chi_prompt + prompt for prompt in prompts]
|
| 228 |
-
num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
|
| 229 |
-
max_length_all = (
|
| 230 |
-
num_chi_prompt_tokens + self.config.text_encoder.model_max_length - 2
|
| 231 |
-
) # magic number 2: [bos], [_]
|
| 232 |
-
|
| 233 |
-
caption_token = self.tokenizer(
|
| 234 |
-
prompts_all,
|
| 235 |
-
max_length=max_length_all,
|
| 236 |
-
padding="max_length",
|
| 237 |
-
truncation=True,
|
| 238 |
-
return_tensors="pt",
|
| 239 |
-
).to(device=self.device)
|
| 240 |
-
select_index = [0] + list(range(-self.config.text_encoder.model_max_length + 1, 0))
|
| 241 |
-
caption_embs = self.text_encoder(caption_token.input_ids, caption_token.attention_mask)[0][:, None][
|
| 242 |
-
:, :, select_index
|
| 243 |
-
].to(self.weight_dtype)
|
| 244 |
-
emb_masks = caption_token.attention_mask[:, select_index]
|
| 245 |
-
null_y = self.null_caption_embs.repeat(len(prompts), 1, 1)[:, None].to(self.weight_dtype)
|
| 246 |
-
|
| 247 |
-
n = len(prompts)
|
| 248 |
-
if latents is None:
|
| 249 |
-
z = torch.randn(
|
| 250 |
-
n,
|
| 251 |
-
self.config.vae.vae_latent_dim,
|
| 252 |
-
self.latent_size_h,
|
| 253 |
-
self.latent_size_w,
|
| 254 |
-
generator=generator,
|
| 255 |
-
device=self.device,
|
| 256 |
-
)
|
| 257 |
-
else:
|
| 258 |
-
z = latents.to(self.device)
|
| 259 |
-
model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks)
|
| 260 |
-
if self.vis_sampler == "flow_euler":
|
| 261 |
-
flow_solver = FlowEuler(
|
| 262 |
-
self.model,
|
| 263 |
-
condition=caption_embs,
|
| 264 |
-
uncondition=null_y,
|
| 265 |
-
cfg_scale=guidance_scale,
|
| 266 |
-
model_kwargs=model_kwargs,
|
| 267 |
-
)
|
| 268 |
-
sample = flow_solver.sample(
|
| 269 |
-
z,
|
| 270 |
-
steps=num_inference_steps,
|
| 271 |
-
)
|
| 272 |
-
elif self.vis_sampler == "flow_dpm-solver":
|
| 273 |
-
scheduler = DPMS(
|
| 274 |
-
self.model,
|
| 275 |
-
condition=caption_embs,
|
| 276 |
-
uncondition=null_y,
|
| 277 |
-
guidance_type=self.guidance_type,
|
| 278 |
-
cfg_scale=guidance_scale,
|
| 279 |
-
pag_scale=pag_guidance_scale,
|
| 280 |
-
pag_applied_layers=self.config.model.pag_applied_layers,
|
| 281 |
-
model_type="flow",
|
| 282 |
-
model_kwargs=model_kwargs,
|
| 283 |
-
schedule="FLOW",
|
| 284 |
-
)
|
| 285 |
-
scheduler.register_progress_bar(self.progress_fn)
|
| 286 |
-
sample = scheduler.sample(
|
| 287 |
-
z,
|
| 288 |
-
steps=num_inference_steps,
|
| 289 |
-
order=2,
|
| 290 |
-
skip_type="time_uniform_flow",
|
| 291 |
-
method="multistep",
|
| 292 |
-
flow_shift=self.flow_shift,
|
| 293 |
-
)
|
| 294 |
-
|
| 295 |
-
sample = sample.to(self.vae_dtype)
|
| 296 |
-
with torch.no_grad():
|
| 297 |
-
sample = vae_decode(self.config.vae.vae_type, self.vae, sample)
|
| 298 |
-
|
| 299 |
-
sample = resize_and_crop_tensor(sample, self.ori_width, self.ori_height)
|
| 300 |
-
samples.append(sample)
|
| 301 |
-
|
| 302 |
-
return sample
|
| 303 |
-
|
| 304 |
-
return samples
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|