treadon commited on
Commit
3aa8acd
·
verified ·
1 Parent(s): 669b8a8

Upload t2i.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. t2i.py +174 -0
t2i.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Text-to-Image — hybrid MLX + PyTorch.
2
+
3
+ Phase 1 (MLX): block-diffusion VQ token generation with CFG.
4
+ Phase 2 (PyTorch): SigVQ + ZImageTransformer2DModel + VAE → pixel image.
5
+
6
+ The MLX backbone is released before the PyTorch decoder loads to fit in 64 GB
7
+ unified memory (decoder is ~12 GB, backbone is ~32 GB).
8
+ """
9
+ import argparse
10
+ import gc
11
+ import json
12
+ import os
13
+ import sys
14
+ import time
15
+ from pathlib import Path
16
+
17
+ import mlx.core as mx
18
+ from huggingface_hub import snapshot_download
19
+ from transformers import AutoTokenizer
20
+
21
+ REPO_ROOT = Path(__file__).resolve().parent.parent / "llada2-uni-repo"
22
+ sys.path.insert(0, str(REPO_ROOT))
23
+
24
+ # Stub out flash_attn (not available on Apple Silicon). The decoder has a
25
+ # dispatch_attention_fn fallback via diffusers that we use instead.
26
+ import types as _types, importlib.machinery as _im
27
+ if "flash_attn" not in sys.modules:
28
+ _stub = _types.ModuleType("flash_attn")
29
+ _stub.__spec__ = _im.ModuleSpec(name="flash_attn", loader=None)
30
+ _stub.__version__ = "0.0.0-stub"
31
+ _stub.flash_attn_func = lambda *a, **k: (_ for _ in ()).throw(
32
+ RuntimeError("flash_attn unavailable"))
33
+ sys.modules["flash_attn"] = _stub
34
+
35
+ from llada2.model import LLaDA2Config, LLaDA2Model
36
+ from llada2.weights import load_weights_into_model
37
+ from llada2.generate_image import generate_image_tokens, extract_vq_tokens
38
+
39
+
40
+ def build_t2i_prompt(tokenizer, prompt_text: str, image_h: int, image_w: int):
41
+ """Return (cond_ids, uncond_ids) — prompt id lists for CFG."""
42
+ sys_tmpl = "You are a text-to-image generation assistant."
43
+ # _build_chat equivalent
44
+ sys_ids = tokenizer(f"<role>SYSTEM</role> {sys_tmpl} <role>HUMAN</role>").input_ids
45
+ asst_ids = tokenizer("<role>ASSISTANT</role>").input_ids
46
+
47
+ soi = tokenizer("<|image|>").input_ids
48
+ boi = tokenizer("<boi>").input_ids
49
+ h_tok = tokenizer(f"<|reserved_token_{image_h}|>").input_ids
50
+ w_tok = tokenizer(f"<|reserved_token_{image_w}|>").input_ids
51
+ img_header = soi + h_tok + w_tok + boi
52
+
53
+ cond_ids = sys_ids + tokenizer(prompt_text).input_ids + asst_ids + img_header
54
+ uncond_ids = sys_ids + tokenizer("<uncondition>").input_ids + asst_ids + img_header
55
+ return cond_ids, uncond_ids
56
+
57
+
58
+ def decode_to_pixels(token_ids: list[int], h: int, w: int, model_path: Path,
59
+ decoder_steps: int, resolution_multiplier: int,
60
+ decode_mode: str = "decoder-turbo"):
61
+ """Call the official decoder to render pixels."""
62
+ import torch
63
+ from decoder import decode_vq_tokens
64
+ device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
65
+ return decode_vq_tokens(
66
+ token_ids, h, w, str(model_path), device,
67
+ resolution_multiplier=resolution_multiplier,
68
+ num_steps=decoder_steps, decode_mode=decode_mode,
69
+ )
70
+
71
+
72
+ def main():
73
+ ap = argparse.ArgumentParser()
74
+ ap.add_argument("--prompt", required=True, type=str)
75
+ ap.add_argument("--image-h", default=512, type=int)
76
+ ap.add_argument("--image-w", default=512, type=int)
77
+ ap.add_argument("--steps", default=16, type=int)
78
+ ap.add_argument("--block-length", default=32, type=int)
79
+ ap.add_argument("--cfg-scale", default=4.0, type=float)
80
+ ap.add_argument("--decoder-steps", default=50, type=int)
81
+ ap.add_argument("--decode-mode", default="normal",
82
+ choices=["decoder-turbo", "normal"],
83
+ help="'normal' = full 50-step decoder (cleaner, ~8 min), "
84
+ "'decoder-turbo' = 8-step distilled (faster but brittle ≈ striping)")
85
+ ap.add_argument("--resolution-multiplier", default=2, type=int)
86
+ ap.add_argument("--output", default="t2i_output.png", type=str)
87
+ ap.add_argument("--repo-id", default="inclusionAI/LLaDA2.0-Uni", type=str)
88
+ ap.add_argument("--save-vq", default=None, type=str, help="Save intermediate VQ tokens to .json")
89
+ ap.add_argument("--load-vq", default=None, type=str, help="Skip phase 1, load VQ tokens from .json")
90
+ args = ap.parse_args()
91
+
92
+ print("[t2i] fetching model files…")
93
+ snap = Path(snapshot_download(
94
+ args.repo_id,
95
+ allow_patterns=[
96
+ "model-*.safetensors", "model.safetensors.index.json",
97
+ "config.json", "tokenizer*", "special_tokens_map.json",
98
+ "decoder-turbo/*", "decoder/*", "image_tokenizer/*", "vae/*",
99
+ ],
100
+ ))
101
+
102
+ # Generate image: LLaDA2 divides H and W by 2 internally before computing grid.
103
+ # Net result: grid = (image_h // 2 // 16) x (image_w // 2 // 16)
104
+ grid_h = args.image_h // 2 // 16
105
+ grid_w = args.image_w // 2 // 16
106
+ gen_length = grid_h * grid_w
107
+
108
+ if args.load_vq:
109
+ with open(args.load_vq) as f:
110
+ cached = json.load(f)
111
+ vq_tokens = cached["token_ids"]
112
+ grid_h, grid_w = cached["h"], cached["w"]
113
+ print(f"[t2i] loaded {len(vq_tokens)} VQ tokens from {args.load_vq}")
114
+ else:
115
+ # ---------- Phase 1: MLX VQ-token generation ----------
116
+ tokenizer = AutoTokenizer.from_pretrained(str(snap), trust_remote_code=True)
117
+ config = LLaDA2Config.from_hf(json.loads((snap / "config.json").read_text()))
118
+
119
+ cond_ids, uncond_ids = build_t2i_prompt(tokenizer, args.prompt, grid_h, grid_w)
120
+ print(f"[t2i] prompt tokens: {len(cond_ids)} | grid: {grid_h}x{grid_w} ({gen_length} VQ tokens)")
121
+
122
+ print("[t2i] building model + loading backbone…")
123
+ model = LLaDA2Model(config)
124
+ t0 = time.time()
125
+ load_weights_into_model(model, snap, dtype=mx.bfloat16, verbose=False)
126
+ print(f"[t2i] backbone loaded in {time.time()-t0:.1f}s")
127
+
128
+ prompt_ids = mx.array([cond_ids], dtype=mx.int32)
129
+ uc_ids = mx.array([uncond_ids], dtype=mx.int32)
130
+
131
+ t0 = time.time()
132
+ out = generate_image_tokens(
133
+ model, prompt_ids, uc_ids,
134
+ gen_length=gen_length,
135
+ block_length=args.block_length,
136
+ steps_per_block=args.steps,
137
+ cfg_scale=args.cfg_scale,
138
+ mask_token_id=config.mask_token_id,
139
+ image_token_offset=config.image_token_offset,
140
+ vocab_size=config.vocab_size,
141
+ )
142
+ mx.eval(out)
143
+ vq_tokens = (out[0, len(cond_ids):len(cond_ids) + gen_length] - config.image_token_offset).tolist()
144
+ print(f"[t2i] VQ generation in {time.time()-t0:.1f}s, {len(vq_tokens)} tokens, "
145
+ f"range [{min(vq_tokens)}, {max(vq_tokens)}]")
146
+
147
+ if args.save_vq:
148
+ with open(args.save_vq, "w") as f:
149
+ json.dump({"token_ids": vq_tokens, "h": grid_h, "w": grid_w,
150
+ "prompt": args.prompt}, f)
151
+ print(f"[t2i] saved VQ tokens → {args.save_vq}")
152
+
153
+ # ---------- Free MLX backbone before PyTorch decoder loads ----------
154
+ del model, out
155
+ gc.collect()
156
+ mx.clear_cache()
157
+
158
+ # ---------- Phase 2: PyTorch decode → pixels ----------
159
+ print(f"[t2i] decoding VQ tokens → pixels ({args.decoder_steps} steps)…")
160
+ t0 = time.time()
161
+ img = decode_to_pixels(
162
+ vq_tokens, grid_h, grid_w, snap,
163
+ decoder_steps=args.decoder_steps,
164
+ resolution_multiplier=args.resolution_multiplier,
165
+ decode_mode=args.decode_mode,
166
+ )
167
+ print(f"[t2i] decoded in {time.time()-t0:.1f}s")
168
+
169
+ img.save(args.output)
170
+ print(f"[t2i] wrote {args.output}")
171
+
172
+
173
+ if __name__ == "__main__":
174
+ main()