yingzhac-research commited on
Commit
9fc4b86
·
1 Parent(s): f89593c

Lazy-load pipeline in ZeroGPU worker

Browse files
Files changed (1) hide show
  1. app.py +24 -11
app.py CHANGED
@@ -2,25 +2,34 @@ import torch
2
  import spaces
3
  import gradio as gr
4
  from diffusers import DiffusionPipeline
 
5
 
6
  MAX_SEED = 2**32 - 1
7
 
8
- # Load the pipeline once at startup
9
- print("Loading Z-Image-Turbo pipeline...")
10
- pipe = DiffusionPipeline.from_pretrained(
11
- "Tongyi-MAI/Z-Image-Turbo",
12
- torch_dtype=torch.bfloat16,
13
- low_cpu_mem_usage=False,
14
- )
15
- pipe.to("cuda")
 
 
 
 
 
 
 
 
 
 
16
 
17
  '# ======== AoTI compilation + FA3 ======== (disabled on HF to avoid outdated AOTI/FA3 package errors)'
18
  # pipe.transformer.layers._repeated_blocks = ["ZImageTransformerBlock"]
19
  # spaces.aoti_blocks_load(pipe.transformer.layers, "zerogpu-aoti/Z-Image", variant="fa3")
20
 
21
- print("Pipeline loaded!")
22
-
23
- @spaces.GPU
24
  def generate_image(
25
  prompt,
26
  negative_prompt,
@@ -33,6 +42,10 @@ def generate_image(
33
  progress=gr.Progress(track_tqdm=True),
34
  ):
35
  """Generate 4 images with seeds: seed, 2x, 3x, 4x (mod MAX_SEED)."""
 
 
 
 
36
  if randomize_seed:
37
  seed = torch.randint(0, MAX_SEED, (1,)).item()
38
 
 
2
  import spaces
3
  import gradio as gr
4
  from diffusers import DiffusionPipeline
5
+ from threading import Lock
6
 
7
  MAX_SEED = 2**32 - 1
8
 
9
+ pipe = None
10
+ pipe_lock = Lock()
11
+
12
+ def get_pipe():
13
+ global pipe
14
+ if pipe is not None:
15
+ return pipe
16
+ with pipe_lock:
17
+ if pipe is None:
18
+ # Load the pipeline lazily inside the ZeroGPU worker
19
+ print("Loading Z-Image-Turbo pipeline...")
20
+ pipe = DiffusionPipeline.from_pretrained(
21
+ "Tongyi-MAI/Z-Image-Turbo",
22
+ torch_dtype=torch.bfloat16,
23
+ low_cpu_mem_usage=False,
24
+ ).to("cuda")
25
+ print("Pipeline loaded!")
26
+ return pipe
27
 
28
  '# ======== AoTI compilation + FA3 ======== (disabled on HF to avoid outdated AOTI/FA3 package errors)'
29
  # pipe.transformer.layers._repeated_blocks = ["ZImageTransformerBlock"]
30
  # spaces.aoti_blocks_load(pipe.transformer.layers, "zerogpu-aoti/Z-Image", variant="fa3")
31
 
32
+ @spaces.GPU(duration=120)
 
 
33
  def generate_image(
34
  prompt,
35
  negative_prompt,
 
42
  progress=gr.Progress(track_tqdm=True),
43
  ):
44
  """Generate 4 images with seeds: seed, 2x, 3x, 4x (mod MAX_SEED)."""
45
+ if not torch.cuda.is_available():
46
+ raise RuntimeError("CUDA is not available inside the ZeroGPU worker.")
47
+
48
+ pipe = get_pipe()
49
  if randomize_seed:
50
  seed = torch.randint(0, MAX_SEED, (1,)).item()
51