| """Batch prompt inference for Z-Image.""" | |
| import os | |
| from pathlib import Path | |
| import time | |
| import torch | |
| from inference import ensure_weights | |
| from utils import AttentionBackend, load_from_local_dir, set_attention_backend | |
| from zimage import generate | |
| def read_prompts(path: str) -> list[str]: | |
| """Read prompts from a text file (one per line, empty lines skipped).""" | |
| prompt_path = Path(path) | |
| if not prompt_path.exists(): | |
| raise FileNotFoundError(f"Prompt file not found: {prompt_path}") | |
| with prompt_path.open("r", encoding="utf-8") as f: | |
| prompts = [line.strip() for line in f if line.strip()] | |
| if not prompts: | |
| raise ValueError(f"No prompts found in {prompt_path}") | |
| return prompts | |
| PROMPTS = read_prompts(os.environ.get("PROMPTS_FILE", "prompts/prompt1.txt")) | |
| def slugify(text: str, max_len: int = 60) -> str: | |
| """Create a filesystem-safe slug from the prompt.""" | |
| slug = "".join(ch.lower() if ch.isalnum() else "-" for ch in text) | |
| slug = "-".join(part for part in slug.split("-") if part) | |
| return slug[:max_len].rstrip("-") or "prompt" | |
| def select_device() -> str: | |
| """Choose the best available device without repeating detection logic.""" | |
| if torch.cuda.is_available(): | |
| print("Chosen device: cuda") | |
| return "cuda" | |
| try: | |
| import torch_xla.core.xla_model as xm | |
| device = xm.xla_device() | |
| print("Chosen device: tpu") | |
| return device | |
| except (ImportError, RuntimeError): | |
| if torch.backends.mps.is_available(): | |
| print("Chosen device: mps") | |
| return "mps" | |
| print("Chosen device: cpu") | |
| return "cpu" | |
| def main(): | |
| model_path = ensure_weights("ckpts/Z-Image-Turbo") | |
| dtype = torch.bfloat16 | |
| compile = False | |
| height = 1024 | |
| width = 1024 | |
| num_inference_steps = 8 | |
| guidance_scale = 0.0 | |
| attn_backend = os.environ.get("ZIMAGE_ATTENTION", "_native_flash") | |
| output_dir = Path("outputs") | |
| output_dir.mkdir(exist_ok=True) | |
| device = select_device() | |
| components = load_from_local_dir(model_path, device=device, dtype=dtype, compile=compile) | |
| AttentionBackend.print_available_backends() | |
| set_attention_backend(attn_backend) | |
| print(f"Chosen attention backend: {attn_backend}") | |
| for idx, prompt in enumerate(PROMPTS, start=1): | |
| output_path = output_dir / f"prompt-{idx:02d}-{slugify(prompt)}.png" | |
| seed = 42 + idx - 1 | |
| generator = torch.Generator(device).manual_seed(seed) | |
| start_time = time.time() | |
| images = generate( | |
| prompt=prompt, | |
| **components, | |
| height=height, | |
| width=width, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| generator=generator, | |
| ) | |
| elapsed = time.time() - start_time | |
| images[0].save(output_path) | |
| print(f"[{idx}/{len(PROMPTS)}] Saved {output_path} in {elapsed:.2f} seconds") | |
| print("Done.") | |
| if __name__ == "__main__": | |
| main() | |