Diffusers
Safetensors
File size: 3,118 Bytes
b67c0bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""Batch prompt inference for Z-Image."""

import os
from pathlib import Path
import time

import torch

from inference import ensure_weights
from utils import AttentionBackend, load_from_local_dir, set_attention_backend
from zimage import generate


def read_prompts(path: str) -> list[str]:
    """Read prompts from a text file (one per line, empty lines skipped)."""

    prompt_path = Path(path)
    if not prompt_path.exists():
        raise FileNotFoundError(f"Prompt file not found: {prompt_path}")
    with prompt_path.open("r", encoding="utf-8") as f:
        prompts = [line.strip() for line in f if line.strip()]
    if not prompts:
        raise ValueError(f"No prompts found in {prompt_path}")
    return prompts


PROMPTS = read_prompts(os.environ.get("PROMPTS_FILE", "prompts/prompt1.txt"))


def slugify(text: str, max_len: int = 60) -> str:
    """Create a filesystem-safe slug from the prompt."""

    slug = "".join(ch.lower() if ch.isalnum() else "-" for ch in text)
    slug = "-".join(part for part in slug.split("-") if part)
    return slug[:max_len].rstrip("-") or "prompt"


def select_device() -> str:
    """Choose the best available device without repeating detection logic."""

    if torch.cuda.is_available():
        print("Chosen device: cuda")
        return "cuda"
    try:
        import torch_xla.core.xla_model as xm

        device = xm.xla_device()
        print("Chosen device: tpu")
        return device
    except (ImportError, RuntimeError):
        if torch.backends.mps.is_available():
            print("Chosen device: mps")
            return "mps"
        print("Chosen device: cpu")
        return "cpu"


def main():
    model_path = ensure_weights("ckpts/Z-Image-Turbo")
    dtype = torch.bfloat16
    compile = False
    height = 1024
    width = 1024
    num_inference_steps = 8
    guidance_scale = 0.0
    attn_backend = os.environ.get("ZIMAGE_ATTENTION", "_native_flash")
    output_dir = Path("outputs")
    output_dir.mkdir(exist_ok=True)

    device = select_device()

    components = load_from_local_dir(model_path, device=device, dtype=dtype, compile=compile)
    AttentionBackend.print_available_backends()
    set_attention_backend(attn_backend)
    print(f"Chosen attention backend: {attn_backend}")

    for idx, prompt in enumerate(PROMPTS, start=1):
        output_path = output_dir / f"prompt-{idx:02d}-{slugify(prompt)}.png"
        seed = 42 + idx - 1
        generator = torch.Generator(device).manual_seed(seed)

        start_time = time.time()
        images = generate(
            prompt=prompt,
            **components,
            height=height,
            width=width,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            generator=generator,
        )
        elapsed = time.time() - start_time
        images[0].save(output_path)
        print(f"[{idx}/{len(PROMPTS)}] Saved {output_path} in {elapsed:.2f} seconds")

    print("Done.")


if __name__ == "__main__":
    main()