Diffusers
Safetensors
Z-Image-Special-Edtion / batch_inference.py
leewheel's picture
Upload 8 files
b67c0bb verified
"""Batch prompt inference for Z-Image."""
import os
from pathlib import Path
import time
import torch
from inference import ensure_weights
from utils import AttentionBackend, load_from_local_dir, set_attention_backend
from zimage import generate
def read_prompts(path: str) -> list[str]:
"""Read prompts from a text file (one per line, empty lines skipped)."""
prompt_path = Path(path)
if not prompt_path.exists():
raise FileNotFoundError(f"Prompt file not found: {prompt_path}")
with prompt_path.open("r", encoding="utf-8") as f:
prompts = [line.strip() for line in f if line.strip()]
if not prompts:
raise ValueError(f"No prompts found in {prompt_path}")
return prompts
PROMPTS = read_prompts(os.environ.get("PROMPTS_FILE", "prompts/prompt1.txt"))
def slugify(text: str, max_len: int = 60) -> str:
"""Create a filesystem-safe slug from the prompt."""
slug = "".join(ch.lower() if ch.isalnum() else "-" for ch in text)
slug = "-".join(part for part in slug.split("-") if part)
return slug[:max_len].rstrip("-") or "prompt"
def select_device() -> str:
"""Choose the best available device without repeating detection logic."""
if torch.cuda.is_available():
print("Chosen device: cuda")
return "cuda"
try:
import torch_xla.core.xla_model as xm
device = xm.xla_device()
print("Chosen device: tpu")
return device
except (ImportError, RuntimeError):
if torch.backends.mps.is_available():
print("Chosen device: mps")
return "mps"
print("Chosen device: cpu")
return "cpu"
def main():
model_path = ensure_weights("ckpts/Z-Image-Turbo")
dtype = torch.bfloat16
compile = False
height = 1024
width = 1024
num_inference_steps = 8
guidance_scale = 0.0
attn_backend = os.environ.get("ZIMAGE_ATTENTION", "_native_flash")
output_dir = Path("outputs")
output_dir.mkdir(exist_ok=True)
device = select_device()
components = load_from_local_dir(model_path, device=device, dtype=dtype, compile=compile)
AttentionBackend.print_available_backends()
set_attention_backend(attn_backend)
print(f"Chosen attention backend: {attn_backend}")
for idx, prompt in enumerate(PROMPTS, start=1):
output_path = output_dir / f"prompt-{idx:02d}-{slugify(prompt)}.png"
seed = 42 + idx - 1
generator = torch.Generator(device).manual_seed(seed)
start_time = time.time()
images = generate(
prompt=prompt,
**components,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator,
)
elapsed = time.time() - start_time
images[0].save(output_path)
print(f"[{idx}/{len(PROMPTS)}] Saved {output_path} in {elapsed:.2f} seconds")
print("Done.")
if __name__ == "__main__":
main()