| |
| """ |
| Benchmark for Flux2-klein 4B and 9B models on AWS Neuron. |
| |
| Usage: |
| |
| torchrun --nproc_per_node=4 flux2-klein/benchmark.py \\ |
| --no-random-weights --model-id black-forest-labs/FLUX.2-klein-9B \\ |
| --num-runs 3 --num-steps 4 |
| |
| Results (trn2.3xlarge, 4 NeuronCores, 512×512, 4 steps, bfloat16): |
| |
| FLUX.2-klein-4B (3.88B params) — eager mode |
| Run Type step01 step02 step03 step04 total |
| 1 COLD 9.348s 0.844s 0.835s 0.860s 11.888s |
| 2 WARM 0.831s 0.835s 0.838s 0.837s 3.342s |
| 3 WARM 0.830s 0.835s 0.831s 0.834s 3.330s |
| 4 WARM 0.836s 0.831s 0.840s 0.838s 3.345s |
| Cold first call (XLA compilation): 9.348s |
| Warm avg/step: 0.835s | 1.198 steps/s | speedup vs cold: 11.2× |
| |
| FLUX.2-klein-9B (9.08B params) — eager mode |
| Run Type step01 step02 step03 step04 total |
| 1 COLD 129.651s 1.276s 1.264s 1.270s 133.461s |
| 2 WARM 1.277s 1.264s 1.267s 1.264s 5.071s |
| 3 WARM 1.265s 1.262s 1.270s 1.263s 5.061s |
| 4 WARM 1.258s 1.274s 1.267s 1.266s 5.065s |
| Cold first call (XLA compilation): 129.651s |
| Warm avg/step: 1.266s | 0.790 steps/s | speedup vs cold: 102.4× |
| |
| FLUX.2-klein-9B (9.08B params) — compile mode (torch.compile, Dynamo+NEFF) |
| Run Type step01 step02 step03 step04 total |
| 1 COLD 264.514s 5.677s 5.675s 5.673s 281.539s |
| 2 WARM 5.676s 5.677s 5.677s 5.673s 22.703s |
| 3 WARM 5.672s 5.676s 5.679s 5.676s 22.702s |
| 4 WARM 5.671s 5.673s 5.673s 5.677s 22.695s |
| Cold first call (Dynamo+NEFF compilation): 264.514s |
| Warm avg/step: 5.675s | 0.176 steps/s |
| |
| Comparison — FLUX.2-klein-9B warm throughput: |
| eager: 1.284s/step (0.779 steps/s) ← 4.4× faster |
| compile: 5.675s/step (0.176 steps/s) |
| |
| Note: compile mode is slower because torch.compile/Dynamo uses the NKI flash |
| attention decomposition (training=True path) and does not benefit from the |
| XLA-level fusions that the lazy-XLA path applies automatically. |
| """ |
|
|
| import argparse |
| import gc |
| import logging |
| import os |
| import sys |
| import time |
|
|
| import torch |
| import torch.distributed as dist |
| from torch.distributed.device_mesh import DeviceMesh |
| from diffusers import Flux2Transformer2DModel, FlowMatchEulerDiscreteScheduler |
| from diffusers.pipelines.flux2.pipeline_flux2_klein import ( |
| Flux2KleinPipeline, |
| compute_empirical_mu, |
| ) |
|
|
| |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| from pipeline import ( |
| apply_tp_flux2_transformer, |
| apply_tp_text_encoder, |
| _encode_prompt_tp, |
| load_text_encoder, |
| load_transformer, |
| _snapshot, |
| ) |
|
|
| import torch_neuronx |
| from torch_neuronx.neuron_dynamo_backend import set_model_name |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s") |
| logger = logging.getLogger(__name__) |
|
|
| DEFAULT_MODEL_ID = "black-forest-labs/FLUX.2-klein-4B" |
|
|
|
|
| |
| |
| |
|
|
| def _prepare_inputs(transformer, height, width, batch_size, text_seq_len, device, seed): |
| """ |
| Compute initial latents, latent position IDs, and text position IDs. |
| |
| Returns (latents_dev, latent_ids_dev, text_ids_dev, latents_cpu) where |
| latents_cpu is kept to reset latents to their original values each run. |
| """ |
| generator = torch.Generator().manual_seed(seed) |
|
|
| vae_scale = 8 |
| lh = 2 * (height // (vae_scale * 2)) |
| lw = 2 * (width // (vae_scale * 2)) |
| seq_len = (lh // 2) * (lw // 2) |
|
|
| latents_cpu = torch.randn( |
| batch_size, seq_len, transformer.config.in_channels, |
| dtype=torch.bfloat16, generator=generator, |
| ) |
|
|
| latent_ids_cpu = ( |
| torch.cartesian_prod( |
| torch.arange(1), torch.arange(lh // 2), |
| torch.arange(lw // 2), torch.arange(1), |
| ) |
| .unsqueeze(0).expand(batch_size, -1, -1).contiguous().float() |
| ) |
|
|
| text_ids_cpu = ( |
| torch.cartesian_prod( |
| torch.arange(1), torch.arange(1), |
| torch.arange(1), torch.arange(text_seq_len), |
| ) |
| .unsqueeze(0).expand(batch_size, -1, -1).contiguous().float() |
| ) |
|
|
| return ( |
| latents_cpu.to(device), |
| latent_ids_cpu.to(device), |
| text_ids_cpu.to(device), |
| latents_cpu, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| def _run_one( |
| run_idx, num_runs, transformer, scheduler, |
| prompt_embeds, latents_init_cpu, latent_ids_dev, text_ids_dev, |
| ts_tensor, num_steps, batch_size, device, rank, |
| ): |
| """ |
| Execute one complete denoising loop and return per-step wall-clock times. |
| |
| Latents are always reset to `latents_init_cpu` at the start so every run |
| is independent. Scheduler step is on rank 0 CPU; updated latents are |
| broadcast to all ranks. |
| |
| Returns: |
| step_times: list[float] — elapsed seconds for each transformer forward. |
| """ |
| is_cold = (run_idx == 0) |
| label = f"Run {run_idx + 1}/{num_runs} ({'COLD' if is_cold else 'WARM':4s})" |
|
|
| latents_dev = latents_init_cpu.to(device) |
|
|
| |
| if rank == 0: |
| scheduler._step_index = None |
|
|
| step_times = [] |
|
|
| if rank == 0: |
| logger.info(f" --- {label} ---") |
|
|
| dist.barrier() |
| t_run = time.time() |
|
|
| with torch.no_grad(): |
| for step_idx in range(num_steps): |
| t_val = ts_tensor[step_idx] |
| timestep = t_val.expand(batch_size).to(torch.bfloat16).to(device) / 1000.0 |
|
|
| dist.barrier() |
| t0 = time.time() |
|
|
| noise_pred = transformer( |
| hidden_states=latents_dev, |
| encoder_hidden_states=prompt_embeds, |
| timestep=timestep, |
| img_ids=latent_ids_dev, |
| txt_ids=text_ids_dev, |
| guidance=None, |
| return_dict=False, |
| )[0] |
|
|
| dist.barrier() |
| elapsed = time.time() - t0 |
| step_times.append(elapsed) |
|
|
| if rank == 0: |
| logger.info( |
| f" step {step_idx + 1:2d}/{num_steps}" |
| f" t={t_val.item():7.1f}" |
| f" elapsed={elapsed:.3f}s" |
| ) |
|
|
| if rank == 0: |
| lat_new = scheduler.step( |
| noise_pred.to("cpu"), t_val.cpu(), latents_dev.to("cpu"), |
| return_dict=False, |
| )[0] |
| latents_dev.copy_(lat_new.to(device)) |
| dist.broadcast(latents_dev, src=0) |
|
|
| if rank == 0: |
| total = time.time() - t_run |
| logger.info(f" run {run_idx + 1} total: {total:.3f}s") |
|
|
| return step_times |
|
|
|
|
| |
| |
| |
|
|
| def _print_summary(mode, model_id, height, width, num_steps, num_runs, all_step_times): |
| """Print a formatted latency table and key metrics to the log.""" |
| SEP = "=" * 72 |
| HSEP = "-" * 72 |
|
|
| cold_label = ( |
| "Dynamo trace + NEFF compilation" if mode == "compile" |
| else "XLA compilation" |
| ) |
|
|
| logger.info(SEP) |
| logger.info(f"BENCHMARK RESULTS | {model_id} | mode={mode}") |
| logger.info(f" {height}x{width} · {num_steps} steps/run · {num_runs} runs") |
| logger.info(HSEP) |
|
|
| step_hdrs = " ".join(f"step{i + 1:02d}" for i in range(num_steps)) |
| logger.info(f"{'Run':<5} {'Type':<5} {step_hdrs} total") |
| logger.info(HSEP) |
| for run_idx, times in enumerate(all_step_times): |
| rtype = "COLD" if run_idx == 0 else "WARM" |
| cells = " ".join(f"{t:6.3f}s" for t in times) |
| logger.info(f"{run_idx + 1:<5} {rtype:<5} {cells} {sum(times):.3f}s") |
|
|
| logger.info(HSEP) |
|
|
| cold_step1 = all_step_times[0][0] |
| logger.info(f" Cold first call (incl. {cold_label}): {cold_step1:.3f}s") |
|
|
| if num_steps > 1: |
| cold_rest = all_step_times[0][1:] |
| avg_cold_rest = sum(cold_rest) / len(cold_rest) |
| logger.info( |
| f" Cold run steps 2-{num_steps} avg : {avg_cold_rest:.3f}s/step" |
| ) |
|
|
| if num_runs > 1: |
| warm_times = [t for times in all_step_times[1:] for t in times] |
| avg_warm = sum(warm_times) / len(warm_times) |
| warm_step1_times = [times[0] for times in all_step_times[1:]] |
| avg_warm_step1 = sum(warm_step1_times) / len(warm_step1_times) |
| logger.info( |
| f" Warm runs — first step avg : {avg_warm_step1:.3f}s/step" |
| ) |
| logger.info( |
| f" Warm runs — all steps avg : {avg_warm:.3f}s/step" |
| ) |
| logger.info( |
| f" Throughput (warm, all steps) : {1.0 / avg_warm:.3f} steps/s" |
| ) |
| logger.info( |
| f" Speedup vs cold first call : {cold_step1 / avg_warm:.1f}x" |
| ) |
|
|
| logger.info(SEP) |
|
|
|
|
| |
| |
| |
|
|
| def benchmark( |
| mode, model_id, prompt, height, width, num_steps, batch_size, |
| num_runs, random_weights, seed, fuse_qkv=False, flash_attn=False, |
| ): |
| assert mode in ("eager", "compile"), f"--mode must be 'eager' or 'compile', got {mode!r}" |
|
|
| dist.init_process_group(backend="neuron") |
| world_size = dist.get_world_size() |
| rank = dist.get_rank() |
| device = torch.neuron.current_device() |
|
|
| tp_mesh = DeviceMesh("neuron", list(range(world_size))) |
|
|
| if rank == 0: |
| logger.info(f"{'=' * 72}") |
| logger.info(f"Flux2-klein benchmark | {model_id} | mode={mode}") |
| logger.info( |
| f" {height}x{width} · {num_steps} steps · {num_runs} runs " |
| f"· batch={batch_size} · random_weights={random_weights}" |
| ) |
| logger.info(f"{'=' * 72}") |
|
|
| xfmr_cfg = Flux2Transformer2DModel.load_config(model_id, subfolder="transformer") |
| joint_attention_dim = xfmr_cfg["joint_attention_dim"] |
| text_seq_len = 512 |
|
|
| |
| |
| |
| if not random_weights: |
| t0 = time.time() |
| text_encoder, tokenizer = load_text_encoder(model_id, random_weights=False) |
| logger.info( |
| f"Rank {rank}: text encoder loaded in {time.time() - t0:.1f}s " |
| f"({sum(p.numel() for p in text_encoder.parameters()) / 1e9:.2f}B params)" |
| ) |
| text_encoder = apply_tp_text_encoder(text_encoder, tp_mesh) |
| text_encoder = text_encoder.to(device) |
| text_encoder.eval() |
| if mode == "compile": |
| set_model_name(f"qwen3_text_encoder_rank{rank}") |
| |
| |
| |
| from transformers.utils.output_capturing import install_all_output_capturing_hooks |
| install_all_output_capturing_hooks(text_encoder) |
| text_encoder = torch.compile(text_encoder, backend="neuron", fullgraph=True) |
| logger.info(f"Rank {rank}: text encoder compiled") |
| gc.collect() |
|
|
| prompt_embeds = _encode_prompt_tp( |
| text_encoder, tokenizer, prompt, batch_size, device) |
| if rank == 0: |
| logger.info(f"Prompt encoded shape={prompt_embeds.shape}") |
|
|
| del text_encoder, tokenizer |
| gc.collect() |
| else: |
| prompt_embeds = torch.zeros( |
| batch_size, text_seq_len, joint_attention_dim, |
| dtype=torch.bfloat16, device=device, |
| ) |
| if rank == 0: |
| prompt_embeds.copy_( |
| torch.randn(batch_size, text_seq_len, joint_attention_dim, |
| dtype=torch.bfloat16).to(device)) |
| dist.broadcast(prompt_embeds, src=0) |
|
|
| |
| |
| |
| t0 = time.time() |
| transformer = load_transformer(model_id, random_weights) |
| logger.info( |
| f"Rank {rank}: transformer loaded in {time.time() - t0:.1f}s " |
| f"({sum(p.numel() for p in transformer.parameters()) / 1e9:.2f}B params)" |
| ) |
|
|
| transformer = apply_tp_flux2_transformer(transformer, tp_mesh, |
| fuse_qkv=fuse_qkv, flash_attn=flash_attn) |
| transformer = transformer.to(device) |
| transformer.eval() |
| if mode == "compile": |
| set_model_name(f"flux2_transformer_rank{rank}") |
| transformer = torch.compile(transformer, backend="neuron", fullgraph=True) |
| logger.info(f"Rank {rank}: transformer compiled (NEFF will build on first call)") |
| gc.collect() |
|
|
| |
| |
| |
| vae_scale = 8 |
| lh = 2 * (height // (vae_scale * 2)) |
| lw = 2 * (width // (vae_scale * 2)) |
| image_seq_len = (lh // 2) * (lw // 2) |
|
|
| if rank == 0: |
| mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_steps) |
| scheduler = FlowMatchEulerDiscreteScheduler() |
| scheduler.set_timesteps(num_steps, mu=mu) |
| ts_float = scheduler.timesteps.float() |
| logger.info(f"Timesteps: {scheduler.timesteps.tolist()}") |
| else: |
| scheduler = FlowMatchEulerDiscreteScheduler() |
| ts_float = torch.zeros(num_steps, dtype=torch.float32) |
|
|
| ts_dev = ts_float.to(device) |
| dist.broadcast(ts_dev, src=0) |
|
|
| |
| |
| |
| latents_dev, latent_ids_dev, text_ids_dev, latents_init_cpu = _prepare_inputs( |
| transformer, height, width, batch_size, text_seq_len, device, seed, |
| ) |
|
|
| |
| |
| |
| |
| |
| dist.barrier() |
| if rank == 0: |
| compile_note = " (run 1 triggers Dynamo+NEFF compile)" if mode == "compile" else "" |
| logger.info( |
| f"Starting {num_runs} benchmark runs ({num_steps} steps each){compile_note} ..." |
| ) |
|
|
| all_step_times = [] |
|
|
| for run_idx in range(num_runs): |
| step_times = _run_one( |
| run_idx, num_runs, transformer, scheduler, |
| prompt_embeds, latents_init_cpu, latent_ids_dev, text_ids_dev, |
| ts_dev, num_steps, batch_size, device, rank, |
| ) |
| all_step_times.append(step_times) |
|
|
| |
| |
| |
| if rank == 0: |
| _print_summary(mode, model_id, height, width, num_steps, num_runs, all_step_times) |
|
|
| dist.barrier() |
| dist.destroy_process_group() |
|
|
|
|
| |
| |
| |
|
|
| def parse_args(): |
| p = argparse.ArgumentParser( |
| description="Flux2-klein latency benchmark (4B / 9B) on Neuron", |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| ) |
| p.add_argument("--mode", choices=["eager", "compile"], default="eager", |
| help="eager: lazy-XLA path. compile: torch.compile Dynamo path.") |
| p.add_argument("--model-id", default=DEFAULT_MODEL_ID, |
| help="HuggingFace model ID (4B or 9B variant)") |
| p.add_argument("--prompt", default="a cat sitting on a Neuron chip, photorealistic") |
| p.add_argument("--height", type=int, default=512) |
| p.add_argument("--width", type=int, default=512) |
| p.add_argument("--num-steps", type=int, default=4, |
| help="Denoising steps per run") |
| p.add_argument("--num-runs", type=int, default=4, |
| help="Total runs: run 1=COLD (compilation), runs 2+=WARM (benchmarked)") |
| p.add_argument("--batch-size", type=int, default=1) |
| p.add_argument("--seed", type=int, default=42) |
| p.add_argument("--random-weights", action="store_true", default=True) |
| p.add_argument("--no-random-weights", action="store_false", dest="random_weights") |
| p.add_argument("--fused-qkv", action="store_true", default=False, |
| help="Use NKI fused QKV kernel for double-stream blocks.") |
| p.add_argument("--flash-attn", action="store_true", default=False, |
| help="Use NKI flash attention kernel for all blocks.") |
| p.add_argument( |
| "--cache-dir", |
| default=None, |
| help=( |
| "Persistent NEFF cache directory (sets TORCH_NEURONX_NEFF_CACHE_DIR). " |
| "Applies to both eager and compile modes. " |
| "NEFFs saved on first run, reloaded on subsequent runs. " |
| "Example: --cache-dir /home/ubuntu/neff_cache" |
| ), |
| ) |
| return p.parse_args() |
|
|
|
|
| if __name__ == "__main__": |
| args = parse_args() |
| |
| |
| |
| cache_dir = args.cache_dir or os.environ.get("TORCH_NEURONX_NEFF_CACHE_DIR", "/tmp/neff_cache") |
| os.environ["TORCH_NEURONX_NEFF_CACHE_DIR"] = cache_dir |
| os.makedirs(cache_dir, exist_ok=True) |
| logger.info(f"NEFF cache dir: {cache_dir}") |
| benchmark( |
| mode=args.mode, |
| model_id=args.model_id, |
| prompt=args.prompt, |
| height=args.height, |
| width=args.width, |
| num_steps=args.num_steps, |
| batch_size=args.batch_size, |
| num_runs=args.num_runs, |
| random_weights=args.random_weights, |
| seed=args.seed, |
| fuse_qkv=args.fused_qkv, |
| flash_attn=args.flash_attn, |
| ) |
|
|