#!/usr/bin/env python3 """ Full end-to-end Flux2 klein pipeline on AWS Neuron. Supports two execution modes selected with ``--mode``: eager (default) — lazy-XLA path. Models are moved to the Neuron device with ``.to(device)``; the XLA compiler traces and emits a NEFF on the first forward call. No explicit compile step needed; graph breaks are handled transparently. compile — torch.compile path (Dynamo backend="neuron"). TorchDynamo captures the full FX graph and compiles a NEFF on the first forward call. ``fullgraph=True`` disallows graph breaks. All three models (text encoder, transformer, VAE) are compiled. NEFF caching (compile mode only): Compiled NEFFs are stored in two caches: 1. In-process C++ cache (NeuronResourceManager) — keyed by StableHLO hash. Within a single torchrun the NEFF is compiled once (cold step 1) and all subsequent steps reuse it (warm). 2. Persistent file cache (TORCH_NEURONX_NEFF_CACHE_DIR, default /tmp/neff_cache). Pass ``--cache-dir /persistent/path`` to keep NEFFs across reboots / torchrun invocations so the cold step is skipped entirely on re-runs. Components and placement: Text encoder (Qwen3ForCausalLM) — Neuron TP [+ torch.compile] Scheduler (FlowMatchEulerDiscreteScheduler) — CPU Transformer (Flux2Transformer2DModel) — Neuron TP [+ torch.compile] VAE decoder (AutoencoderKLFlux2) — Neuron rank-0 [+ torch.compile] Usage: # eager mode (default) torchrun --nproc_per_node=4 flux2-klein/pipeline.py --random-weights --num-steps 4 # compile mode — first run compiles and caches NEFFs torchrun --nproc_per_node=4 flux2-klein/pipeline.py --mode compile \\ --cache-dir /home/ubuntu/neff_cache --random-weights --num-steps 4 # compile mode — second run loads cached NEFFs, no recompilation torchrun --nproc_per_node=4 flux2-klein/pipeline.py --mode compile \\ --cache-dir /home/ubuntu/neff_cache --no-random-weights \\ --model-id black-forest-labs/FLUX.2-klein-9B \\ --prompt "a cat sitting on a Neuron chip" --height 512 --width 512 """ import argparse import gc import logging import os import time import torch import torch.distributed as dist from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor.parallel import ( ColwiseParallel, RowwiseParallel, parallelize_module, ) from transformers import Qwen3ForCausalLM, Qwen2TokenizerFast from diffusers import AutoencoderKLFlux2, Flux2Transformer2DModel, FlowMatchEulerDiscreteScheduler from diffusers.models.embeddings import apply_rotary_emb from diffusers.models.attention_dispatch import dispatch_attention_fn from diffusers.models.transformers.transformer_flux2 import ( Flux2AttnProcessor, Flux2ParallelSelfAttnProcessor, _get_qkv_projections, ) from diffusers.pipelines.flux2.pipeline_flux2_klein import ( Flux2KleinPipeline, compute_empirical_mu, ) import torch_neuronx # noqa: F401 — registers neuron backend from torch_neuronx.neuron_dynamo_backend import set_model_name logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s") logger = logging.getLogger(__name__) DEFAULT_MODEL_ID = "black-forest-labs/FLUX.2-klein-4B" # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _to_local(t: torch.Tensor) -> torch.Tensor: """Extract local shard from a DTensor, or return tensor unchanged.""" return t.to_local() if hasattr(t, "to_local") else t def _build_fused_qkv_weight(q_w, k_w, v_w) -> torch.Tensor: """ Concatenate 3 ColwisePar local shards into one fused QKV weight [H, 3*(H//tp)]. Each shard has shape [H//tp, H]; transposed → [H, H//tp]; cat → [H, 3*(H//tp)]. """ return torch.cat([_to_local(q_w).T, _to_local(k_w).T, _to_local(v_w).T], dim=1).contiguous() # --------------------------------------------------------------------------- # TP-aware attention processors # --------------------------------------------------------------------------- class Flux2AttnProcessorFlashAttn(Flux2AttnProcessor): """ TP-aware double-stream attention processor using a NKI flash attention kernel. Replaces dispatch_attention_fn with flux2_flash_attn (BSHD in, BSHD out). """ def __init__(self, tp_size: int): super().__init__() self.tp_size = tp_size self._flash_attn = None def _init_nki(self): if self._flash_attn is None: import sys, os _dir = os.path.dirname(os.path.abspath(__file__)) if _dir not in sys.path: sys.path.insert(0, _dir) from nki_flash_attn import flux2_flash_attn self._flash_attn = flux2_flash_attn def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, image_rotary_emb=None, **kwargs): self._init_nki() local_heads = attn.heads // self.tp_size head_dim = attn.head_dim query, key, value, encoder_query, encoder_key, encoder_value = \ _get_qkv_projections(attn, hidden_states, encoder_hidden_states) query = query.unflatten(-1, (local_heads, head_dim)) key = key.unflatten(-1, (local_heads, head_dim)) value = value.unflatten(-1, (local_heads, head_dim)) query = attn.norm_q(query) key = attn.norm_k(key) if attn.added_kv_proj_dim is not None: encoder_query = encoder_query.unflatten(-1, (local_heads, head_dim)) encoder_key = encoder_key.unflatten(-1, (local_heads, head_dim)) encoder_value = encoder_value.unflatten(-1, (local_heads, head_dim)) encoder_query = attn.norm_added_q(encoder_query) encoder_key = attn.norm_added_k(encoder_key) query = torch.cat([encoder_query, query], dim=1) key = torch.cat([encoder_key, key], dim=1) value = torch.cat([encoder_value, value], dim=1) if image_rotary_emb is not None: query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) hidden_states = self._flash_attn(query, key, value) # (B, S, H_local, D) BSHD hidden_states = hidden_states.flatten(2, 3).to(query.dtype) if encoder_hidden_states is not None: encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1) encoder_hidden_states = attn.to_add_out(encoder_hidden_states) hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) if encoder_hidden_states is not None: return hidden_states, encoder_hidden_states return hidden_states class Flux2AttnProcessorTP(Flux2AttnProcessor): def __init__(self, tp_size: int): super().__init__() self.tp_size = tp_size def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, image_rotary_emb=None, **kwargs): local_heads = attn.heads // self.tp_size head_dim = attn.head_dim query, key, value, encoder_query, encoder_key, encoder_value = \ _get_qkv_projections(attn, hidden_states, encoder_hidden_states) query = query.unflatten(-1, (local_heads, head_dim)) key = key.unflatten(-1, (local_heads, head_dim)) value = value.unflatten(-1, (local_heads, head_dim)) query = attn.norm_q(query) key = attn.norm_k(key) if attn.added_kv_proj_dim is not None: encoder_query = encoder_query.unflatten(-1, (local_heads, head_dim)) encoder_key = encoder_key.unflatten(-1, (local_heads, head_dim)) encoder_value = encoder_value.unflatten(-1, (local_heads, head_dim)) encoder_query = attn.norm_added_q(encoder_query) encoder_key = attn.norm_added_k(encoder_key) query = torch.cat([encoder_query, query], dim=1) key = torch.cat([encoder_key, key], dim=1) value = torch.cat([encoder_value, value], dim=1) if image_rotary_emb is not None: query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) hidden_states = dispatch_attention_fn( query, key, value, attn_mask=attention_mask, backend=self._attention_backend, parallel_config=self._parallel_config, ) hidden_states = hidden_states.flatten(2, 3).to(query.dtype) if encoder_hidden_states is not None: encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 ) encoder_hidden_states = attn.to_add_out(encoder_hidden_states) hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) if encoder_hidden_states is not None: return hidden_states, encoder_hidden_states return hidden_states class Flux2AttnProcessorFusedQKV(Flux2AttnProcessor): """ TP-aware double-stream attention processor using a single NKI fused QKV kernel. Replaces 6 separate ColwisePar matmuls (to_q/to_k/to_v for image, add_q_proj/add_k_proj/add_v_proj for text) with 2 NKI qkv kernel calls — one per stream — each computing Q+K+V in a single fused HBM pass. Fused weight [H, 3*(H//tp)] is built lazily on the first forward call by concatenating the TP-sharded local weight tensors, then cached as attn._fused_img_qkv_w and attn._fused_txt_qkv_w. Eager mode only: the lazy weight cache relies on plain tensor attributes that torch.compile(fullgraph=True) would not trace through. """ def __init__(self, tp_size: int): super().__init__() self.tp_size = tp_size self._nki_qkv = None def _init_nki(self): if self._nki_qkv is not None: return import sys _dir = os.path.dirname(os.path.abspath(__file__)) if _dir not in sys.path: sys.path.insert(0, _dir) from nki_qkv import fused_qkv_kernel self._nki_qkv = fused_qkv_kernel def _fused_proj(self, attn, q_attr, k_attr, v_attr, cache_attr, x): """Build (once) and apply the fused [H, 3*(H//tp)] QKV weight.""" fused_w = getattr(attn, cache_attr, None) if fused_w is None: fused_w = _build_fused_qkv_weight( getattr(attn, q_attr).weight, getattr(attn, k_attr).weight, getattr(attn, v_attr).weight, ).to(x.device) setattr(attn, cache_attr, fused_w) local_dim = attn.inner_dim // self.tp_size B, S, H = x.shape x_T = x.reshape(B * S, H).T.contiguous() # [H, B*S] — pre-transposed for kernel out = self._nki_qkv(x_T, fused_w) # [B*S, 3*local_dim] return out.reshape(B, S, -1).split(local_dim, dim=-1) # q, k, v each [B, S, local_dim] def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, image_rotary_emb=None, **kwargs): self._init_nki() local_heads = attn.heads // self.tp_size head_dim = attn.head_dim # Image stream: fused QKV query, key, value = self._fused_proj( attn, "to_q", "to_k", "to_v", "_fused_img_qkv_w", hidden_states) query = query.unflatten(-1, (local_heads, head_dim)) key = key.unflatten(-1, (local_heads, head_dim)) value = value.unflatten(-1, (local_heads, head_dim)) query = attn.norm_q(query) key = attn.norm_k(key) if attn.added_kv_proj_dim is not None: # Text stream: fused QKV encoder_query, encoder_key, encoder_value = self._fused_proj( attn, "add_q_proj", "add_k_proj", "add_v_proj", "_fused_txt_qkv_w", encoder_hidden_states) encoder_query = encoder_query.unflatten(-1, (local_heads, head_dim)) encoder_key = encoder_key.unflatten(-1, (local_heads, head_dim)) encoder_value = encoder_value.unflatten(-1, (local_heads, head_dim)) encoder_query = attn.norm_added_q(encoder_query) encoder_key = attn.norm_added_k(encoder_key) query = torch.cat([encoder_query, query], dim=1) key = torch.cat([encoder_key, key], dim=1) value = torch.cat([encoder_value, value], dim=1) if image_rotary_emb is not None: query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) hidden_states = dispatch_attention_fn( query, key, value, attn_mask=attention_mask, backend=self._attention_backend, parallel_config=self._parallel_config, ) hidden_states = hidden_states.flatten(2, 3).to(query.dtype) if encoder_hidden_states is not None: encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1, ) encoder_hidden_states = attn.to_add_out(encoder_hidden_states) hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) if encoder_hidden_states is not None: return hidden_states, encoder_hidden_states return hidden_states class Flux2AttnProcessorFusedQKVFlashAttn(Flux2AttnProcessor): """ TP-aware double-stream processor: NKI fused QKV projection + NKI flash attention. Combines Flux2AttnProcessorFusedQKV (QKV side) with Flux2AttnProcessorFlashAttn (attention side). Activated by --fused-qkv --flash-attn together. """ def __init__(self, tp_size: int): super().__init__() self.tp_size = tp_size self._nki_qkv = None self._flash_attn = None def _init_nki(self): if self._nki_qkv is None: import sys _dir = os.path.dirname(os.path.abspath(__file__)) if _dir not in sys.path: sys.path.insert(0, _dir) from nki_qkv import fused_qkv_kernel self._nki_qkv = fused_qkv_kernel if self._flash_attn is None: import sys _dir = os.path.dirname(os.path.abspath(__file__)) if _dir not in sys.path: sys.path.insert(0, _dir) from nki_flash_attn import flux2_flash_attn self._flash_attn = flux2_flash_attn def _fused_proj(self, attn, q_attr, k_attr, v_attr, cache_attr, x): fused_w = getattr(attn, cache_attr, None) if fused_w is None: fused_w = _build_fused_qkv_weight( getattr(attn, q_attr).weight, getattr(attn, k_attr).weight, getattr(attn, v_attr).weight, ).to(x.device) setattr(attn, cache_attr, fused_w) local_dim = attn.inner_dim // self.tp_size B, S, H = x.shape x_T = x.reshape(B * S, H).T.contiguous() out = self._nki_qkv(x_T, fused_w) return out.reshape(B, S, -1).split(local_dim, dim=-1) def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, image_rotary_emb=None, **kwargs): self._init_nki() local_heads = attn.heads // self.tp_size head_dim = attn.head_dim query, key, value = self._fused_proj( attn, "to_q", "to_k", "to_v", "_fused_img_qkv_w", hidden_states) query = query.unflatten(-1, (local_heads, head_dim)) key = key.unflatten(-1, (local_heads, head_dim)) value = value.unflatten(-1, (local_heads, head_dim)) query = attn.norm_q(query) key = attn.norm_k(key) if attn.added_kv_proj_dim is not None: encoder_query, encoder_key, encoder_value = self._fused_proj( attn, "add_q_proj", "add_k_proj", "add_v_proj", "_fused_txt_qkv_w", encoder_hidden_states) encoder_query = encoder_query.unflatten(-1, (local_heads, head_dim)) encoder_key = encoder_key.unflatten(-1, (local_heads, head_dim)) encoder_value = encoder_value.unflatten(-1, (local_heads, head_dim)) encoder_query = attn.norm_added_q(encoder_query) encoder_key = attn.norm_added_k(encoder_key) query = torch.cat([encoder_query, query], dim=1) key = torch.cat([encoder_key, key], dim=1) value = torch.cat([encoder_value, value], dim=1) if image_rotary_emb is not None: query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) hidden_states = self._flash_attn(query, key, value) hidden_states = hidden_states.flatten(2, 3).to(query.dtype) if encoder_hidden_states is not None: encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1) encoder_hidden_states = attn.to_add_out(encoder_hidden_states) hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) if encoder_hidden_states is not None: return hidden_states, encoder_hidden_states return hidden_states class Flux2ParallelSelfAttnProcessorTP(Flux2ParallelSelfAttnProcessor): def __init__(self, tp_size: int): super().__init__() self.tp_size = tp_size def __call__(self, attn, hidden_states, attention_mask=None, image_rotary_emb=None, **kwargs): local_heads = attn.heads // self.tp_size head_dim = attn.head_dim local_inner = attn.inner_dim // self.tp_size local_mlp_gate = attn.mlp_hidden_dim * attn.mlp_mult_factor // self.tp_size hidden_states = attn.to_qkv_mlp_proj(hidden_states) qkv, mlp_hidden_states = torch.split( hidden_states, [3 * local_inner, local_mlp_gate], dim=-1) query, key, value = qkv.chunk(3, dim=-1) query = query.unflatten(-1, (local_heads, head_dim)) key = key.unflatten(-1, (local_heads, head_dim)) value = value.unflatten(-1, (local_heads, head_dim)) query = attn.norm_q(query) key = attn.norm_k(key) if image_rotary_emb is not None: query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) hidden_states = dispatch_attention_fn( query, key, value, attn_mask=attention_mask, backend=self._attention_backend, parallel_config=self._parallel_config, ) hidden_states = hidden_states.flatten(2, 3).to(query.dtype) mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states) hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1) return attn.to_out(hidden_states) class Flux2ParallelSelfAttnProcessorFlashAttn(Flux2ParallelSelfAttnProcessor): """ TP-aware single-stream attention processor using a NKI flash attention kernel. """ def __init__(self, tp_size: int): super().__init__() self.tp_size = tp_size self._flash_attn = None def _init_nki(self): if self._flash_attn is None: import sys, os _dir = os.path.dirname(os.path.abspath(__file__)) if _dir not in sys.path: sys.path.insert(0, _dir) from nki_flash_attn import flux2_flash_attn self._flash_attn = flux2_flash_attn def __call__(self, attn, hidden_states, attention_mask=None, image_rotary_emb=None, **kwargs): self._init_nki() local_heads = attn.heads // self.tp_size head_dim = attn.head_dim local_inner = attn.inner_dim // self.tp_size local_mlp_gate = attn.mlp_hidden_dim * attn.mlp_mult_factor // self.tp_size hidden_states = attn.to_qkv_mlp_proj(hidden_states) qkv, mlp_hidden_states = torch.split( hidden_states, [3 * local_inner, local_mlp_gate], dim=-1) query, key, value = qkv.chunk(3, dim=-1) query = query.unflatten(-1, (local_heads, head_dim)) key = key.unflatten(-1, (local_heads, head_dim)) value = value.unflatten(-1, (local_heads, head_dim)) query = attn.norm_q(query) key = attn.norm_k(key) if image_rotary_emb is not None: query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) hidden_states = self._flash_attn(query, key, value) # (B, S, H_local, D) BSHD hidden_states = hidden_states.flatten(2, 3).to(query.dtype) mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states) hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1) return attn.to_out(hidden_states) # --------------------------------------------------------------------------- # Tensor parallelism # --------------------------------------------------------------------------- def _permute_swiglu_for_tp(weight: torch.Tensor, tp_size: int) -> torch.Tensor: """Permute SwiGLU linear_in rows so ColwisePar gives each rank [gate_i|linear_i].""" with torch.no_grad(): total = weight.shape[0] inner = total // 2 chunk = inner // tp_size gate = weight[:inner] linear = weight[inner:] parts = [] for i in range(tp_size): parts.append(gate[i * chunk : (i + 1) * chunk]) parts.append(linear[i * chunk : (i + 1) * chunk]) return torch.cat(parts, dim=0) def _permute_qkv_mlp_for_tp( weight: torch.Tensor, tp_size: int, inner_dim: int, mlp_hidden_dim: int, ) -> torch.Tensor: """Permute fused QKV+SwiGLU rows so ColwisePar gives rank i [q_i|k_i|v_i|gate_i|lin_i].""" with torch.no_grad(): q = weight[:inner_dim] k = weight[inner_dim : 2 * inner_dim] v = weight[2 * inner_dim : 3 * inner_dim] mlp_gate = weight[3 * inner_dim : 3 * inner_dim + mlp_hidden_dim] mlp_lin = weight[3 * inner_dim + mlp_hidden_dim :] qkv_chunk = inner_dim // tp_size mlp_chunk = mlp_hidden_dim // tp_size parts = [] for i in range(tp_size): parts += [ q[i * qkv_chunk : (i + 1) * qkv_chunk], k[i * qkv_chunk : (i + 1) * qkv_chunk], v[i * qkv_chunk : (i + 1) * qkv_chunk], mlp_gate[i * mlp_chunk : (i + 1) * mlp_chunk], mlp_lin [i * mlp_chunk : (i + 1) * mlp_chunk], ] return torch.cat(parts, dim=0) def _permute_out_for_tp( weight: torch.Tensor, tp_size: int, attn_dim: int, mlp_dim: int, ) -> torch.Tensor: """Permute to_out columns so RowwisePar rank i gets W[:, attn_i|mlp_i].""" with torch.no_grad(): attn_part = weight[:, :attn_dim] mlp_part = weight[:, attn_dim:] attn_chunk = attn_dim // tp_size mlp_chunk = mlp_dim // tp_size parts = [] for i in range(tp_size): parts.append(attn_part[:, i * attn_chunk : (i + 1) * attn_chunk]) parts.append(mlp_part[:, i * mlp_chunk : (i + 1) * mlp_chunk]) return torch.cat(parts, dim=1) def apply_tp_flux2_transformer( model: Flux2Transformer2DModel, tp_mesh: DeviceMesh, fuse_qkv: bool = False, flash_attn: bool = False, ) -> Flux2Transformer2DModel: tp_size = tp_mesh.size() double_plan = { "attn.to_q": ColwiseParallel(), "attn.to_k": ColwiseParallel(), "attn.to_v": ColwiseParallel(), "attn.to_out.0": RowwiseParallel(), "attn.add_q_proj": ColwiseParallel(), "attn.add_k_proj": ColwiseParallel(), "attn.add_v_proj": ColwiseParallel(), "attn.to_add_out": RowwiseParallel(), "ff.linear_in": ColwiseParallel(), "ff.linear_out": RowwiseParallel(), "ff_context.linear_in": ColwiseParallel(), "ff_context.linear_out": RowwiseParallel(), } for block in model.transformer_blocks: block.ff.linear_in.weight.data = _permute_swiglu_for_tp( block.ff.linear_in.weight.data, tp_size) block.ff_context.linear_in.weight.data = _permute_swiglu_for_tp( block.ff_context.linear_in.weight.data, tp_size) parallelize_module(block, tp_mesh, double_plan) if fuse_qkv and flash_attn: block.attn.set_processor(Flux2AttnProcessorFusedQKVFlashAttn(tp_size)) block.attn._fused_img_qkv_w = _build_fused_qkv_weight( block.attn.to_q.weight, block.attn.to_k.weight, block.attn.to_v.weight, ) block.attn._fused_txt_qkv_w = _build_fused_qkv_weight( block.attn.add_q_proj.weight, block.attn.add_k_proj.weight, block.attn.add_v_proj.weight, ) elif flash_attn: block.attn.set_processor(Flux2AttnProcessorFlashAttn(tp_size)) elif fuse_qkv: block.attn.set_processor(Flux2AttnProcessorFusedQKV(tp_size)) # Pre-build fused weights now (real DTensors present) so they are # plain local tensors by the time torch.compile traces the forward. # Lazy build inside _fused_proj fails under Dynamo fake-tensor mode # because _to_local() on a fake DTensor returns the full DTensor. block.attn._fused_img_qkv_w = _build_fused_qkv_weight( block.attn.to_q.weight, block.attn.to_k.weight, block.attn.to_v.weight, ) block.attn._fused_txt_qkv_w = _build_fused_qkv_weight( block.attn.add_q_proj.weight, block.attn.add_k_proj.weight, block.attn.add_v_proj.weight, ) else: block.attn.set_processor(Flux2AttnProcessorTP(tp_size)) single_plan = { "attn.to_qkv_mlp_proj": ColwiseParallel(), "attn.to_out": RowwiseParallel(), } for block in model.single_transformer_blocks: attn = block.attn inner_dim = attn.inner_dim mlp_hidden = attn.mlp_hidden_dim attn.to_qkv_mlp_proj.weight.data = _permute_qkv_mlp_for_tp( attn.to_qkv_mlp_proj.weight.data, tp_size, inner_dim, mlp_hidden) attn.to_out.weight.data = _permute_out_for_tp( attn.to_out.weight.data, tp_size, inner_dim, mlp_hidden) parallelize_module(block, tp_mesh, single_plan) if flash_attn: block.attn.set_processor(Flux2ParallelSelfAttnProcessorFlashAttn(tp_size)) else: block.attn.set_processor(Flux2ParallelSelfAttnProcessorTP(tp_size)) return model def apply_tp_text_encoder(model: Qwen3ForCausalLM, tp_mesh: DeviceMesh) -> Qwen3ForCausalLM: """Llama-style TP plan for Qwen3ForCausalLM; no weight permutations needed.""" layer_plan = { "self_attn.q_proj": ColwiseParallel(), "self_attn.k_proj": ColwiseParallel(), "self_attn.v_proj": ColwiseParallel(), "self_attn.o_proj": RowwiseParallel(), "mlp.gate_proj": ColwiseParallel(), "mlp.up_proj": ColwiseParallel(), "mlp.down_proj": RowwiseParallel(), } for layer in model.model.layers: parallelize_module(layer, tp_mesh, layer_plan) return model def _encode_prompt_tp( text_encoder: Qwen3ForCausalLM, tokenizer: Qwen2TokenizerFast, prompt: str, batch_size: int, device, ) -> torch.Tensor: """ Encode prompt with the TP Qwen3 text encoder on Neuron. All ranks tokenize independently (deterministic — avoids integer-dtype dist.broadcast which fails on Neuron). Returns [B, 512, 7680]. """ prompts = [prompt] * batch_size all_ids, all_masks = [], [] for p in prompts: text = tokenizer.apply_chat_template( [{"role": "user", "content": p}], tokenize=False, add_generation_prompt=True, enable_thinking=False, ) enc = tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=512) all_ids.append(enc["input_ids"]) all_masks.append(enc["attention_mask"]) input_ids = torch.cat(all_ids, dim=0).to(device) attention_mask = torch.cat(all_masks, dim=0).to(device) with torch.no_grad(): output = text_encoder( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, use_cache=False, ) hidden = [] for k in (9, 18, 27): hs = output.hidden_states[k] if hasattr(hs, "to_local"): hs = hs.to_local() hidden.append(hs.to(torch.bfloat16)) out = torch.stack(hidden, dim=1) # [B, 3, 512, 2560] B, _, seq, hdim = out.shape return out.permute(0, 2, 1, 3).reshape(B, seq, 3 * hdim) # [B, 512, 7680] # --------------------------------------------------------------------------- # Model loading (meta-tensor-safe: from_config + safetensors) # --------------------------------------------------------------------------- def _snapshot(model_id: str) -> str: from huggingface_hub import snapshot_download return snapshot_download(model_id) def _load_safetensors(model: torch.nn.Module, subfolder_dir: str) -> None: """Stream safetensors one tensor at a time — avoids doubling peak CPU RAM.""" from safetensors import safe_open shards = sorted(f for f in os.listdir(subfolder_dir) if f.endswith(".safetensors")) if not shards: raise FileNotFoundError(f"No .safetensors in {subfolder_dir}") params = dict(model.named_parameters()) buffers = dict(model.named_buffers()) loaded: set = set() for fname in shards: with safe_open(os.path.join(subfolder_dir, fname), framework="pt", device="cpu") as f: for key in f.keys(): tensor = f.get_tensor(key) if key in params: params[key].data.copy_(tensor) loaded.add(key) elif key in buffers: buffers[key].data.copy_(tensor) loaded.add(key) del tensor missing = [k for k in params if k not in loaded] if missing: logger.warning(" missing keys (first 5): %s", missing[:5]) def load_text_encoder(model_id: str, random_weights: bool): if random_weights: return None, None snap = _snapshot(model_id) logger.info("Loading Qwen3 text encoder ...") model = Qwen3ForCausalLM.from_pretrained( os.path.join(snap, "text_encoder"), dtype=torch.bfloat16) model.eval() tokenizer = Qwen2TokenizerFast.from_pretrained(os.path.join(snap, "tokenizer")) return model, tokenizer def load_transformer(model_id: str, random_weights: bool) -> Flux2Transformer2DModel: config = Flux2Transformer2DModel.load_config(model_id, subfolder="transformer") _prev_dtype = torch.get_default_dtype() torch.set_default_dtype(torch.bfloat16) model = Flux2Transformer2DModel.from_config(config) torch.set_default_dtype(_prev_dtype) if not random_weights: logger.info("Loading transformer weights ...") snap = _snapshot(model_id) _load_safetensors(model, os.path.join(snap, "transformer")) return model def load_vae(model_id: str, random_weights: bool) -> AutoencoderKLFlux2: config = AutoencoderKLFlux2.load_config(model_id, subfolder="vae") _prev_dtype = torch.get_default_dtype() torch.set_default_dtype(torch.bfloat16) vae = AutoencoderKLFlux2.from_config(config) torch.set_default_dtype(_prev_dtype) if not random_weights: logger.info("Loading VAE weights ...") snap = _snapshot(model_id) _load_safetensors(vae, os.path.join(snap, "vae")) vae.eval() return vae # --------------------------------------------------------------------------- # Main run # --------------------------------------------------------------------------- def run(mode, model_id, prompt, height, width, num_steps, batch_size, random_weights, output_path, seed, fuse_qkv=False): assert mode in ("eager", "compile"), f"--mode must be 'eager' or 'compile', got {mode!r}" dist.init_process_group(backend="neuron") world_size = dist.get_world_size() rank = dist.get_rank() device = torch.neuron.current_device() logger.info(f"Rank {rank}/{world_size} on device {device} " f"mode={mode} ({height}x{width}, {num_steps} steps, random={random_weights})") tp_mesh = DeviceMesh("neuron", list(range(world_size))) joint_attention_dim = Flux2Transformer2DModel.load_config( model_id, subfolder="transformer")["joint_attention_dim"] text_seq_len = 512 # ------------------------------------------------------------------ # 1. Text encoder: all ranks load, TP, move to Neuron, encode, free. # # Memory ordering: text enc CPU copy freed before transformer loads # so the two large CPU allocations never overlap: # peak(text enc load) = tp × enc_size (e.g. 4×16 GB for 9B) # peak(xfmr load) = tp × xfmr_size (e.g. 4×17 GB for 9B) # ------------------------------------------------------------------ if not random_weights: t0 = time.time() text_encoder, tokenizer = load_text_encoder(model_id, random_weights) logger.info(f"Rank {rank}: text encoder loaded in {time.time()-t0:.1f}s " f"({sum(p.numel() for p in text_encoder.parameters())/1e9:.2f}B params)") text_encoder = apply_tp_text_encoder(text_encoder, tp_mesh) text_encoder = text_encoder.to(device) text_encoder.eval() if mode == "compile": set_model_name(f"qwen3_text_encoder_rank{rank}") # KNOWN ISSUE (transformers >= 4.52 / diffusers 0.37.0.dev, 2026-03-19): # torch.compile(fullgraph=True) fails with: # torch._dynamo.exc.Unsupported: Unsupported context manager # Root cause: transformers.utils.output_capturing.maybe_install_capturing_hooks() # uses a threading.Lock (_hook_installation_lock) that Dynamo cannot trace. # The function has an early-return guard: # if getattr(model, "_output_capturing_hooks_installed", False): return # Fix: pre-install the hooks so the guard fires before the lock is reached. from transformers.utils.output_capturing import install_all_output_capturing_hooks install_all_output_capturing_hooks(text_encoder) text_encoder = torch.compile(text_encoder, backend="neuron", fullgraph=True) logger.info(f"Rank {rank}: text encoder compiled (NEFF will build on first call)") gc.collect() t0 = time.time() prompt_embeds_dev = _encode_prompt_tp( text_encoder, tokenizer, prompt, batch_size, device) if rank == 0: logger.info(f"Prompt encoded in {time.time()-t0:.1f}s " f"shape={prompt_embeds_dev.shape}") del text_encoder, tokenizer gc.collect() else: # --random-weights: rank 0 generates random embeds, broadcasts to all prompt_embeds_dev = torch.zeros( batch_size, text_seq_len, joint_attention_dim, dtype=torch.bfloat16, device=device) if rank == 0: prompt_embeds_dev.copy_( torch.randn(batch_size, text_seq_len, joint_attention_dim, dtype=torch.bfloat16).to(device)) dist.broadcast(prompt_embeds_dev, src=0) # ------------------------------------------------------------------ # 2. Transformer: all ranks load, TP, move to Neuron. # Text encoder CPU copy is already freed; peak RAM = 4 × xfmr_size. # ------------------------------------------------------------------ t0 = time.time() transformer = load_transformer(model_id, random_weights) logger.info(f"Rank {rank}: transformer loaded in {time.time()-t0:.1f}s " f"({sum(p.numel() for p in transformer.parameters())/1e9:.2f}B params)") transformer = apply_tp_flux2_transformer(transformer, tp_mesh, fuse_qkv=fuse_qkv) transformer = transformer.to(device) transformer.eval() if mode == "compile": set_model_name(f"flux2_transformer_rank{rank}") transformer = torch.compile(transformer, backend="neuron", fullgraph=True) logger.info(f"Rank {rank}: transformer compiled (NEFF will build on first call)") gc.collect() # ------------------------------------------------------------------ # 3. VAE / scheduler / pipeline helpers: rank-0 only. # # Pipeline is created BEFORE torch.compile(vae) because # Flux2KleinPipeline.__init__ calls len(vae.config.block_out_channels) # which fails on an OptimizedModule wrapper. # ------------------------------------------------------------------ pipe = None if rank == 0: snap = _snapshot(model_id) vae = load_vae(model_id, random_weights) vae = vae.to(device) scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( os.path.join(snap, "scheduler")) pipe = Flux2KleinPipeline( scheduler=scheduler, vae=vae, text_encoder=None, tokenizer=None, transformer=transformer, is_distilled=True, ) if mode == "compile": set_model_name("flux2_vae_rank0") vae = torch.compile(vae, backend="neuron", fullgraph=True) logger.info("Rank 0: VAE compiled (NEFF will build on first call)") # ------------------------------------------------------------------ # 4. Prepare latents and position IDs (all ranks, same seed — no broadcast) # ------------------------------------------------------------------ generator = torch.Generator().manual_seed(seed) num_latent_ch = transformer.config.in_channels // 4 _vae_tmp = vae if rank == 0 else None if _vae_tmp is None: vae_scale = 8 lh = 2 * (height // (vae_scale * 2)) lw = 2 * (width // (vae_scale * 2)) seq_len = (lh // 2) * (lw // 2) latents_cpu = torch.randn( batch_size, seq_len, transformer.config.in_channels, dtype=torch.bfloat16, generator=generator) ids = torch.cartesian_prod( torch.arange(1), torch.arange(lh // 2), torch.arange(lw // 2), torch.arange(1)) latent_ids_cpu = ids.unsqueeze(0).expand(batch_size, -1, -1).contiguous().float() else: latents_cpu, latent_ids_cpu_int = pipe.prepare_latents( batch_size=batch_size, num_latents_channels=num_latent_ch, height=height, width=width, dtype=torch.bfloat16, device="cpu", generator=generator, ) latent_ids_cpu = latent_ids_cpu_int.float() t_ids = torch.cartesian_prod( torch.arange(1), torch.arange(1), torch.arange(1), torch.arange(text_seq_len)) text_ids_cpu = t_ids.unsqueeze(0).expand(batch_size, -1, -1).contiguous().float() latents_dev = latents_cpu.to(device) latent_ids_dev = latent_ids_cpu.to(device) text_ids_dev = text_ids_cpu.to(device) # ------------------------------------------------------------------ # 5. Scheduler timesteps # ------------------------------------------------------------------ if rank == 0: image_seq_len = latents_dev.shape[1] mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_steps) scheduler.set_timesteps(num_steps, mu=mu) timesteps = scheduler.timesteps else: scheduler = FlowMatchEulerDiscreteScheduler() if rank == 0: ts_tensor = timesteps.float() else: ts_tensor = torch.zeros(num_steps, dtype=torch.float32) ts_dev = ts_tensor.to(device) dist.broadcast(ts_dev, src=0) ts_tensor = ts_dev # ------------------------------------------------------------------ # 6. Denoising loop (all ranks via TP transformer) # eager: step 1 triggers lazy-XLA compile # compile: step 1 triggers Dynamo trace + NEFF compilation # ------------------------------------------------------------------ dist.barrier() compile_note = " (step 1 includes torch.compile NEFF build)" if mode == "compile" else "" logger.info(f"Rank {rank}: starting {num_steps}-step denoising loop{compile_note} ...") t_total = time.time() with torch.no_grad(): for step_idx in range(num_steps): t_step = time.time() t_val = ts_tensor[step_idx] timestep = t_val.expand(batch_size).to(torch.bfloat16).to(device) / 1000.0 noise_pred = transformer( hidden_states=latents_dev, encoder_hidden_states=prompt_embeds_dev, timestep=timestep, img_ids=latent_ids_dev, txt_ids=text_ids_dev, guidance=None, return_dict=False, )[0] if rank == 0: np_cpu = noise_pred.to("cpu") lat_cpu = latents_dev.to("cpu") t_cpu = t_val.cpu() lat_new = scheduler.step(np_cpu, t_cpu, lat_cpu, return_dict=False)[0] latents_dev.copy_(lat_new.to(device)) dist.broadcast(latents_dev, src=0) if rank == 0: elapsed = time.time() - t_step np_f = noise_pred.float() logger.info(f" step {step_idx+1}/{num_steps} " f"t={t_val.item():.1f} " f"mean={np_f.mean().item():.4f} " f"std={np_f.std().item():.4f} " f"elapsed={elapsed:.3f}s") if rank == 0: logger.info(f"Denoising complete: {time.time()-t_total:.2f}s total") # ------------------------------------------------------------------ # 7. VAE decode (rank-0 only) # compile mode: vae is an OptimizedModule; access config/buffers via # vae._orig_mod to bypass the Dynamo wrapper. # ------------------------------------------------------------------ if rank == 0: vae_note = " (step 1 includes VAE NEFF build)" if mode == "compile" else "" logger.info(f"Decoding latents with VAE{vae_note} ...") latents_spatial = Flux2KleinPipeline._unpack_latents_with_ids( latents_dev, latent_ids_dev) vae_mod = vae._orig_mod if mode == "compile" else vae bn_mean = vae_mod.bn.running_mean.view(1, -1, 1, 1).to(latents_spatial.dtype) bn_std = torch.sqrt( vae_mod.bn.running_var.view(1, -1, 1, 1) + vae_mod.config.batch_norm_eps ).to(latents_spatial.dtype) latents_spatial = latents_spatial * bn_std + bn_mean latents_spatial = Flux2KleinPipeline._unpatchify_latents(latents_spatial) with torch.no_grad(): image_tensor = vae.decode(latents_spatial, return_dict=False)[0] image_np = image_tensor.clamp(-1.0, 1.0).add(1.0).div(2.0) image_np = (image_np * 255).byte().permute(0, 2, 3, 1).cpu().numpy() try: from PIL import Image for i, arr in enumerate(image_np): fname = output_path if batch_size == 1 \ else output_path.replace(".", f"_{i}.", 1) Image.fromarray(arr).save(fname) logger.info(f"Saved: {fname} (shape {arr.shape})") except ImportError: logger.warning("Pillow not installed — skipping image save.") logger.info(f"output range: [{image_tensor.min():.3f}, {image_tensor.max():.3f}]") logger.info("PASS: end-to-end pipeline completed") dist.barrier() dist.destroy_process_group() # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- def parse_args(): p = argparse.ArgumentParser( description="Flux2-klein pipeline (Qwen3 + TP transformer + VAE) on Neuron" ) p.add_argument("--mode", choices=["eager", "compile"], default="eager", help="eager: lazy-XLA path (default). compile: torch.compile Dynamo path.") p.add_argument("--model-id", default=DEFAULT_MODEL_ID) p.add_argument("--prompt", default="a photograph of a cat sitting on a Neuron chip, photorealistic") p.add_argument("--height", type=int, default=512) p.add_argument("--width", type=int, default=512) p.add_argument("--num-steps", type=int, default=4) p.add_argument("--batch-size", type=int, default=1) p.add_argument("--random-weights", action="store_true", default=True) p.add_argument("--no-random-weights", action="store_false", dest="random_weights") p.add_argument("--output", default="flux2_output.png") p.add_argument("--seed", type=int, default=42) p.add_argument("--fused-qkv", action="store_true", default=False, help="Use NKI fused QKV kernel for double-stream blocks (eager mode only). " "Replaces 3 separate ColwisePar matmuls per stream with one nki_qkv call.") p.add_argument( "--cache-dir", default=None, help=( "Persistent NEFF cache directory (sets TORCH_NEURONX_NEFF_CACHE_DIR). " "Applies to both eager and compile modes. " "NEFFs are saved on the first run and reloaded on subsequent runs, " "skipping neuronx-cc recompilation. " "Defaults to /tmp/neff_cache (lost on reboot). " "Example: --cache-dir /home/ubuntu/neff_cache" ), ) return p.parse_args() if __name__ == "__main__": args = parse_args() # Always set the NEFF cache dir regardless of mode — both eager (lazy-XLA) # and compile (Dynamo) paths use TORCH_NEURONX_NEFF_CACHE_DIR to persist # compiled NEFFs across runs. Default /tmp/neff_cache is lost on reboot. cache_dir = args.cache_dir or os.environ.get("TORCH_NEURONX_NEFF_CACHE_DIR", "/tmp/neff_cache") os.environ["TORCH_NEURONX_NEFF_CACHE_DIR"] = cache_dir os.makedirs(cache_dir, exist_ok=True) logger.info(f"NEFF cache dir: {cache_dir}") run( mode=args.mode, model_id=args.model_id, prompt=args.prompt, height=args.height, width=args.width, num_steps=args.num_steps, batch_size=args.batch_size, random_weights=args.random_weights, output_path=args.output, seed=args.seed, fuse_qkv=args.fused_qkv, )