Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,568 +1,358 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
import os
|
|
|
|
| 6 |
import random
|
| 7 |
-
import re
|
| 8 |
-
import sys
|
| 9 |
import warnings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
-
from PIL import Image
|
| 12 |
-
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
|
| 13 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
| 14 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 16 |
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
from diffusers import ZImagePipeline
|
| 22 |
-
from diffusers.models.transformers.transformer_z_image import ZImageTransformer2DModel
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
MODEL_PATH = os.environ.get("MODEL_PATH", "telcom/dee-z-image")
|
| 28 |
-
ENABLE_COMPILE = os.environ.get("ENABLE_COMPILE", "true").lower() == "true"
|
| 29 |
-
ENABLE_WARMUP = os.environ.get("ENABLE_WARMUP", "true").lower() == "true"
|
| 30 |
-
ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "flash_3")
|
| 31 |
-
UNSAFE_MAX_NEW_TOKEN = int(os.environ.get("UNSAFE_MAX_NEW_TOKEN", "10"))
|
| 32 |
-
DASHSCOPE_API_KEY = os.environ.get("DASHSCOPE_API_KEY")
|
| 33 |
-
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 34 |
-
UNSAFE_PROMPT_CHECK = os.environ.get("UNSAFE_PROMPT_CHECK")
|
| 35 |
-
# =============================================================================
|
| 36 |
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 39 |
warnings.filterwarnings("ignore")
|
| 40 |
logging.getLogger("transformers").setLevel(logging.ERROR)
|
| 41 |
|
| 42 |
-
|
| 43 |
-
"1024": [
|
| 44 |
-
"1024x1024 ( 1:1 )",
|
| 45 |
-
"1152x896 ( 9:7 )",
|
| 46 |
-
"896x1152 ( 7:9 )",
|
| 47 |
-
"1152x864 ( 4:3 )",
|
| 48 |
-
"864x1152 ( 3:4 )",
|
| 49 |
-
"1248x832 ( 3:2 )",
|
| 50 |
-
"832x1248 ( 2:3 )",
|
| 51 |
-
"1280x720 ( 16:9 )",
|
| 52 |
-
"720x1280 ( 9:16 )",
|
| 53 |
-
"1344x576 ( 21:9 )",
|
| 54 |
-
"576x1344 ( 9:21 )",
|
| 55 |
-
],
|
| 56 |
-
"1280": [
|
| 57 |
-
"1280x1280 ( 1:1 )",
|
| 58 |
-
"1440x1120 ( 9:7 )",
|
| 59 |
-
"1120x1440 ( 7:9 )",
|
| 60 |
-
"1472x1104 ( 4:3 )",
|
| 61 |
-
"1104x1472 ( 3:4 )",
|
| 62 |
-
"1536x1024 ( 3:2 )",
|
| 63 |
-
"1024x1536 ( 2:3 )",
|
| 64 |
-
"1536x864 ( 16:9 )",
|
| 65 |
-
"864x1536 ( 9:16 )",
|
| 66 |
-
"1680x720 ( 21:9 )",
|
| 67 |
-
"720x1680 ( 9:21 )",
|
| 68 |
-
],
|
| 69 |
-
"1536": [
|
| 70 |
-
"1536x1536 ( 1:1 )",
|
| 71 |
-
"1728x1344 ( 9:7 )",
|
| 72 |
-
"1344x1728 ( 7:9 )",
|
| 73 |
-
"1728x1296 ( 4:3 )",
|
| 74 |
-
"1296x1728 ( 3:4 )",
|
| 75 |
-
"1872x1248 ( 3:2 )",
|
| 76 |
-
"1248x1872 ( 2:3 )",
|
| 77 |
-
"2048x1152 ( 16:9 )",
|
| 78 |
-
"1152x2048 ( 9:16 )",
|
| 79 |
-
"2016x864 ( 21:9 )",
|
| 80 |
-
"864x2016 ( 9:21 )",
|
| 81 |
-
],
|
| 82 |
-
}
|
| 83 |
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
"Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights."
|
| 91 |
-
],
|
| 92 |
-
[
|
| 93 |
-
'''A vertical digital illustration depicting a serene and majestic Chinese landscape, rendered in a style reminiscent of traditional Shanshui painting but with a modern, clean aesthetic. The scene is dominated by towering, steep cliffs in various shades of blue and teal, which frame a central valley. In the distance, layers of mountains fade into a light blue and white mist, creating a strong sense of atmospheric perspective and depth. A calm, turquoise river flows through the center of the composition, with a small, traditional Chinese boat, possibly a sampan, navigating its waters. The boat has a bright yellow canopy and a red hull, and it leaves a gentle wake behind it. It carries several indistinct figures of people. Sparse vegetation, including green trees and some bare-branched trees, clings to the rocky ledges and peaks. The overall lighting is soft and diffused, casting a tranquil glow over the entire scene. Centered in the image is overlaid text. At the top of the text block is a small, red, circular seal-like logo containing stylized characters. Below it, in a smaller, black, sans-serif font, are the words 'Zao-Xiang * East Beauty & West Fashion * Z-Image'. Directly beneath this, in a larger, elegant black serif font, is the word 'SHOW & SHARE CREATIVITY WITH THE WORLD'. Among them, there are "SHOW & SHARE", "CREATIVITY", and "WITH THE WORLD"'''
|
| 94 |
-
],
|
| 95 |
-
|
| 96 |
-
]
|
| 97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
if match:
|
| 102 |
-
return int(match.group(1)), int(match.group(2))
|
| 103 |
-
return 1024, 1024
|
| 104 |
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
-
|
| 107 |
-
|
|
|
|
| 108 |
|
| 109 |
-
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
text_encoder = AutoModelForCausalLM.from_pretrained(
|
| 135 |
-
os.path.join(model_path, "text_encoder"),
|
| 136 |
-
torch_dtype=torch.bfloat16,
|
| 137 |
-
device_map="cuda",
|
| 138 |
-
).eval()
|
| 139 |
-
|
| 140 |
-
tokenizer = AutoTokenizer.from_pretrained(os.path.join(model_path, "tokenizer"))
|
| 141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
tokenizer.padding_side = "left"
|
| 143 |
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
torch._inductor.config.conv_1x1_as_mm = True
|
| 147 |
-
torch._inductor.config.coordinate_descent_tuning = True
|
| 148 |
-
torch._inductor.config.epilogue_fusion = False
|
| 149 |
-
torch._inductor.config.coordinate_descent_check_all_directions = True
|
| 150 |
-
torch._inductor.config.max_autotune_gemm = True
|
| 151 |
-
torch._inductor.config.max_autotune_gemm_backends = "TRITON,ATEN"
|
| 152 |
-
torch._inductor.config.triton.cudagraphs = False
|
| 153 |
-
|
| 154 |
-
pipe = ZImagePipeline(scheduler=None, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=None)
|
| 155 |
-
|
| 156 |
-
if enable_compile:
|
| 157 |
-
pipe.vae.disable_tiling()
|
| 158 |
-
|
| 159 |
-
if not os.path.exists(model_path):
|
| 160 |
-
transformer = ZImageTransformer2DModel.from_pretrained(
|
| 161 |
-
f"{model_path}", subfolder="transformer", use_auth_token=use_auth_token
|
| 162 |
-
).to("cuda", torch.bfloat16)
|
| 163 |
-
else:
|
| 164 |
-
transformer = ZImageTransformer2DModel.from_pretrained(os.path.join(model_path, "transformer")).to(
|
| 165 |
-
"cuda", torch.bfloat16
|
| 166 |
-
)
|
| 167 |
-
|
| 168 |
-
pipe.transformer = transformer
|
| 169 |
-
pipe.transformer.set_attention_backend(attention_backend)
|
| 170 |
-
|
| 171 |
-
if enable_compile:
|
| 172 |
-
print("Compiling transformer...")
|
| 173 |
-
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=False)
|
| 174 |
-
|
| 175 |
-
pipe.to("cuda", torch.bfloat16)
|
| 176 |
-
|
| 177 |
-
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
| 178 |
-
from transformers import CLIPImageProcessor
|
| 179 |
-
|
| 180 |
-
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
| 181 |
-
safety_feature_extractor = CLIPImageProcessor.from_pretrained(safety_model_id)
|
| 182 |
-
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id, torch_dtype=torch.float16).to("cuda")
|
| 183 |
-
|
| 184 |
-
pipe.safety_feature_extractor = safety_feature_extractor
|
| 185 |
-
pipe.safety_checker = safety_checker
|
| 186 |
-
return pipe
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
def generate_image(
|
| 190 |
-
pipe,
|
| 191 |
-
prompt,
|
| 192 |
-
resolution="1024x1024",
|
| 193 |
-
seed=42,
|
| 194 |
-
guidance_scale=5.0,
|
| 195 |
-
num_inference_steps=50,
|
| 196 |
-
shift=3.0,
|
| 197 |
-
max_sequence_length=512,
|
| 198 |
-
progress=gr.Progress(track_tqdm=True),
|
| 199 |
-
):
|
| 200 |
-
width, height = get_resolution(resolution)
|
| 201 |
-
|
| 202 |
-
generator = torch.Generator("cuda").manual_seed(seed)
|
| 203 |
-
|
| 204 |
-
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=shift)
|
| 205 |
-
pipe.scheduler = scheduler
|
| 206 |
-
|
| 207 |
-
image = pipe(
|
| 208 |
-
prompt=prompt,
|
| 209 |
-
height=height,
|
| 210 |
-
width=width,
|
| 211 |
-
guidance_scale=guidance_scale,
|
| 212 |
-
num_inference_steps=num_inference_steps,
|
| 213 |
-
generator=generator,
|
| 214 |
-
max_sequence_length=max_sequence_length,
|
| 215 |
-
).images[0]
|
| 216 |
-
|
| 217 |
-
return image
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
def warmup_model(pipe, resolutions):
|
| 221 |
-
print("Starting warmup phase...")
|
| 222 |
-
|
| 223 |
-
dummy_prompt = "warmup"
|
| 224 |
-
|
| 225 |
-
for res_str in resolutions:
|
| 226 |
-
print(f"Warming up for resolution: {res_str}")
|
| 227 |
-
try:
|
| 228 |
-
for i in range(3):
|
| 229 |
-
generate_image(
|
| 230 |
-
pipe,
|
| 231 |
-
prompt=dummy_prompt,
|
| 232 |
-
resolution=res_str,
|
| 233 |
-
num_inference_steps=9,
|
| 234 |
-
guidance_scale=0.0,
|
| 235 |
-
seed=42 + i,
|
| 236 |
-
)
|
| 237 |
-
except Exception as e:
|
| 238 |
-
print(f"Warmup failed for {res_str}: {e}")
|
| 239 |
-
|
| 240 |
-
print("Warmup completed.")
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
# ==================== Prompt Expander ====================
|
| 244 |
-
@dataclass
|
| 245 |
-
class PromptOutput:
|
| 246 |
-
status: bool
|
| 247 |
-
prompt: str
|
| 248 |
-
seed: int
|
| 249 |
-
system_prompt: str
|
| 250 |
-
message: str
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
class PromptExpander:
|
| 254 |
-
def __init__(self, backend="api", **kwargs):
|
| 255 |
-
self.backend = backend
|
| 256 |
-
|
| 257 |
-
def decide_system_prompt(self, template_name=None):
|
| 258 |
-
return prompt_template
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
class APIPromptExpander(PromptExpander):
|
| 262 |
-
def __init__(self, api_config=None, **kwargs):
|
| 263 |
-
super().__init__(backend="api", **kwargs)
|
| 264 |
-
self.api_config = api_config or {}
|
| 265 |
-
self.client = self._init_api_client()
|
| 266 |
-
|
| 267 |
-
def _init_api_client(self):
|
| 268 |
-
try:
|
| 269 |
-
from openai import OpenAI
|
| 270 |
-
|
| 271 |
-
api_key = self.api_config.get("api_key") or DASHSCOPE_API_KEY
|
| 272 |
-
base_url = self.api_config.get("base_url", "https://dashscope.aliyuncs.com/compatible-mode/v1")
|
| 273 |
-
|
| 274 |
-
if not api_key:
|
| 275 |
-
print("Warning: DASHSCOPE_API_KEY not found.")
|
| 276 |
-
return None
|
| 277 |
-
|
| 278 |
-
return OpenAI(api_key=api_key, base_url=base_url)
|
| 279 |
-
except ImportError:
|
| 280 |
-
print("Please install openai: pip install openai")
|
| 281 |
-
return None
|
| 282 |
-
except Exception as e:
|
| 283 |
-
print(f"Failed to initialize API client: {e}")
|
| 284 |
-
return None
|
| 285 |
-
|
| 286 |
-
def __call__(self, prompt, system_prompt=None, seed=-1, **kwargs):
|
| 287 |
-
return self.extend(prompt, system_prompt, seed, **kwargs)
|
| 288 |
-
|
| 289 |
-
def extend(self, prompt, system_prompt=None, seed=-1, **kwargs):
|
| 290 |
-
if self.client is None:
|
| 291 |
-
return PromptOutput(False, "", seed, system_prompt, "API client not initialized")
|
| 292 |
-
|
| 293 |
-
if system_prompt is None:
|
| 294 |
-
system_prompt = self.decide_system_prompt()
|
| 295 |
-
|
| 296 |
-
if "{prompt}" in system_prompt:
|
| 297 |
-
system_prompt = system_prompt.format(prompt=prompt)
|
| 298 |
-
prompt = " "
|
| 299 |
-
|
| 300 |
-
try:
|
| 301 |
-
model = self.api_config.get("model", "qwen3-max-preview")
|
| 302 |
-
response = self.client.chat.completions.create(
|
| 303 |
-
model=model,
|
| 304 |
-
messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}],
|
| 305 |
-
temperature=0.7,
|
| 306 |
-
top_p=0.8,
|
| 307 |
-
)
|
| 308 |
-
|
| 309 |
-
content = response.choices[0].message.content
|
| 310 |
-
json_start = content.find("```json")
|
| 311 |
-
if json_start != -1:
|
| 312 |
-
json_end = content.find("```", json_start + 7)
|
| 313 |
-
try:
|
| 314 |
-
json_str = content[json_start + 7 : json_end].strip()
|
| 315 |
-
data = json.loads(json_str)
|
| 316 |
-
expanded_prompt = data.get("revised_prompt", content)
|
| 317 |
-
except:
|
| 318 |
-
expanded_prompt = content
|
| 319 |
-
else:
|
| 320 |
-
expanded_prompt = content
|
| 321 |
-
|
| 322 |
-
return PromptOutput(
|
| 323 |
-
status=True, prompt=expanded_prompt, seed=seed, system_prompt=system_prompt, message=content
|
| 324 |
-
)
|
| 325 |
-
except Exception as e:
|
| 326 |
-
return PromptOutput(False, "", seed, system_prompt, str(e))
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
def create_prompt_expander(backend="api", **kwargs):
|
| 330 |
-
if backend == "api":
|
| 331 |
-
return APIPromptExpander(**kwargs)
|
| 332 |
-
raise ValueError("Only 'api' backend is supported.")
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
pipe = None
|
| 336 |
-
prompt_expander = None
|
| 337 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 338 |
|
| 339 |
-
|
| 340 |
-
|
|
|
|
| 341 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
try:
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
for cat in RES_CHOICES.values():
|
| 349 |
-
all_resolutions.extend(cat)
|
| 350 |
-
warmup_model(pipe, all_resolutions)
|
| 351 |
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
|
|
|
|
| 356 |
try:
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
prompt_expander = None
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
def prompt_enhance(prompt, enable_enhance):
|
| 365 |
-
if not enable_enhance or not prompt_expander:
|
| 366 |
-
return prompt, "Enhancement disabled or not available."
|
| 367 |
-
|
| 368 |
-
if not prompt.strip():
|
| 369 |
-
return "", "Please enter a prompt."
|
| 370 |
-
|
| 371 |
-
try:
|
| 372 |
-
result = prompt_expander(prompt)
|
| 373 |
-
if result.status:
|
| 374 |
-
return result.prompt, result.message
|
| 375 |
-
else:
|
| 376 |
-
return prompt, f"Enhancement failed: {result.message}"
|
| 377 |
-
except Exception as e:
|
| 378 |
-
return prompt, f"Error: {str(e)}"
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
@spaces.GPU
|
| 382 |
-
def generate(
|
| 383 |
-
prompt,
|
| 384 |
-
resolution="1024x1024 ( 1:1 )",
|
| 385 |
-
seed=42,
|
| 386 |
-
steps=9,
|
| 387 |
-
shift=3.0,
|
| 388 |
-
random_seed=True,
|
| 389 |
-
gallery_images=None,
|
| 390 |
-
enhance=False,
|
| 391 |
-
progress=gr.Progress(track_tqdm=True),
|
| 392 |
-
):
|
| 393 |
-
"""
|
| 394 |
-
Generate an image using the Z-Image model based on the provided prompt and settings.
|
| 395 |
-
|
| 396 |
-
This function is triggered when the user clicks the "Generate" button. It processes
|
| 397 |
-
the input prompt (optionally enhancing it), configures generation parameters, and
|
| 398 |
-
produces an image using the Z-Image diffusion transformer pipeline.
|
| 399 |
-
|
| 400 |
-
Args:
|
| 401 |
-
prompt (str): Text prompt describing the desired image content
|
| 402 |
-
resolution (str): Output resolution in format "WIDTHxHEIGHT ( RATIO )" (e.g., "1024x1024 ( 1:1 )")
|
| 403 |
-
seed (int): Seed for reproducible generation
|
| 404 |
-
steps (int): Number of inference steps for the diffusion process
|
| 405 |
-
shift (float): Time shift parameter for the flow matching scheduler
|
| 406 |
-
random_seed (bool): Whether to generate a new random seed, if True will ignore the seed input
|
| 407 |
-
gallery_images (list): List of previously generated images to append to (only needed for the Gradio UI)
|
| 408 |
-
enhance (bool): This was Whether to enhance the prompt (DISABLED! Do not use)
|
| 409 |
-
progress (gr.Progress): Gradio progress tracker for displaying generation progress (only needed for the Gradio UI)
|
| 410 |
-
|
| 411 |
-
Returns:
|
| 412 |
-
tuple: (gallery_images, seed_str, seed_int)
|
| 413 |
-
- gallery_images: Updated list of generated images including the new image
|
| 414 |
-
- seed_str: String representation of the seed used for generation
|
| 415 |
-
- seed_int: Integer representation of the seed used for generation
|
| 416 |
-
"""
|
| 417 |
-
|
| 418 |
-
if random_seed:
|
| 419 |
-
new_seed = random.randint(1, 1000000)
|
| 420 |
-
else:
|
| 421 |
-
new_seed = seed if seed != -1 else random.randint(1, 1000000)
|
| 422 |
-
|
| 423 |
-
class UnsafeContentError(Exception):
|
| 424 |
pass
|
| 425 |
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
raise gr.Error("Model not loaded.")
|
| 429 |
-
|
| 430 |
-
has_unsafe_concept = is_unsafe_prompt(
|
| 431 |
-
pipe.text_encoder,
|
| 432 |
-
pipe.tokenizer,
|
| 433 |
-
system_prompt=UNSAFE_PROMPT_CHECK,
|
| 434 |
-
user_prompt=prompt,
|
| 435 |
-
max_new_token=UNSAFE_MAX_NEW_TOKEN,
|
| 436 |
-
)
|
| 437 |
-
if has_unsafe_concept:
|
| 438 |
-
raise UnsafeContentError("Input unsafe")
|
| 439 |
-
|
| 440 |
-
final_prompt = prompt
|
| 441 |
-
|
| 442 |
-
if enhance:
|
| 443 |
-
final_prompt, _ = prompt_enhance(prompt, True)
|
| 444 |
-
print(f"Enhanced prompt: {final_prompt}")
|
| 445 |
-
|
| 446 |
try:
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
image = generate_image(
|
| 452 |
-
pipe=pipe,
|
| 453 |
-
prompt=final_prompt,
|
| 454 |
-
resolution=resolution_str,
|
| 455 |
-
seed=new_seed,
|
| 456 |
-
guidance_scale=0.0,
|
| 457 |
-
num_inference_steps=int(steps + 1),
|
| 458 |
-
shift=shift,
|
| 459 |
-
)
|
| 460 |
-
|
| 461 |
-
safety_checker_input = pipe.safety_feature_extractor([image], return_tensors="pt").pixel_values.cuda()
|
| 462 |
-
_, has_nsfw_concept = pipe.safety_checker(images=[torch.zeros(1)], clip_input=safety_checker_input)
|
| 463 |
-
has_nsfw_concept = has_nsfw_concept[0]
|
| 464 |
-
if has_nsfw_concept:
|
| 465 |
-
print("input unsafe")
|
| 466 |
|
| 467 |
-
|
| 468 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 469 |
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 474 |
|
| 475 |
-
|
|
|
|
| 476 |
|
|
|
|
|
|
|
|
|
|
| 477 |
|
| 478 |
-
|
|
|
|
| 479 |
|
| 480 |
-
|
| 481 |
|
| 482 |
-
|
| 483 |
-
|
| 484 |
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
|
| 489 |
-
#
|
|
|
|
|
|
|
|
|
|
| 490 |
|
| 491 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 492 |
|
| 493 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 494 |
|
| 495 |
-
|
| 496 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 497 |
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
# PE enhancement button (Temporarily disabled)
|
| 551 |
-
# enhance_btn.click(
|
| 552 |
-
# prompt_enhance,
|
| 553 |
-
# inputs=[prompt_input, enable_enhance],
|
| 554 |
-
# outputs=[prompt_input, final_prompt_output]
|
| 555 |
-
# )
|
| 556 |
-
|
| 557 |
-
generate_btn.click(
|
| 558 |
-
generate,
|
| 559 |
-
inputs=[prompt_input, resolution, seed, steps, shift, random_seed, output_gallery],
|
| 560 |
-
outputs=[output_gallery, used_seed, seed],
|
| 561 |
-
api_visibility="public",
|
| 562 |
)
|
| 563 |
|
| 564 |
-
css = """
|
| 565 |
-
.fillable{max-width: 1230px !important}
|
| 566 |
-
"""
|
| 567 |
if __name__ == "__main__":
|
| 568 |
-
|
|
|
|
|
|
| 1 |
+
# ============================================================
|
| 2 |
+
# IMPORTANT: imports order matters for Hugging Face Spaces
|
| 3 |
+
# ============================================================
|
| 4 |
+
|
| 5 |
import os
|
| 6 |
+
import gc
|
| 7 |
import random
|
|
|
|
|
|
|
| 8 |
import warnings
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
# ---- Spaces GPU decorator (must be imported early) ----------
|
| 12 |
+
try:
|
| 13 |
+
import spaces # noqa: F401
|
| 14 |
+
SPACES_AVAILABLE = True
|
| 15 |
+
except Exception:
|
| 16 |
+
SPACES_AVAILABLE = False
|
| 17 |
|
|
|
|
|
|
|
| 18 |
import gradio as gr
|
| 19 |
+
import numpy as np
|
| 20 |
+
from PIL import Image
|
| 21 |
+
|
| 22 |
import torch
|
| 23 |
+
from huggingface_hub import login
|
| 24 |
+
|
| 25 |
+
from diffusers import (
|
| 26 |
+
ZImagePipeline,
|
| 27 |
+
ZImageImg2ImgPipeline,
|
| 28 |
+
AutoencoderKL,
|
| 29 |
+
FlowMatchEulerDiscreteScheduler,
|
| 30 |
+
)
|
| 31 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 32 |
|
| 33 |
+
# ============================================================
|
| 34 |
+
# Config
|
| 35 |
+
# ============================================================
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
+
MODEL_PATH = os.environ.get("MODEL_PATH", "telcom/dee-z-image").strip()
|
| 38 |
+
ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "flash_3").strip() # try: flash_3, flash, sdpa
|
| 39 |
+
ENABLE_COMPILE = os.environ.get("ENABLE_COMPILE", "false").lower() == "true"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
+
HF_TOKEN = os.getenv("HF_TOKEN", "").strip()
|
| 42 |
+
if HF_TOKEN:
|
| 43 |
+
login(token=HF_TOKEN)
|
| 44 |
|
| 45 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 46 |
warnings.filterwarnings("ignore")
|
| 47 |
logging.getLogger("transformers").setLevel(logging.ERROR)
|
| 48 |
|
| 49 |
+
MAX_SEED = np.iinfo(np.int32).max
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
+
# ============================================================
|
| 52 |
+
# Device & dtype
|
| 53 |
+
# ============================================================
|
| 54 |
|
| 55 |
+
cuda_available = torch.cuda.is_available()
|
| 56 |
+
device = torch.device("cuda" if cuda_available else "cpu")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
+
if cuda_available and hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported():
|
| 59 |
+
dtype = torch.bfloat16
|
| 60 |
+
elif cuda_available:
|
| 61 |
+
dtype = torch.float16
|
| 62 |
+
else:
|
| 63 |
+
dtype = torch.float32
|
| 64 |
|
| 65 |
+
# A conservative max for most Spaces GPUs. Increase if you know you have headroom.
|
| 66 |
+
MAX_IMAGE_SIZE = 1536 if cuda_available else 768
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
+
fallback_msg = ""
|
| 69 |
+
if not cuda_available:
|
| 70 |
+
fallback_msg = "GPU unavailable. Running in CPU fallback mode (slow)."
|
| 71 |
|
| 72 |
+
# ============================================================
|
| 73 |
+
# Load pipelines
|
| 74 |
+
# ============================================================
|
| 75 |
|
| 76 |
+
pipe_txt2img = None
|
| 77 |
+
pipe_img2img = None
|
| 78 |
+
model_loaded = False
|
| 79 |
+
load_error = None
|
| 80 |
|
| 81 |
+
def _try_load_with_from_pretrained():
|
| 82 |
+
"""
|
| 83 |
+
Preferred path: load everything via Diffusers from_pretrained.
|
| 84 |
+
Works when the repo is structured as a standard Diffusers pipeline repo.
|
| 85 |
+
"""
|
| 86 |
+
kwargs = {
|
| 87 |
+
"torch_dtype": dtype,
|
| 88 |
+
"use_safetensors": True,
|
| 89 |
+
}
|
| 90 |
+
if HF_TOKEN:
|
| 91 |
+
kwargs["token"] = HF_TOKEN
|
| 92 |
+
|
| 93 |
+
p_txt = ZImagePipeline.from_pretrained(MODEL_PATH, **kwargs)
|
| 94 |
+
p_img = ZImageImg2ImgPipeline(**p_txt.components)
|
| 95 |
+
return p_txt, p_img
|
| 96 |
+
|
| 97 |
+
def _fallback_manual_load():
|
| 98 |
+
"""
|
| 99 |
+
Fallback path: load subfolders manually, similar to many Z-Image demos.
|
| 100 |
+
Works when MODEL_PATH points to a repo with subfolders:
|
| 101 |
+
vae/, transformer/, text_encoder/, tokenizer/
|
| 102 |
+
"""
|
| 103 |
+
use_auth_token = HF_TOKEN if HF_TOKEN else True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
+
vae = AutoencoderKL.from_pretrained(
|
| 106 |
+
MODEL_PATH,
|
| 107 |
+
subfolder="vae",
|
| 108 |
+
torch_dtype=dtype,
|
| 109 |
+
use_auth_token=use_auth_token,
|
| 110 |
+
)
|
| 111 |
+
text_encoder = AutoModelForCausalLM.from_pretrained(
|
| 112 |
+
MODEL_PATH,
|
| 113 |
+
subfolder="text_encoder",
|
| 114 |
+
torch_dtype=dtype,
|
| 115 |
+
use_auth_token=use_auth_token,
|
| 116 |
+
).eval()
|
| 117 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 118 |
+
MODEL_PATH,
|
| 119 |
+
subfolder="tokenizer",
|
| 120 |
+
use_auth_token=use_auth_token,
|
| 121 |
+
)
|
| 122 |
tokenizer.padding_side = "left"
|
| 123 |
|
| 124 |
+
# ZImageTransformer2DModel lives inside diffusers; importing lazily avoids import issues on older versions.
|
| 125 |
+
from diffusers.models.transformers.transformer_z_image import ZImageTransformer2DModel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
+
transformer = ZImageTransformer2DModel.from_pretrained(
|
| 128 |
+
MODEL_PATH,
|
| 129 |
+
subfolder="transformer",
|
| 130 |
+
torch_dtype=dtype,
|
| 131 |
+
use_auth_token=use_auth_token,
|
| 132 |
+
)
|
| 133 |
|
| 134 |
+
p_txt = ZImagePipeline(scheduler=None, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=transformer)
|
| 135 |
+
p_img = ZImageImg2ImgPipeline(scheduler=None, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=transformer)
|
| 136 |
+
return p_txt, p_img
|
| 137 |
|
| 138 |
+
try:
|
| 139 |
+
pipe_txt2img, pipe_img2img = _try_load_with_from_pretrained()
|
| 140 |
+
model_loaded = True
|
| 141 |
+
except Exception as e1:
|
| 142 |
try:
|
| 143 |
+
pipe_txt2img, pipe_img2img = _fallback_manual_load()
|
| 144 |
+
model_loaded = True
|
| 145 |
+
except Exception as e2:
|
| 146 |
+
load_error = f"from_pretrained error: {repr(e1)}\nmanual_load error: {repr(e2)}"
|
| 147 |
+
model_loaded = False
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
+
if model_loaded:
|
| 150 |
+
pipe_txt2img = pipe_txt2img.to(device)
|
| 151 |
+
pipe_img2img = pipe_img2img.to(device)
|
| 152 |
|
| 153 |
+
# Try attention backend (best-effort)
|
| 154 |
try:
|
| 155 |
+
if hasattr(pipe_txt2img, "transformer") and hasattr(pipe_txt2img.transformer, "set_attention_backend"):
|
| 156 |
+
pipe_txt2img.transformer.set_attention_backend(ATTENTION_BACKEND)
|
| 157 |
+
pipe_img2img.transformer.set_attention_backend(ATTENTION_BACKEND)
|
| 158 |
+
except Exception:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
pass
|
| 160 |
|
| 161 |
+
# Optional compile (best-effort, can break on some setups)
|
| 162 |
+
if ENABLE_COMPILE and device.type == "cuda":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
try:
|
| 164 |
+
pipe_txt2img.transformer = torch.compile(pipe_txt2img.transformer, mode="max-autotune-no-cudagraphs", fullgraph=False)
|
| 165 |
+
pipe_img2img.transformer = pipe_txt2img.transformer
|
| 166 |
+
except Exception:
|
| 167 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
|
| 169 |
+
# Disable diffusers progress bars
|
| 170 |
+
try:
|
| 171 |
+
pipe_txt2img.set_progress_bar_config(disable=True)
|
| 172 |
+
pipe_img2img.set_progress_bar_config(disable=True)
|
| 173 |
+
except Exception:
|
| 174 |
+
pass
|
| 175 |
|
| 176 |
+
# ============================================================
|
| 177 |
+
# Utility: error image
|
| 178 |
+
# ============================================================
|
| 179 |
+
|
| 180 |
+
def make_error_image(w, h):
|
| 181 |
+
return Image.new("RGB", (w, h), (18, 18, 22))
|
| 182 |
+
|
| 183 |
+
def _prep_init_image(init_image, width, height):
|
| 184 |
+
if init_image is None:
|
| 185 |
+
return None
|
| 186 |
+
if not isinstance(init_image, Image.Image):
|
| 187 |
+
return None
|
| 188 |
+
init_image = init_image.convert("RGB")
|
| 189 |
+
if init_image.size != (width, height):
|
| 190 |
+
init_image = init_image.resize((width, height), Image.LANCZOS)
|
| 191 |
+
return init_image
|
| 192 |
+
|
| 193 |
+
# ============================================================
|
| 194 |
+
# Inference
|
| 195 |
+
# ============================================================
|
| 196 |
+
|
| 197 |
+
def _infer_impl(
|
| 198 |
+
prompt: str,
|
| 199 |
+
negative_prompt: str,
|
| 200 |
+
seed: int,
|
| 201 |
+
randomize_seed: bool,
|
| 202 |
+
width: int,
|
| 203 |
+
height: int,
|
| 204 |
+
guidance_scale: float,
|
| 205 |
+
num_inference_steps: int,
|
| 206 |
+
shift: float,
|
| 207 |
+
max_sequence_length: int,
|
| 208 |
+
init_image,
|
| 209 |
+
strength: float,
|
| 210 |
+
):
|
| 211 |
+
width = int(width)
|
| 212 |
+
height = int(height)
|
| 213 |
+
seed = int(seed)
|
| 214 |
|
| 215 |
+
if not model_loaded:
|
| 216 |
+
return make_error_image(width, height), f"Model load failed:\n\n{load_error}"
|
| 217 |
|
| 218 |
+
prompt = (prompt or "").strip()
|
| 219 |
+
if not prompt:
|
| 220 |
+
return make_error_image(width, height), "Error: Prompt is empty."
|
| 221 |
|
| 222 |
+
if randomize_seed:
|
| 223 |
+
seed = random.randint(0, MAX_SEED)
|
| 224 |
|
| 225 |
+
init_image = _prep_init_image(init_image, width, height)
|
| 226 |
|
| 227 |
+
generator = torch.Generator(device=device)
|
| 228 |
+
generator = generator.manual_seed(seed)
|
| 229 |
|
| 230 |
+
status = f"Seed: {seed}"
|
| 231 |
+
if fallback_msg:
|
| 232 |
+
status += f" | {fallback_msg}"
|
| 233 |
|
| 234 |
+
# Set scheduler per-run because shift can change
|
| 235 |
+
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=float(shift))
|
| 236 |
+
pipe_txt2img.scheduler = scheduler
|
| 237 |
+
pipe_img2img.scheduler = scheduler
|
| 238 |
|
| 239 |
+
try:
|
| 240 |
+
common_kwargs = dict(
|
| 241 |
+
prompt=prompt,
|
| 242 |
+
negative_prompt=(negative_prompt or "").strip() if (guidance_scale and float(guidance_scale) > 1.0) else None,
|
| 243 |
+
guidance_scale=float(guidance_scale),
|
| 244 |
+
num_inference_steps=int(num_inference_steps),
|
| 245 |
+
generator=generator,
|
| 246 |
+
height=height,
|
| 247 |
+
width=width,
|
| 248 |
+
max_sequence_length=int(max_sequence_length),
|
| 249 |
+
)
|
| 250 |
|
| 251 |
+
with torch.inference_mode():
|
| 252 |
+
if device.type == "cuda":
|
| 253 |
+
with torch.autocast("cuda", dtype=dtype):
|
| 254 |
+
if init_image is not None:
|
| 255 |
+
out = pipe_img2img(
|
| 256 |
+
image=init_image,
|
| 257 |
+
strength=float(strength),
|
| 258 |
+
**common_kwargs,
|
| 259 |
+
)
|
| 260 |
+
else:
|
| 261 |
+
out = pipe_txt2img(**common_kwargs)
|
| 262 |
+
else:
|
| 263 |
+
if init_image is not None:
|
| 264 |
+
out = pipe_img2img(
|
| 265 |
+
image=init_image,
|
| 266 |
+
strength=float(strength),
|
| 267 |
+
**common_kwargs,
|
| 268 |
+
)
|
| 269 |
+
else:
|
| 270 |
+
out = pipe_txt2img(**common_kwargs)
|
| 271 |
+
|
| 272 |
+
image = out.images[0]
|
| 273 |
+
return image, status
|
| 274 |
|
| 275 |
+
except Exception as e:
|
| 276 |
+
return make_error_image(width, height), f"Error: {type(e).__name__}: {e}"
|
| 277 |
+
|
| 278 |
+
finally:
|
| 279 |
+
gc.collect()
|
| 280 |
+
if device.type == "cuda":
|
| 281 |
+
torch.cuda.empty_cache()
|
| 282 |
+
|
| 283 |
+
# IMPORTANT: decorator must be explicit
|
| 284 |
+
if SPACES_AVAILABLE:
|
| 285 |
+
@spaces.GPU
|
| 286 |
+
def infer(*args, **kwargs):
|
| 287 |
+
return _infer_impl(*args, **kwargs)
|
| 288 |
+
else:
|
| 289 |
+
def infer(*args, **kwargs):
|
| 290 |
+
return _infer_impl(*args, **kwargs)
|
| 291 |
+
|
| 292 |
+
# ============================================================
|
| 293 |
+
# UI
|
| 294 |
+
# ============================================================
|
| 295 |
+
|
| 296 |
+
CSS = """
|
| 297 |
+
body {
|
| 298 |
+
background: #000;
|
| 299 |
+
color: #fff;
|
| 300 |
+
}
|
| 301 |
+
"""
|
| 302 |
|
| 303 |
+
with gr.Blocks(title="Z-Image txt2img + img2img") as demo:
|
| 304 |
+
gr.HTML(f"<style>{CSS}</style>")
|
| 305 |
+
|
| 306 |
+
if fallback_msg:
|
| 307 |
+
gr.Markdown(f"**{fallback_msg}**")
|
| 308 |
+
|
| 309 |
+
if not model_loaded:
|
| 310 |
+
gr.Markdown(f"⚠️ Model failed to load:\n\n{load_error}")
|
| 311 |
+
|
| 312 |
+
gr.Markdown("## Z-Image Generator (txt2img + img2img)")
|
| 313 |
+
|
| 314 |
+
prompt = gr.Textbox(label="Prompt", lines=2, placeholder="Describe what you want...")
|
| 315 |
+
init_image = gr.Image(label="Initial image (optional)", type="pil")
|
| 316 |
+
|
| 317 |
+
run_button = gr.Button("Generate")
|
| 318 |
+
result = gr.Image(label="Result")
|
| 319 |
+
status = gr.Markdown("")
|
| 320 |
+
|
| 321 |
+
with gr.Accordion("Advanced Settings", open=False):
|
| 322 |
+
negative_prompt = gr.Textbox(label="Negative prompt (only used if Guidance > 1)")
|
| 323 |
+
seed = gr.Slider(0, MAX_SEED, step=1, value=0, label="Seed")
|
| 324 |
+
randomize_seed = gr.Checkbox(value=True, label="Randomize seed")
|
| 325 |
+
|
| 326 |
+
width = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=1024, label="Width")
|
| 327 |
+
height = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=1024, label="Height")
|
| 328 |
+
|
| 329 |
+
guidance_scale = gr.Slider(0.0, 10.0, step=0.1, value=0.0, label="Guidance scale")
|
| 330 |
+
num_inference_steps = gr.Slider(1, 50, step=1, value=8, label="Steps")
|
| 331 |
+
shift = gr.Slider(1.0, 10.0, step=0.1, value=3.0, label="Time shift")
|
| 332 |
+
|
| 333 |
+
max_sequence_length = gr.Slider(64, 512, step=64, value=512, label="Max sequence length")
|
| 334 |
+
|
| 335 |
+
strength = gr.Slider(0.0, 1.0, step=0.05, value=0.6, label="Image strength (img2img)")
|
| 336 |
+
|
| 337 |
+
run_button.click(
|
| 338 |
+
fn=infer,
|
| 339 |
+
inputs=[
|
| 340 |
+
prompt,
|
| 341 |
+
negative_prompt,
|
| 342 |
+
seed,
|
| 343 |
+
randomize_seed,
|
| 344 |
+
width,
|
| 345 |
+
height,
|
| 346 |
+
guidance_scale,
|
| 347 |
+
num_inference_steps,
|
| 348 |
+
shift,
|
| 349 |
+
max_sequence_length,
|
| 350 |
+
init_image,
|
| 351 |
+
strength,
|
| 352 |
+
],
|
| 353 |
+
outputs=[result, status],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
)
|
| 355 |
|
|
|
|
|
|
|
|
|
|
| 356 |
if __name__ == "__main__":
|
| 357 |
+
# Keep the same launch feel as your first script
|
| 358 |
+
demo.queue().launch(ssr_mode=False)
|