Azure99 commited on
Commit
7e3d1b9
·
verified ·
1 Parent(s): aa76f48

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -56
app.py CHANGED
@@ -14,8 +14,6 @@ import gradio as gr
14
  import torch
15
  from transformers import AutoModelForCausalLM, AutoTokenizer
16
 
17
- from prompt_check import is_unsafe_prompt
18
-
19
  sys.path.append(os.path.dirname(os.path.abspath(__file__)))
20
 
21
  from diffusers import ZImagePipeline
@@ -28,10 +26,8 @@ MODEL_PATH = os.environ.get("MODEL_PATH", "Tongyi-MAI/Z-Image-Turbo")
28
  ENABLE_COMPILE = os.environ.get("ENABLE_COMPILE", "true").lower() == "true"
29
  ENABLE_WARMUP = os.environ.get("ENABLE_WARMUP", "true").lower() == "true"
30
  ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "flash_3")
31
- UNSAFE_MAX_NEW_TOKEN = int(os.environ.get("UNSAFE_MAX_NEW_TOKEN", "10"))
32
- DASHSCOPE_API_KEY = os.environ.get("DASHSCOPE_API_KEY")
33
  HF_TOKEN = os.environ.get("HF_TOKEN")
34
- UNSAFE_PROMPT_CHECK = os.environ.get("UNSAFE_PROMPT_CHECK")
35
  # =============================================================================
36
 
37
 
@@ -280,11 +276,11 @@ class APIPromptExpander(PromptExpander):
280
  try:
281
  from openai import OpenAI
282
 
283
- api_key = self.api_config.get("api_key") or DASHSCOPE_API_KEY
284
- base_url = self.api_config.get("base_url", "https://dashscope.aliyuncs.com/compatible-mode/v1")
285
 
286
  if not api_key:
287
- print("Warning: DASHSCOPE_API_KEY not found.")
288
  return None
289
 
290
  return OpenAI(api_key=api_key, base_url=base_url)
@@ -310,12 +306,10 @@ class APIPromptExpander(PromptExpander):
310
  prompt = " "
311
 
312
  try:
313
- model = self.api_config.get("model", "qwen3-max-preview")
314
  response = self.client.chat.completions.create(
315
  model=model,
316
  messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}],
317
- temperature=0.7,
318
- top_p=0.8,
319
  )
320
 
321
  content = response.choices[0].message.content
@@ -331,6 +325,8 @@ class APIPromptExpander(PromptExpander):
331
  else:
332
  expanded_prompt = content
333
 
 
 
334
  return PromptOutput(
335
  status=True, prompt=expanded_prompt, seed=seed, system_prompt=system_prompt, message=content
336
  )
@@ -366,7 +362,7 @@ def init_app():
366
  pipe = None
367
 
368
  try:
369
- prompt_expander = create_prompt_expander(backend="api", api_config={"model": "qwen3-max-preview"})
370
  print("Prompt expander initialized.")
371
  except Exception as e:
372
  print(f"Error initializing prompt expander: {e}")
@@ -432,52 +428,29 @@ def generate(
432
  else:
433
  new_seed = seed if seed != -1 else random.randint(1, 1000000)
434
 
435
- class UnsafeContentError(Exception):
436
- pass
437
-
438
- try:
439
- if pipe is None:
440
- raise gr.Error("Model not loaded.")
441
-
442
- has_unsafe_concept = is_unsafe_prompt(
443
- pipe.text_encoder,
444
- pipe.tokenizer,
445
- system_prompt=UNSAFE_PROMPT_CHECK,
446
- user_prompt=prompt,
447
- max_new_token=UNSAFE_MAX_NEW_TOKEN,
448
- )
449
- if has_unsafe_concept:
450
- raise UnsafeContentError("Input unsafe")
451
-
452
- final_prompt = prompt
453
 
454
- if enhance:
455
- final_prompt, _ = prompt_enhance(prompt, True)
456
- print(f"Enhanced prompt: {final_prompt}")
457
 
458
- try:
459
- resolution_str = resolution.split(" ")[0]
460
- except:
461
- resolution_str = "1024x1024"
462
-
463
- image = generate_image(
464
- pipe=pipe,
465
- prompt=final_prompt,
466
- resolution=resolution_str,
467
- seed=new_seed,
468
- guidance_scale=0.0,
469
- num_inference_steps=int(steps + 1),
470
- shift=shift,
471
- )
472
 
473
- safety_checker_input = pipe.safety_feature_extractor([image], return_tensors="pt").pixel_values.cuda()
474
- _, has_nsfw_concept = pipe.safety_checker(images=[torch.zeros(1)], clip_input=safety_checker_input)
475
- has_nsfw_concept = has_nsfw_concept[0]
476
- if has_nsfw_concept:
477
- raise UnsafeContentError("input unsafe")
478
 
479
- except UnsafeContentError:
480
- image = Image.open("nsfw.png")
 
 
 
 
 
 
 
 
 
 
 
 
481
 
482
  if gallery_images is None:
483
  gallery_images = []
@@ -491,8 +464,8 @@ init_app()
491
 
492
  # ==================== AoTI (Ahead of Time Inductor compilation) ====================
493
 
494
- pipe.transformer.layers._repeated_blocks = ["ZImageTransformerBlock"]
495
- spaces.aoti_blocks_load(pipe.transformer.layers, "zerogpu-aoti/Z-Image", variant="fa3")
496
 
497
  with gr.Blocks(title="Z-Image Demo") as demo:
498
  gr.Markdown(
 
14
  import torch
15
  from transformers import AutoModelForCausalLM, AutoTokenizer
16
 
 
 
17
  sys.path.append(os.path.dirname(os.path.abspath(__file__)))
18
 
19
  from diffusers import ZImagePipeline
 
26
  ENABLE_COMPILE = os.environ.get("ENABLE_COMPILE", "true").lower() == "true"
27
  ENABLE_WARMUP = os.environ.get("ENABLE_WARMUP", "true").lower() == "true"
28
  ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "flash_3")
29
+ OPENROUTER_API_KEY = os.environ.get("OPENROUTER_API_KEY")
 
30
  HF_TOKEN = os.environ.get("HF_TOKEN")
 
31
  # =============================================================================
32
 
33
 
 
276
  try:
277
  from openai import OpenAI
278
 
279
+ api_key = self.api_config.get("api_key") or OPENROUTER_API_KEY
280
+ base_url = self.api_config.get("base_url", "https://openrouter.ai/api/v1")
281
 
282
  if not api_key:
283
+ print("Warning: OPENROUTER_API_KEY not found.")
284
  return None
285
 
286
  return OpenAI(api_key=api_key, base_url=base_url)
 
306
  prompt = " "
307
 
308
  try:
309
+ model = self.api_config.get("model", "google/gemini-2.5-flash")
310
  response = self.client.chat.completions.create(
311
  model=model,
312
  messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}],
 
 
313
  )
314
 
315
  content = response.choices[0].message.content
 
325
  else:
326
  expanded_prompt = content
327
 
328
+ print(f"Original prompt: {prompt}\nFinal prompt: {expanded_prompt}")
329
+
330
  return PromptOutput(
331
  status=True, prompt=expanded_prompt, seed=seed, system_prompt=system_prompt, message=content
332
  )
 
362
  pipe = None
363
 
364
  try:
365
+ prompt_expander = create_prompt_expander(backend="api", api_config={"model": "google/gemini-2.5-flash"})
366
  print("Prompt expander initialized.")
367
  except Exception as e:
368
  print(f"Error initializing prompt expander: {e}")
 
428
  else:
429
  new_seed = seed if seed != -1 else random.randint(1, 1000000)
430
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
431
 
432
+ if pipe is None:
433
+ raise gr.Error("Model not loaded.")
 
434
 
435
+ final_prompt = prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
436
 
437
+ if enhance:
438
+ final_prompt, msg = prompt_enhance(prompt, True)
 
 
 
439
 
440
+ try:
441
+ resolution_str = resolution.split(" ")[0]
442
+ except:
443
+ resolution_str = "1024x1024"
444
+
445
+ image = generate_image(
446
+ pipe=pipe,
447
+ prompt=final_prompt,
448
+ resolution=resolution_str,
449
+ seed=new_seed,
450
+ guidance_scale=0.0,
451
+ num_inference_steps=int(steps + 1),
452
+ shift=shift,
453
+ )
454
 
455
  if gallery_images is None:
456
  gallery_images = []
 
464
 
465
  # ==================== AoTI (Ahead of Time Inductor compilation) ====================
466
 
467
+ #pipe.transformer.layers._repeated_blocks = ["ZImageTransformerBlock"]
468
+ #spaces.aoti_blocks_load(pipe.transformer.layers, "zerogpu-aoti/Z-Image", variant="fa3")
469
 
470
  with gr.Blocks(title="Z-Image Demo") as demo:
471
  gr.Markdown(