#!/usr/bin/env python3 """ Benchmark for Flux2-klein 4B and 9B models on AWS Neuron. Usage: torchrun --nproc_per_node=4 flux2-klein/benchmark.py \\ --no-random-weights --model-id black-forest-labs/FLUX.2-klein-9B \\ --num-runs 3 --num-steps 4 Results (trn2.3xlarge, 4 NeuronCores, 512×512, 4 steps, bfloat16): FLUX.2-klein-4B (3.88B params) — eager mode Run Type step01 step02 step03 step04 total 1 COLD 9.348s 0.844s 0.835s 0.860s 11.888s 2 WARM 0.831s 0.835s 0.838s 0.837s 3.342s 3 WARM 0.830s 0.835s 0.831s 0.834s 3.330s 4 WARM 0.836s 0.831s 0.840s 0.838s 3.345s Cold first call (XLA compilation): 9.348s Warm avg/step: 0.835s | 1.198 steps/s | speedup vs cold: 11.2× FLUX.2-klein-9B (9.08B params) — eager mode Run Type step01 step02 step03 step04 total 1 COLD 129.651s 1.276s 1.264s 1.270s 133.461s 2 WARM 1.277s 1.264s 1.267s 1.264s 5.071s 3 WARM 1.265s 1.262s 1.270s 1.263s 5.061s 4 WARM 1.258s 1.274s 1.267s 1.266s 5.065s Cold first call (XLA compilation): 129.651s Warm avg/step: 1.266s | 0.790 steps/s | speedup vs cold: 102.4× FLUX.2-klein-9B (9.08B params) — compile mode (torch.compile, Dynamo+NEFF) Run Type step01 step02 step03 step04 total 1 COLD 264.514s 5.677s 5.675s 5.673s 281.539s 2 WARM 5.676s 5.677s 5.677s 5.673s 22.703s 3 WARM 5.672s 5.676s 5.679s 5.676s 22.702s 4 WARM 5.671s 5.673s 5.673s 5.677s 22.695s Cold first call (Dynamo+NEFF compilation): 264.514s Warm avg/step: 5.675s | 0.176 steps/s Comparison — FLUX.2-klein-9B warm throughput: eager: 1.284s/step (0.779 steps/s) ← 4.4× faster compile: 5.675s/step (0.176 steps/s) Note: compile mode is slower because torch.compile/Dynamo uses the NKI flash attention decomposition (training=True path) and does not benefit from the XLA-level fusions that the lazy-XLA path applies automatically. """ import argparse import gc import logging import os import sys import time import torch import torch.distributed as dist from torch.distributed.device_mesh import DeviceMesh from diffusers import Flux2Transformer2DModel, FlowMatchEulerDiscreteScheduler from diffusers.pipelines.flux2.pipeline_flux2_klein import ( Flux2KleinPipeline, compute_empirical_mu, ) # Import loading/TP helpers from pipeline.py in the same directory sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from pipeline import ( # noqa: E402 apply_tp_flux2_transformer, apply_tp_text_encoder, _encode_prompt_tp, load_text_encoder, load_transformer, _snapshot, ) import torch_neuronx # noqa: F401, E402 — registers neuron backend from torch_neuronx.neuron_dynamo_backend import set_model_name logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s") logger = logging.getLogger(__name__) DEFAULT_MODEL_ID = "black-forest-labs/FLUX.2-klein-4B" # --------------------------------------------------------------------------- # Latent / position ID preparation # --------------------------------------------------------------------------- def _prepare_inputs(transformer, height, width, batch_size, text_seq_len, device, seed): """ Compute initial latents, latent position IDs, and text position IDs. Returns (latents_dev, latent_ids_dev, text_ids_dev, latents_cpu) where latents_cpu is kept to reset latents to their original values each run. """ generator = torch.Generator().manual_seed(seed) vae_scale = 8 lh = 2 * (height // (vae_scale * 2)) lw = 2 * (width // (vae_scale * 2)) seq_len = (lh // 2) * (lw // 2) latents_cpu = torch.randn( batch_size, seq_len, transformer.config.in_channels, dtype=torch.bfloat16, generator=generator, ) latent_ids_cpu = ( torch.cartesian_prod( torch.arange(1), torch.arange(lh // 2), torch.arange(lw // 2), torch.arange(1), ) .unsqueeze(0).expand(batch_size, -1, -1).contiguous().float() ) text_ids_cpu = ( torch.cartesian_prod( torch.arange(1), torch.arange(1), torch.arange(1), torch.arange(text_seq_len), ) .unsqueeze(0).expand(batch_size, -1, -1).contiguous().float() ) return ( latents_cpu.to(device), latent_ids_cpu.to(device), text_ids_cpu.to(device), latents_cpu, # kept on CPU for resetting between runs ) # --------------------------------------------------------------------------- # Single denoising run # --------------------------------------------------------------------------- def _run_one( run_idx, num_runs, transformer, scheduler, prompt_embeds, latents_init_cpu, latent_ids_dev, text_ids_dev, ts_tensor, num_steps, batch_size, device, rank, ): """ Execute one complete denoising loop and return per-step wall-clock times. Latents are always reset to `latents_init_cpu` at the start so every run is independent. Scheduler step is on rank 0 CPU; updated latents are broadcast to all ranks. Returns: step_times: list[float] — elapsed seconds for each transformer forward. """ is_cold = (run_idx == 0) label = f"Run {run_idx + 1}/{num_runs} ({'COLD' if is_cold else 'WARM':4s})" latents_dev = latents_init_cpu.to(device) # Reset scheduler's internal step counter (avoids IndexError on run 2+) if rank == 0: scheduler._step_index = None step_times = [] if rank == 0: logger.info(f" --- {label} ---") dist.barrier() t_run = time.time() with torch.no_grad(): for step_idx in range(num_steps): t_val = ts_tensor[step_idx] timestep = t_val.expand(batch_size).to(torch.bfloat16).to(device) / 1000.0 dist.barrier() t0 = time.time() noise_pred = transformer( hidden_states=latents_dev, encoder_hidden_states=prompt_embeds, timestep=timestep, img_ids=latent_ids_dev, txt_ids=text_ids_dev, guidance=None, return_dict=False, )[0] dist.barrier() elapsed = time.time() - t0 step_times.append(elapsed) if rank == 0: logger.info( f" step {step_idx + 1:2d}/{num_steps}" f" t={t_val.item():7.1f}" f" elapsed={elapsed:.3f}s" ) if rank == 0: lat_new = scheduler.step( noise_pred.to("cpu"), t_val.cpu(), latents_dev.to("cpu"), return_dict=False, )[0] latents_dev.copy_(lat_new.to(device)) dist.broadcast(latents_dev, src=0) if rank == 0: total = time.time() - t_run logger.info(f" run {run_idx + 1} total: {total:.3f}s") return step_times # --------------------------------------------------------------------------- # Summary reporting # --------------------------------------------------------------------------- def _print_summary(mode, model_id, height, width, num_steps, num_runs, all_step_times): """Print a formatted latency table and key metrics to the log.""" SEP = "=" * 72 HSEP = "-" * 72 cold_label = ( "Dynamo trace + NEFF compilation" if mode == "compile" else "XLA compilation" ) logger.info(SEP) logger.info(f"BENCHMARK RESULTS | {model_id} | mode={mode}") logger.info(f" {height}x{width} · {num_steps} steps/run · {num_runs} runs") logger.info(HSEP) step_hdrs = " ".join(f"step{i + 1:02d}" for i in range(num_steps)) logger.info(f"{'Run':<5} {'Type':<5} {step_hdrs} total") logger.info(HSEP) for run_idx, times in enumerate(all_step_times): rtype = "COLD" if run_idx == 0 else "WARM" cells = " ".join(f"{t:6.3f}s" for t in times) logger.info(f"{run_idx + 1:<5} {rtype:<5} {cells} {sum(times):.3f}s") logger.info(HSEP) cold_step1 = all_step_times[0][0] logger.info(f" Cold first call (incl. {cold_label}): {cold_step1:.3f}s") if num_steps > 1: cold_rest = all_step_times[0][1:] avg_cold_rest = sum(cold_rest) / len(cold_rest) logger.info( f" Cold run steps 2-{num_steps} avg : {avg_cold_rest:.3f}s/step" ) if num_runs > 1: warm_times = [t for times in all_step_times[1:] for t in times] avg_warm = sum(warm_times) / len(warm_times) warm_step1_times = [times[0] for times in all_step_times[1:]] avg_warm_step1 = sum(warm_step1_times) / len(warm_step1_times) logger.info( f" Warm runs — first step avg : {avg_warm_step1:.3f}s/step" ) logger.info( f" Warm runs — all steps avg : {avg_warm:.3f}s/step" ) logger.info( f" Throughput (warm, all steps) : {1.0 / avg_warm:.3f} steps/s" ) logger.info( f" Speedup vs cold first call : {cold_step1 / avg_warm:.1f}x" ) logger.info(SEP) # --------------------------------------------------------------------------- # Main benchmark entry point # --------------------------------------------------------------------------- def benchmark( mode, model_id, prompt, height, width, num_steps, batch_size, num_runs, random_weights, seed, fuse_qkv=False, flash_attn=False, ): assert mode in ("eager", "compile"), f"--mode must be 'eager' or 'compile', got {mode!r}" dist.init_process_group(backend="neuron") world_size = dist.get_world_size() rank = dist.get_rank() device = torch.neuron.current_device() tp_mesh = DeviceMesh("neuron", list(range(world_size))) if rank == 0: logger.info(f"{'=' * 72}") logger.info(f"Flux2-klein benchmark | {model_id} | mode={mode}") logger.info( f" {height}x{width} · {num_steps} steps · {num_runs} runs " f"· batch={batch_size} · random_weights={random_weights}" ) logger.info(f"{'=' * 72}") xfmr_cfg = Flux2Transformer2DModel.load_config(model_id, subfolder="transformer") joint_attention_dim = xfmr_cfg["joint_attention_dim"] text_seq_len = 512 # ------------------------------------------------------------------ # 1. Text encoder: all ranks load & TP-encode, then free # ------------------------------------------------------------------ if not random_weights: t0 = time.time() text_encoder, tokenizer = load_text_encoder(model_id, random_weights=False) logger.info( f"Rank {rank}: text encoder loaded in {time.time() - t0:.1f}s " f"({sum(p.numel() for p in text_encoder.parameters()) / 1e9:.2f}B params)" ) text_encoder = apply_tp_text_encoder(text_encoder, tp_mesh) text_encoder = text_encoder.to(device) text_encoder.eval() if mode == "compile": set_model_name(f"qwen3_text_encoder_rank{rank}") # Pre-install output-capturing hooks so _output_capturing_hooks_installed=True; # the maybe_install_capturing_hooks early-return fires before the threading.Lock # that torch.compile(fullgraph=True) cannot trace. See pipeline.py for full note. from transformers.utils.output_capturing import install_all_output_capturing_hooks install_all_output_capturing_hooks(text_encoder) text_encoder = torch.compile(text_encoder, backend="neuron", fullgraph=True) logger.info(f"Rank {rank}: text encoder compiled") gc.collect() prompt_embeds = _encode_prompt_tp( text_encoder, tokenizer, prompt, batch_size, device) if rank == 0: logger.info(f"Prompt encoded shape={prompt_embeds.shape}") del text_encoder, tokenizer gc.collect() else: prompt_embeds = torch.zeros( batch_size, text_seq_len, joint_attention_dim, dtype=torch.bfloat16, device=device, ) if rank == 0: prompt_embeds.copy_( torch.randn(batch_size, text_seq_len, joint_attention_dim, dtype=torch.bfloat16).to(device)) dist.broadcast(prompt_embeds, src=0) # ------------------------------------------------------------------ # 2. Transformer: all ranks load, TP, move to Neuron [+ compile] # ------------------------------------------------------------------ t0 = time.time() transformer = load_transformer(model_id, random_weights) logger.info( f"Rank {rank}: transformer loaded in {time.time() - t0:.1f}s " f"({sum(p.numel() for p in transformer.parameters()) / 1e9:.2f}B params)" ) transformer = apply_tp_flux2_transformer(transformer, tp_mesh, fuse_qkv=fuse_qkv, flash_attn=flash_attn) transformer = transformer.to(device) transformer.eval() if mode == "compile": set_model_name(f"flux2_transformer_rank{rank}") transformer = torch.compile(transformer, backend="neuron", fullgraph=True) logger.info(f"Rank {rank}: transformer compiled (NEFF will build on first call)") gc.collect() # ------------------------------------------------------------------ # 3. Scheduler timesteps (computed once, reused for all runs) # ------------------------------------------------------------------ vae_scale = 8 lh = 2 * (height // (vae_scale * 2)) lw = 2 * (width // (vae_scale * 2)) image_seq_len = (lh // 2) * (lw // 2) if rank == 0: mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_steps) scheduler = FlowMatchEulerDiscreteScheduler() scheduler.set_timesteps(num_steps, mu=mu) ts_float = scheduler.timesteps.float() logger.info(f"Timesteps: {scheduler.timesteps.tolist()}") else: scheduler = FlowMatchEulerDiscreteScheduler() ts_float = torch.zeros(num_steps, dtype=torch.float32) ts_dev = ts_float.to(device) dist.broadcast(ts_dev, src=0) # ------------------------------------------------------------------ # 4. Initial latents and position IDs # ------------------------------------------------------------------ latents_dev, latent_ids_dev, text_ids_dev, latents_init_cpu = _prepare_inputs( transformer, height, width, batch_size, text_seq_len, device, seed, ) # ------------------------------------------------------------------ # 5. Benchmark loop # Run 1 (COLD): triggers compilation (XLA or Dynamo+NEFF) # Runs 2+ (WARM): reuse compiled graph # ------------------------------------------------------------------ dist.barrier() if rank == 0: compile_note = " (run 1 triggers Dynamo+NEFF compile)" if mode == "compile" else "" logger.info( f"Starting {num_runs} benchmark runs ({num_steps} steps each){compile_note} ..." ) all_step_times = [] for run_idx in range(num_runs): step_times = _run_one( run_idx, num_runs, transformer, scheduler, prompt_embeds, latents_init_cpu, latent_ids_dev, text_ids_dev, ts_dev, num_steps, batch_size, device, rank, ) all_step_times.append(step_times) # ------------------------------------------------------------------ # 6. Summary (rank 0 only) # ------------------------------------------------------------------ if rank == 0: _print_summary(mode, model_id, height, width, num_steps, num_runs, all_step_times) dist.barrier() dist.destroy_process_group() # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- def parse_args(): p = argparse.ArgumentParser( description="Flux2-klein latency benchmark (4B / 9B) on Neuron", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) p.add_argument("--mode", choices=["eager", "compile"], default="eager", help="eager: lazy-XLA path. compile: torch.compile Dynamo path.") p.add_argument("--model-id", default=DEFAULT_MODEL_ID, help="HuggingFace model ID (4B or 9B variant)") p.add_argument("--prompt", default="a cat sitting on a Neuron chip, photorealistic") p.add_argument("--height", type=int, default=512) p.add_argument("--width", type=int, default=512) p.add_argument("--num-steps", type=int, default=4, help="Denoising steps per run") p.add_argument("--num-runs", type=int, default=4, help="Total runs: run 1=COLD (compilation), runs 2+=WARM (benchmarked)") p.add_argument("--batch-size", type=int, default=1) p.add_argument("--seed", type=int, default=42) p.add_argument("--random-weights", action="store_true", default=True) p.add_argument("--no-random-weights", action="store_false", dest="random_weights") p.add_argument("--fused-qkv", action="store_true", default=False, help="Use NKI fused QKV kernel for double-stream blocks.") p.add_argument("--flash-attn", action="store_true", default=False, help="Use NKI flash attention kernel for all blocks.") p.add_argument( "--cache-dir", default=None, help=( "Persistent NEFF cache directory (sets TORCH_NEURONX_NEFF_CACHE_DIR). " "Applies to both eager and compile modes. " "NEFFs saved on first run, reloaded on subsequent runs. " "Example: --cache-dir /home/ubuntu/neff_cache" ), ) return p.parse_args() if __name__ == "__main__": args = parse_args() # Always set the NEFF cache dir regardless of mode — both eager (lazy-XLA) # and compile (Dynamo) paths use TORCH_NEURONX_NEFF_CACHE_DIR to persist # compiled NEFFs across runs. Default /tmp/neff_cache is lost on reboot. cache_dir = args.cache_dir or os.environ.get("TORCH_NEURONX_NEFF_CACHE_DIR", "/tmp/neff_cache") os.environ["TORCH_NEURONX_NEFF_CACHE_DIR"] = cache_dir os.makedirs(cache_dir, exist_ok=True) logger.info(f"NEFF cache dir: {cache_dir}") benchmark( mode=args.mode, model_id=args.model_id, prompt=args.prompt, height=args.height, width=args.width, num_steps=args.num_steps, batch_size=args.batch_size, num_runs=args.num_runs, random_weights=args.random_weights, seed=args.seed, fuse_qkv=args.fused_qkv, flash_attn=args.flash_attn, )