Spaces:
Running on Zero
Running on Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -14,8 +14,6 @@ import gradio as gr
|
|
| 14 |
import torch
|
| 15 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 16 |
|
| 17 |
-
from prompt_check import is_unsafe_prompt
|
| 18 |
-
|
| 19 |
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 20 |
|
| 21 |
from diffusers import ZImagePipeline
|
|
@@ -28,10 +26,8 @@ MODEL_PATH = os.environ.get("MODEL_PATH", "Tongyi-MAI/Z-Image-Turbo")
|
|
| 28 |
ENABLE_COMPILE = os.environ.get("ENABLE_COMPILE", "true").lower() == "true"
|
| 29 |
ENABLE_WARMUP = os.environ.get("ENABLE_WARMUP", "true").lower() == "true"
|
| 30 |
ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "flash_3")
|
| 31 |
-
|
| 32 |
-
DASHSCOPE_API_KEY = os.environ.get("DASHSCOPE_API_KEY")
|
| 33 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 34 |
-
UNSAFE_PROMPT_CHECK = os.environ.get("UNSAFE_PROMPT_CHECK")
|
| 35 |
# =============================================================================
|
| 36 |
|
| 37 |
|
|
@@ -280,11 +276,11 @@ class APIPromptExpander(PromptExpander):
|
|
| 280 |
try:
|
| 281 |
from openai import OpenAI
|
| 282 |
|
| 283 |
-
api_key = self.api_config.get("api_key") or
|
| 284 |
-
base_url = self.api_config.get("base_url", "https://
|
| 285 |
|
| 286 |
if not api_key:
|
| 287 |
-
print("Warning:
|
| 288 |
return None
|
| 289 |
|
| 290 |
return OpenAI(api_key=api_key, base_url=base_url)
|
|
@@ -310,12 +306,10 @@ class APIPromptExpander(PromptExpander):
|
|
| 310 |
prompt = " "
|
| 311 |
|
| 312 |
try:
|
| 313 |
-
model = self.api_config.get("model", "
|
| 314 |
response = self.client.chat.completions.create(
|
| 315 |
model=model,
|
| 316 |
messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}],
|
| 317 |
-
temperature=0.7,
|
| 318 |
-
top_p=0.8,
|
| 319 |
)
|
| 320 |
|
| 321 |
content = response.choices[0].message.content
|
|
@@ -331,6 +325,8 @@ class APIPromptExpander(PromptExpander):
|
|
| 331 |
else:
|
| 332 |
expanded_prompt = content
|
| 333 |
|
|
|
|
|
|
|
| 334 |
return PromptOutput(
|
| 335 |
status=True, prompt=expanded_prompt, seed=seed, system_prompt=system_prompt, message=content
|
| 336 |
)
|
|
@@ -366,7 +362,7 @@ def init_app():
|
|
| 366 |
pipe = None
|
| 367 |
|
| 368 |
try:
|
| 369 |
-
prompt_expander = create_prompt_expander(backend="api", api_config={"model": "
|
| 370 |
print("Prompt expander initialized.")
|
| 371 |
except Exception as e:
|
| 372 |
print(f"Error initializing prompt expander: {e}")
|
|
@@ -432,52 +428,29 @@ def generate(
|
|
| 432 |
else:
|
| 433 |
new_seed = seed if seed != -1 else random.randint(1, 1000000)
|
| 434 |
|
| 435 |
-
class UnsafeContentError(Exception):
|
| 436 |
-
pass
|
| 437 |
-
|
| 438 |
-
try:
|
| 439 |
-
if pipe is None:
|
| 440 |
-
raise gr.Error("Model not loaded.")
|
| 441 |
-
|
| 442 |
-
has_unsafe_concept = is_unsafe_prompt(
|
| 443 |
-
pipe.text_encoder,
|
| 444 |
-
pipe.tokenizer,
|
| 445 |
-
system_prompt=UNSAFE_PROMPT_CHECK,
|
| 446 |
-
user_prompt=prompt,
|
| 447 |
-
max_new_token=UNSAFE_MAX_NEW_TOKEN,
|
| 448 |
-
)
|
| 449 |
-
if has_unsafe_concept:
|
| 450 |
-
raise UnsafeContentError("Input unsafe")
|
| 451 |
-
|
| 452 |
-
final_prompt = prompt
|
| 453 |
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
print(f"Enhanced prompt: {final_prompt}")
|
| 457 |
|
| 458 |
-
|
| 459 |
-
resolution_str = resolution.split(" ")[0]
|
| 460 |
-
except:
|
| 461 |
-
resolution_str = "1024x1024"
|
| 462 |
-
|
| 463 |
-
image = generate_image(
|
| 464 |
-
pipe=pipe,
|
| 465 |
-
prompt=final_prompt,
|
| 466 |
-
resolution=resolution_str,
|
| 467 |
-
seed=new_seed,
|
| 468 |
-
guidance_scale=0.0,
|
| 469 |
-
num_inference_steps=int(steps + 1),
|
| 470 |
-
shift=shift,
|
| 471 |
-
)
|
| 472 |
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
has_nsfw_concept = has_nsfw_concept[0]
|
| 476 |
-
if has_nsfw_concept:
|
| 477 |
-
raise UnsafeContentError("input unsafe")
|
| 478 |
|
| 479 |
-
|
| 480 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 481 |
|
| 482 |
if gallery_images is None:
|
| 483 |
gallery_images = []
|
|
@@ -491,8 +464,8 @@ init_app()
|
|
| 491 |
|
| 492 |
# ==================== AoTI (Ahead of Time Inductor compilation) ====================
|
| 493 |
|
| 494 |
-
pipe.transformer.layers._repeated_blocks = ["ZImageTransformerBlock"]
|
| 495 |
-
spaces.aoti_blocks_load(pipe.transformer.layers, "zerogpu-aoti/Z-Image", variant="fa3")
|
| 496 |
|
| 497 |
with gr.Blocks(title="Z-Image Demo") as demo:
|
| 498 |
gr.Markdown(
|
|
|
|
| 14 |
import torch
|
| 15 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 16 |
|
|
|
|
|
|
|
| 17 |
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 18 |
|
| 19 |
from diffusers import ZImagePipeline
|
|
|
|
| 26 |
ENABLE_COMPILE = os.environ.get("ENABLE_COMPILE", "true").lower() == "true"
|
| 27 |
ENABLE_WARMUP = os.environ.get("ENABLE_WARMUP", "true").lower() == "true"
|
| 28 |
ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "flash_3")
|
| 29 |
+
OPENROUTER_API_KEY = os.environ.get("OPENROUTER_API_KEY")
|
|
|
|
| 30 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
|
|
|
| 31 |
# =============================================================================
|
| 32 |
|
| 33 |
|
|
|
|
| 276 |
try:
|
| 277 |
from openai import OpenAI
|
| 278 |
|
| 279 |
+
api_key = self.api_config.get("api_key") or OPENROUTER_API_KEY
|
| 280 |
+
base_url = self.api_config.get("base_url", "https://openrouter.ai/api/v1")
|
| 281 |
|
| 282 |
if not api_key:
|
| 283 |
+
print("Warning: OPENROUTER_API_KEY not found.")
|
| 284 |
return None
|
| 285 |
|
| 286 |
return OpenAI(api_key=api_key, base_url=base_url)
|
|
|
|
| 306 |
prompt = " "
|
| 307 |
|
| 308 |
try:
|
| 309 |
+
model = self.api_config.get("model", "google/gemini-2.5-flash")
|
| 310 |
response = self.client.chat.completions.create(
|
| 311 |
model=model,
|
| 312 |
messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}],
|
|
|
|
|
|
|
| 313 |
)
|
| 314 |
|
| 315 |
content = response.choices[0].message.content
|
|
|
|
| 325 |
else:
|
| 326 |
expanded_prompt = content
|
| 327 |
|
| 328 |
+
print(f"Original prompt: {prompt}\nFinal prompt: {expanded_prompt}")
|
| 329 |
+
|
| 330 |
return PromptOutput(
|
| 331 |
status=True, prompt=expanded_prompt, seed=seed, system_prompt=system_prompt, message=content
|
| 332 |
)
|
|
|
|
| 362 |
pipe = None
|
| 363 |
|
| 364 |
try:
|
| 365 |
+
prompt_expander = create_prompt_expander(backend="api", api_config={"model": "google/gemini-2.5-flash"})
|
| 366 |
print("Prompt expander initialized.")
|
| 367 |
except Exception as e:
|
| 368 |
print(f"Error initializing prompt expander: {e}")
|
|
|
|
| 428 |
else:
|
| 429 |
new_seed = seed if seed != -1 else random.randint(1, 1000000)
|
| 430 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 431 |
|
| 432 |
+
if pipe is None:
|
| 433 |
+
raise gr.Error("Model not loaded.")
|
|
|
|
| 434 |
|
| 435 |
+
final_prompt = prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 436 |
|
| 437 |
+
if enhance:
|
| 438 |
+
final_prompt, msg = prompt_enhance(prompt, True)
|
|
|
|
|
|
|
|
|
|
| 439 |
|
| 440 |
+
try:
|
| 441 |
+
resolution_str = resolution.split(" ")[0]
|
| 442 |
+
except:
|
| 443 |
+
resolution_str = "1024x1024"
|
| 444 |
+
|
| 445 |
+
image = generate_image(
|
| 446 |
+
pipe=pipe,
|
| 447 |
+
prompt=final_prompt,
|
| 448 |
+
resolution=resolution_str,
|
| 449 |
+
seed=new_seed,
|
| 450 |
+
guidance_scale=0.0,
|
| 451 |
+
num_inference_steps=int(steps + 1),
|
| 452 |
+
shift=shift,
|
| 453 |
+
)
|
| 454 |
|
| 455 |
if gallery_images is None:
|
| 456 |
gallery_images = []
|
|
|
|
| 464 |
|
| 465 |
# ==================== AoTI (Ahead of Time Inductor compilation) ====================
|
| 466 |
|
| 467 |
+
#pipe.transformer.layers._repeated_blocks = ["ZImageTransformerBlock"]
|
| 468 |
+
#spaces.aoti_blocks_load(pipe.transformer.layers, "zerogpu-aoti/Z-Image", variant="fa3")
|
| 469 |
|
| 470 |
with gr.Blocks(title="Z-Image Demo") as demo:
|
| 471 |
gr.Markdown(
|