| |
| """ |
| Full end-to-end Flux2 klein pipeline on AWS Neuron. |
| |
| Supports two execution modes selected with ``--mode``: |
| |
| eager (default) — lazy-XLA path. Models are moved to the Neuron device |
| with ``.to(device)``; the XLA compiler traces and emits |
| a NEFF on the first forward call. No explicit compile |
| step needed; graph breaks are handled transparently. |
| |
| compile — torch.compile path (Dynamo backend="neuron"). TorchDynamo |
| captures the full FX graph and compiles a NEFF on the |
| first forward call. ``fullgraph=True`` disallows graph |
| breaks. All three models (text encoder, transformer, VAE) |
| are compiled. |
| |
| NEFF caching (compile mode only): |
| Compiled NEFFs are stored in two caches: |
| 1. In-process C++ cache (NeuronResourceManager) — keyed by StableHLO hash. |
| Within a single torchrun the NEFF is compiled once (cold step 1) and |
| all subsequent steps reuse it (warm). |
| 2. Persistent file cache (TORCH_NEURONX_NEFF_CACHE_DIR, default /tmp/neff_cache). |
| Pass ``--cache-dir /persistent/path`` to keep NEFFs across reboots / |
| torchrun invocations so the cold step is skipped entirely on re-runs. |
| |
| Components and placement: |
| Text encoder (Qwen3ForCausalLM) — Neuron TP [+ torch.compile] |
| Scheduler (FlowMatchEulerDiscreteScheduler) — CPU |
| Transformer (Flux2Transformer2DModel) — Neuron TP [+ torch.compile] |
| VAE decoder (AutoencoderKLFlux2) — Neuron rank-0 [+ torch.compile] |
| |
| Usage: |
| # eager mode (default) |
| torchrun --nproc_per_node=4 flux2-klein/pipeline.py --random-weights --num-steps 4 |
| |
| # compile mode — first run compiles and caches NEFFs |
| torchrun --nproc_per_node=4 flux2-klein/pipeline.py --mode compile \\ |
| --cache-dir /home/ubuntu/neff_cache --random-weights --num-steps 4 |
| |
| # compile mode — second run loads cached NEFFs, no recompilation |
| torchrun --nproc_per_node=4 flux2-klein/pipeline.py --mode compile \\ |
| --cache-dir /home/ubuntu/neff_cache --no-random-weights \\ |
| --model-id black-forest-labs/FLUX.2-klein-9B \\ |
| --prompt "a cat sitting on a Neuron chip" --height 512 --width 512 |
| """ |
|
|
| import argparse |
| import gc |
| import logging |
| import os |
| import time |
|
|
| import torch |
| import torch.distributed as dist |
| from torch.distributed.device_mesh import DeviceMesh |
| from torch.distributed.tensor.parallel import ( |
| ColwiseParallel, |
| RowwiseParallel, |
| parallelize_module, |
| ) |
| from transformers import Qwen3ForCausalLM, Qwen2TokenizerFast |
| from diffusers import AutoencoderKLFlux2, Flux2Transformer2DModel, FlowMatchEulerDiscreteScheduler |
| from diffusers.models.embeddings import apply_rotary_emb |
| from diffusers.models.attention_dispatch import dispatch_attention_fn |
| from diffusers.models.transformers.transformer_flux2 import ( |
| Flux2AttnProcessor, |
| Flux2ParallelSelfAttnProcessor, |
| _get_qkv_projections, |
| ) |
| from diffusers.pipelines.flux2.pipeline_flux2_klein import ( |
| Flux2KleinPipeline, |
| compute_empirical_mu, |
| ) |
|
|
| import torch_neuronx |
| from torch_neuronx.neuron_dynamo_backend import set_model_name |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s") |
| logger = logging.getLogger(__name__) |
|
|
| DEFAULT_MODEL_ID = "black-forest-labs/FLUX.2-klein-4B" |
|
|
|
|
| |
| |
| |
|
|
| def _to_local(t: torch.Tensor) -> torch.Tensor: |
| """Extract local shard from a DTensor, or return tensor unchanged.""" |
| return t.to_local() if hasattr(t, "to_local") else t |
|
|
|
|
| def _build_fused_qkv_weight(q_w, k_w, v_w) -> torch.Tensor: |
| """ |
| Concatenate 3 ColwisePar local shards into one fused QKV weight [H, 3*(H//tp)]. |
| Each shard has shape [H//tp, H]; transposed → [H, H//tp]; cat → [H, 3*(H//tp)]. |
| """ |
| return torch.cat([_to_local(q_w).T, _to_local(k_w).T, _to_local(v_w).T], dim=1).contiguous() |
|
|
|
|
| |
| |
| |
|
|
| class Flux2AttnProcessorFlashAttn(Flux2AttnProcessor): |
| """ |
| TP-aware double-stream attention processor using a NKI flash attention kernel. |
| Replaces dispatch_attention_fn with flux2_flash_attn (BSHD in, BSHD out). |
| """ |
| def __init__(self, tp_size: int): |
| super().__init__() |
| self.tp_size = tp_size |
| self._flash_attn = None |
|
|
| def _init_nki(self): |
| if self._flash_attn is None: |
| import sys, os |
| _dir = os.path.dirname(os.path.abspath(__file__)) |
| if _dir not in sys.path: |
| sys.path.insert(0, _dir) |
| from nki_flash_attn import flux2_flash_attn |
| self._flash_attn = flux2_flash_attn |
|
|
| def __call__(self, attn, hidden_states, encoder_hidden_states=None, |
| attention_mask=None, image_rotary_emb=None, **kwargs): |
| self._init_nki() |
| local_heads = attn.heads // self.tp_size |
| head_dim = attn.head_dim |
|
|
| query, key, value, encoder_query, encoder_key, encoder_value = \ |
| _get_qkv_projections(attn, hidden_states, encoder_hidden_states) |
| query = query.unflatten(-1, (local_heads, head_dim)) |
| key = key.unflatten(-1, (local_heads, head_dim)) |
| value = value.unflatten(-1, (local_heads, head_dim)) |
| query = attn.norm_q(query) |
| key = attn.norm_k(key) |
| if attn.added_kv_proj_dim is not None: |
| encoder_query = encoder_query.unflatten(-1, (local_heads, head_dim)) |
| encoder_key = encoder_key.unflatten(-1, (local_heads, head_dim)) |
| encoder_value = encoder_value.unflatten(-1, (local_heads, head_dim)) |
| encoder_query = attn.norm_added_q(encoder_query) |
| encoder_key = attn.norm_added_k(encoder_key) |
| query = torch.cat([encoder_query, query], dim=1) |
| key = torch.cat([encoder_key, key], dim=1) |
| value = torch.cat([encoder_value, value], dim=1) |
| if image_rotary_emb is not None: |
| query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) |
| key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) |
|
|
| hidden_states = self._flash_attn(query, key, value) |
| hidden_states = hidden_states.flatten(2, 3).to(query.dtype) |
|
|
| if encoder_hidden_states is not None: |
| encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( |
| [encoder_hidden_states.shape[1], |
| hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1) |
| encoder_hidden_states = attn.to_add_out(encoder_hidden_states) |
| hidden_states = attn.to_out[0](hidden_states) |
| hidden_states = attn.to_out[1](hidden_states) |
| if encoder_hidden_states is not None: |
| return hidden_states, encoder_hidden_states |
| return hidden_states |
|
|
|
|
| class Flux2AttnProcessorTP(Flux2AttnProcessor): |
| def __init__(self, tp_size: int): |
| super().__init__() |
| self.tp_size = tp_size |
|
|
| def __call__(self, attn, hidden_states, encoder_hidden_states=None, |
| attention_mask=None, image_rotary_emb=None, **kwargs): |
| local_heads = attn.heads // self.tp_size |
| head_dim = attn.head_dim |
| query, key, value, encoder_query, encoder_key, encoder_value = \ |
| _get_qkv_projections(attn, hidden_states, encoder_hidden_states) |
| query = query.unflatten(-1, (local_heads, head_dim)) |
| key = key.unflatten(-1, (local_heads, head_dim)) |
| value = value.unflatten(-1, (local_heads, head_dim)) |
| query = attn.norm_q(query) |
| key = attn.norm_k(key) |
| if attn.added_kv_proj_dim is not None: |
| encoder_query = encoder_query.unflatten(-1, (local_heads, head_dim)) |
| encoder_key = encoder_key.unflatten(-1, (local_heads, head_dim)) |
| encoder_value = encoder_value.unflatten(-1, (local_heads, head_dim)) |
| encoder_query = attn.norm_added_q(encoder_query) |
| encoder_key = attn.norm_added_k(encoder_key) |
| query = torch.cat([encoder_query, query], dim=1) |
| key = torch.cat([encoder_key, key], dim=1) |
| value = torch.cat([encoder_value, value], dim=1) |
| if image_rotary_emb is not None: |
| query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) |
| key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) |
| hidden_states = dispatch_attention_fn( |
| query, key, value, |
| attn_mask=attention_mask, |
| backend=self._attention_backend, |
| parallel_config=self._parallel_config, |
| ) |
| hidden_states = hidden_states.flatten(2, 3).to(query.dtype) |
| if encoder_hidden_states is not None: |
| encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( |
| [encoder_hidden_states.shape[1], |
| hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 |
| ) |
| encoder_hidden_states = attn.to_add_out(encoder_hidden_states) |
| hidden_states = attn.to_out[0](hidden_states) |
| hidden_states = attn.to_out[1](hidden_states) |
| if encoder_hidden_states is not None: |
| return hidden_states, encoder_hidden_states |
| return hidden_states |
|
|
|
|
| class Flux2AttnProcessorFusedQKV(Flux2AttnProcessor): |
| """ |
| TP-aware double-stream attention processor using a single NKI fused QKV kernel. |
| |
| Replaces 6 separate ColwisePar matmuls (to_q/to_k/to_v for image, |
| add_q_proj/add_k_proj/add_v_proj for text) with 2 NKI qkv kernel calls — |
| one per stream — each computing Q+K+V in a single fused HBM pass. |
| |
| Fused weight [H, 3*(H//tp)] is built lazily on the first forward call by |
| concatenating the TP-sharded local weight tensors, then cached as |
| attn._fused_img_qkv_w and attn._fused_txt_qkv_w. |
| |
| Eager mode only: the lazy weight cache relies on plain tensor attributes |
| that torch.compile(fullgraph=True) would not trace through. |
| """ |
|
|
| def __init__(self, tp_size: int): |
| super().__init__() |
| self.tp_size = tp_size |
| self._nki_qkv = None |
|
|
| def _init_nki(self): |
| if self._nki_qkv is not None: |
| return |
| import sys |
| _dir = os.path.dirname(os.path.abspath(__file__)) |
| if _dir not in sys.path: |
| sys.path.insert(0, _dir) |
| from nki_qkv import fused_qkv_kernel |
| self._nki_qkv = fused_qkv_kernel |
|
|
| def _fused_proj(self, attn, q_attr, k_attr, v_attr, cache_attr, x): |
| """Build (once) and apply the fused [H, 3*(H//tp)] QKV weight.""" |
| fused_w = getattr(attn, cache_attr, None) |
| if fused_w is None: |
| fused_w = _build_fused_qkv_weight( |
| getattr(attn, q_attr).weight, |
| getattr(attn, k_attr).weight, |
| getattr(attn, v_attr).weight, |
| ).to(x.device) |
| setattr(attn, cache_attr, fused_w) |
|
|
| local_dim = attn.inner_dim // self.tp_size |
| B, S, H = x.shape |
| x_T = x.reshape(B * S, H).T.contiguous() |
| out = self._nki_qkv(x_T, fused_w) |
| return out.reshape(B, S, -1).split(local_dim, dim=-1) |
|
|
| def __call__(self, attn, hidden_states, encoder_hidden_states=None, |
| attention_mask=None, image_rotary_emb=None, **kwargs): |
| self._init_nki() |
|
|
| local_heads = attn.heads // self.tp_size |
| head_dim = attn.head_dim |
|
|
| |
| query, key, value = self._fused_proj( |
| attn, "to_q", "to_k", "to_v", "_fused_img_qkv_w", hidden_states) |
| query = query.unflatten(-1, (local_heads, head_dim)) |
| key = key.unflatten(-1, (local_heads, head_dim)) |
| value = value.unflatten(-1, (local_heads, head_dim)) |
| query = attn.norm_q(query) |
| key = attn.norm_k(key) |
|
|
| if attn.added_kv_proj_dim is not None: |
| |
| encoder_query, encoder_key, encoder_value = self._fused_proj( |
| attn, "add_q_proj", "add_k_proj", "add_v_proj", |
| "_fused_txt_qkv_w", encoder_hidden_states) |
| encoder_query = encoder_query.unflatten(-1, (local_heads, head_dim)) |
| encoder_key = encoder_key.unflatten(-1, (local_heads, head_dim)) |
| encoder_value = encoder_value.unflatten(-1, (local_heads, head_dim)) |
| encoder_query = attn.norm_added_q(encoder_query) |
| encoder_key = attn.norm_added_k(encoder_key) |
| query = torch.cat([encoder_query, query], dim=1) |
| key = torch.cat([encoder_key, key], dim=1) |
| value = torch.cat([encoder_value, value], dim=1) |
|
|
| if image_rotary_emb is not None: |
| query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) |
| key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) |
|
|
| hidden_states = dispatch_attention_fn( |
| query, key, value, |
| attn_mask=attention_mask, |
| backend=self._attention_backend, |
| parallel_config=self._parallel_config, |
| ) |
| hidden_states = hidden_states.flatten(2, 3).to(query.dtype) |
|
|
| if encoder_hidden_states is not None: |
| encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( |
| [encoder_hidden_states.shape[1], |
| hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1, |
| ) |
| encoder_hidden_states = attn.to_add_out(encoder_hidden_states) |
|
|
| hidden_states = attn.to_out[0](hidden_states) |
| hidden_states = attn.to_out[1](hidden_states) |
|
|
| if encoder_hidden_states is not None: |
| return hidden_states, encoder_hidden_states |
| return hidden_states |
|
|
|
|
| class Flux2AttnProcessorFusedQKVFlashAttn(Flux2AttnProcessor): |
| """ |
| TP-aware double-stream processor: NKI fused QKV projection + NKI flash attention. |
| Combines Flux2AttnProcessorFusedQKV (QKV side) with Flux2AttnProcessorFlashAttn |
| (attention side). Activated by --fused-qkv --flash-attn together. |
| """ |
| def __init__(self, tp_size: int): |
| super().__init__() |
| self.tp_size = tp_size |
| self._nki_qkv = None |
| self._flash_attn = None |
|
|
| def _init_nki(self): |
| if self._nki_qkv is None: |
| import sys |
| _dir = os.path.dirname(os.path.abspath(__file__)) |
| if _dir not in sys.path: |
| sys.path.insert(0, _dir) |
| from nki_qkv import fused_qkv_kernel |
| self._nki_qkv = fused_qkv_kernel |
| if self._flash_attn is None: |
| import sys |
| _dir = os.path.dirname(os.path.abspath(__file__)) |
| if _dir not in sys.path: |
| sys.path.insert(0, _dir) |
| from nki_flash_attn import flux2_flash_attn |
| self._flash_attn = flux2_flash_attn |
|
|
| def _fused_proj(self, attn, q_attr, k_attr, v_attr, cache_attr, x): |
| fused_w = getattr(attn, cache_attr, None) |
| if fused_w is None: |
| fused_w = _build_fused_qkv_weight( |
| getattr(attn, q_attr).weight, |
| getattr(attn, k_attr).weight, |
| getattr(attn, v_attr).weight, |
| ).to(x.device) |
| setattr(attn, cache_attr, fused_w) |
| local_dim = attn.inner_dim // self.tp_size |
| B, S, H = x.shape |
| x_T = x.reshape(B * S, H).T.contiguous() |
| out = self._nki_qkv(x_T, fused_w) |
| return out.reshape(B, S, -1).split(local_dim, dim=-1) |
|
|
| def __call__(self, attn, hidden_states, encoder_hidden_states=None, |
| attention_mask=None, image_rotary_emb=None, **kwargs): |
| self._init_nki() |
| local_heads = attn.heads // self.tp_size |
| head_dim = attn.head_dim |
|
|
| query, key, value = self._fused_proj( |
| attn, "to_q", "to_k", "to_v", "_fused_img_qkv_w", hidden_states) |
| query = query.unflatten(-1, (local_heads, head_dim)) |
| key = key.unflatten(-1, (local_heads, head_dim)) |
| value = value.unflatten(-1, (local_heads, head_dim)) |
| query = attn.norm_q(query) |
| key = attn.norm_k(key) |
|
|
| if attn.added_kv_proj_dim is not None: |
| encoder_query, encoder_key, encoder_value = self._fused_proj( |
| attn, "add_q_proj", "add_k_proj", "add_v_proj", |
| "_fused_txt_qkv_w", encoder_hidden_states) |
| encoder_query = encoder_query.unflatten(-1, (local_heads, head_dim)) |
| encoder_key = encoder_key.unflatten(-1, (local_heads, head_dim)) |
| encoder_value = encoder_value.unflatten(-1, (local_heads, head_dim)) |
| encoder_query = attn.norm_added_q(encoder_query) |
| encoder_key = attn.norm_added_k(encoder_key) |
| query = torch.cat([encoder_query, query], dim=1) |
| key = torch.cat([encoder_key, key], dim=1) |
| value = torch.cat([encoder_value, value], dim=1) |
|
|
| if image_rotary_emb is not None: |
| query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) |
| key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) |
|
|
| hidden_states = self._flash_attn(query, key, value) |
| hidden_states = hidden_states.flatten(2, 3).to(query.dtype) |
|
|
| if encoder_hidden_states is not None: |
| encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( |
| [encoder_hidden_states.shape[1], |
| hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1) |
| encoder_hidden_states = attn.to_add_out(encoder_hidden_states) |
| hidden_states = attn.to_out[0](hidden_states) |
| hidden_states = attn.to_out[1](hidden_states) |
| if encoder_hidden_states is not None: |
| return hidden_states, encoder_hidden_states |
| return hidden_states |
|
|
|
|
| class Flux2ParallelSelfAttnProcessorTP(Flux2ParallelSelfAttnProcessor): |
| def __init__(self, tp_size: int): |
| super().__init__() |
| self.tp_size = tp_size |
|
|
| def __call__(self, attn, hidden_states, attention_mask=None, |
| image_rotary_emb=None, **kwargs): |
| local_heads = attn.heads // self.tp_size |
| head_dim = attn.head_dim |
| local_inner = attn.inner_dim // self.tp_size |
| local_mlp_gate = attn.mlp_hidden_dim * attn.mlp_mult_factor // self.tp_size |
| hidden_states = attn.to_qkv_mlp_proj(hidden_states) |
| qkv, mlp_hidden_states = torch.split( |
| hidden_states, [3 * local_inner, local_mlp_gate], dim=-1) |
| query, key, value = qkv.chunk(3, dim=-1) |
| query = query.unflatten(-1, (local_heads, head_dim)) |
| key = key.unflatten(-1, (local_heads, head_dim)) |
| value = value.unflatten(-1, (local_heads, head_dim)) |
| query = attn.norm_q(query) |
| key = attn.norm_k(key) |
| if image_rotary_emb is not None: |
| query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) |
| key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) |
| hidden_states = dispatch_attention_fn( |
| query, key, value, |
| attn_mask=attention_mask, |
| backend=self._attention_backend, |
| parallel_config=self._parallel_config, |
| ) |
| hidden_states = hidden_states.flatten(2, 3).to(query.dtype) |
| mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states) |
| hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1) |
| return attn.to_out(hidden_states) |
|
|
|
|
| class Flux2ParallelSelfAttnProcessorFlashAttn(Flux2ParallelSelfAttnProcessor): |
| """ |
| TP-aware single-stream attention processor using a NKI flash attention kernel. |
| """ |
| def __init__(self, tp_size: int): |
| super().__init__() |
| self.tp_size = tp_size |
| self._flash_attn = None |
|
|
| def _init_nki(self): |
| if self._flash_attn is None: |
| import sys, os |
| _dir = os.path.dirname(os.path.abspath(__file__)) |
| if _dir not in sys.path: |
| sys.path.insert(0, _dir) |
| from nki_flash_attn import flux2_flash_attn |
| self._flash_attn = flux2_flash_attn |
|
|
| def __call__(self, attn, hidden_states, attention_mask=None, |
| image_rotary_emb=None, **kwargs): |
| self._init_nki() |
| local_heads = attn.heads // self.tp_size |
| head_dim = attn.head_dim |
| local_inner = attn.inner_dim // self.tp_size |
| local_mlp_gate = attn.mlp_hidden_dim * attn.mlp_mult_factor // self.tp_size |
|
|
| hidden_states = attn.to_qkv_mlp_proj(hidden_states) |
| qkv, mlp_hidden_states = torch.split( |
| hidden_states, [3 * local_inner, local_mlp_gate], dim=-1) |
| query, key, value = qkv.chunk(3, dim=-1) |
| query = query.unflatten(-1, (local_heads, head_dim)) |
| key = key.unflatten(-1, (local_heads, head_dim)) |
| value = value.unflatten(-1, (local_heads, head_dim)) |
| query = attn.norm_q(query) |
| key = attn.norm_k(key) |
| if image_rotary_emb is not None: |
| query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) |
| key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) |
|
|
| hidden_states = self._flash_attn(query, key, value) |
| hidden_states = hidden_states.flatten(2, 3).to(query.dtype) |
|
|
| mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states) |
| hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1) |
| return attn.to_out(hidden_states) |
|
|
|
|
| |
| |
| |
|
|
| def _permute_swiglu_for_tp(weight: torch.Tensor, tp_size: int) -> torch.Tensor: |
| """Permute SwiGLU linear_in rows so ColwisePar gives each rank [gate_i|linear_i].""" |
| with torch.no_grad(): |
| total = weight.shape[0] |
| inner = total // 2 |
| chunk = inner // tp_size |
| gate = weight[:inner] |
| linear = weight[inner:] |
| parts = [] |
| for i in range(tp_size): |
| parts.append(gate[i * chunk : (i + 1) * chunk]) |
| parts.append(linear[i * chunk : (i + 1) * chunk]) |
| return torch.cat(parts, dim=0) |
|
|
|
|
| def _permute_qkv_mlp_for_tp( |
| weight: torch.Tensor, |
| tp_size: int, |
| inner_dim: int, |
| mlp_hidden_dim: int, |
| ) -> torch.Tensor: |
| """Permute fused QKV+SwiGLU rows so ColwisePar gives rank i [q_i|k_i|v_i|gate_i|lin_i].""" |
| with torch.no_grad(): |
| q = weight[:inner_dim] |
| k = weight[inner_dim : 2 * inner_dim] |
| v = weight[2 * inner_dim : 3 * inner_dim] |
| mlp_gate = weight[3 * inner_dim : 3 * inner_dim + mlp_hidden_dim] |
| mlp_lin = weight[3 * inner_dim + mlp_hidden_dim :] |
| qkv_chunk = inner_dim // tp_size |
| mlp_chunk = mlp_hidden_dim // tp_size |
| parts = [] |
| for i in range(tp_size): |
| parts += [ |
| q[i * qkv_chunk : (i + 1) * qkv_chunk], |
| k[i * qkv_chunk : (i + 1) * qkv_chunk], |
| v[i * qkv_chunk : (i + 1) * qkv_chunk], |
| mlp_gate[i * mlp_chunk : (i + 1) * mlp_chunk], |
| mlp_lin [i * mlp_chunk : (i + 1) * mlp_chunk], |
| ] |
| return torch.cat(parts, dim=0) |
|
|
|
|
| def _permute_out_for_tp( |
| weight: torch.Tensor, |
| tp_size: int, |
| attn_dim: int, |
| mlp_dim: int, |
| ) -> torch.Tensor: |
| """Permute to_out columns so RowwisePar rank i gets W[:, attn_i|mlp_i].""" |
| with torch.no_grad(): |
| attn_part = weight[:, :attn_dim] |
| mlp_part = weight[:, attn_dim:] |
| attn_chunk = attn_dim // tp_size |
| mlp_chunk = mlp_dim // tp_size |
| parts = [] |
| for i in range(tp_size): |
| parts.append(attn_part[:, i * attn_chunk : (i + 1) * attn_chunk]) |
| parts.append(mlp_part[:, i * mlp_chunk : (i + 1) * mlp_chunk]) |
| return torch.cat(parts, dim=1) |
|
|
|
|
| def apply_tp_flux2_transformer( |
| model: Flux2Transformer2DModel, |
| tp_mesh: DeviceMesh, |
| fuse_qkv: bool = False, |
| flash_attn: bool = False, |
| ) -> Flux2Transformer2DModel: |
| tp_size = tp_mesh.size() |
| double_plan = { |
| "attn.to_q": ColwiseParallel(), |
| "attn.to_k": ColwiseParallel(), |
| "attn.to_v": ColwiseParallel(), |
| "attn.to_out.0": RowwiseParallel(), |
| "attn.add_q_proj": ColwiseParallel(), |
| "attn.add_k_proj": ColwiseParallel(), |
| "attn.add_v_proj": ColwiseParallel(), |
| "attn.to_add_out": RowwiseParallel(), |
| "ff.linear_in": ColwiseParallel(), |
| "ff.linear_out": RowwiseParallel(), |
| "ff_context.linear_in": ColwiseParallel(), |
| "ff_context.linear_out": RowwiseParallel(), |
| } |
| for block in model.transformer_blocks: |
| block.ff.linear_in.weight.data = _permute_swiglu_for_tp( |
| block.ff.linear_in.weight.data, tp_size) |
| block.ff_context.linear_in.weight.data = _permute_swiglu_for_tp( |
| block.ff_context.linear_in.weight.data, tp_size) |
| parallelize_module(block, tp_mesh, double_plan) |
| if fuse_qkv and flash_attn: |
| block.attn.set_processor(Flux2AttnProcessorFusedQKVFlashAttn(tp_size)) |
| block.attn._fused_img_qkv_w = _build_fused_qkv_weight( |
| block.attn.to_q.weight, |
| block.attn.to_k.weight, |
| block.attn.to_v.weight, |
| ) |
| block.attn._fused_txt_qkv_w = _build_fused_qkv_weight( |
| block.attn.add_q_proj.weight, |
| block.attn.add_k_proj.weight, |
| block.attn.add_v_proj.weight, |
| ) |
| elif flash_attn: |
| block.attn.set_processor(Flux2AttnProcessorFlashAttn(tp_size)) |
| elif fuse_qkv: |
| block.attn.set_processor(Flux2AttnProcessorFusedQKV(tp_size)) |
| |
| |
| |
| |
| block.attn._fused_img_qkv_w = _build_fused_qkv_weight( |
| block.attn.to_q.weight, |
| block.attn.to_k.weight, |
| block.attn.to_v.weight, |
| ) |
| block.attn._fused_txt_qkv_w = _build_fused_qkv_weight( |
| block.attn.add_q_proj.weight, |
| block.attn.add_k_proj.weight, |
| block.attn.add_v_proj.weight, |
| ) |
| else: |
| block.attn.set_processor(Flux2AttnProcessorTP(tp_size)) |
|
|
| single_plan = { |
| "attn.to_qkv_mlp_proj": ColwiseParallel(), |
| "attn.to_out": RowwiseParallel(), |
| } |
| for block in model.single_transformer_blocks: |
| attn = block.attn |
| inner_dim = attn.inner_dim |
| mlp_hidden = attn.mlp_hidden_dim |
| attn.to_qkv_mlp_proj.weight.data = _permute_qkv_mlp_for_tp( |
| attn.to_qkv_mlp_proj.weight.data, tp_size, inner_dim, mlp_hidden) |
| attn.to_out.weight.data = _permute_out_for_tp( |
| attn.to_out.weight.data, tp_size, inner_dim, mlp_hidden) |
| parallelize_module(block, tp_mesh, single_plan) |
| if flash_attn: |
| block.attn.set_processor(Flux2ParallelSelfAttnProcessorFlashAttn(tp_size)) |
| else: |
| block.attn.set_processor(Flux2ParallelSelfAttnProcessorTP(tp_size)) |
| return model |
|
|
|
|
| def apply_tp_text_encoder(model: Qwen3ForCausalLM, tp_mesh: DeviceMesh) -> Qwen3ForCausalLM: |
| """Llama-style TP plan for Qwen3ForCausalLM; no weight permutations needed.""" |
| layer_plan = { |
| "self_attn.q_proj": ColwiseParallel(), |
| "self_attn.k_proj": ColwiseParallel(), |
| "self_attn.v_proj": ColwiseParallel(), |
| "self_attn.o_proj": RowwiseParallel(), |
| "mlp.gate_proj": ColwiseParallel(), |
| "mlp.up_proj": ColwiseParallel(), |
| "mlp.down_proj": RowwiseParallel(), |
| } |
| for layer in model.model.layers: |
| parallelize_module(layer, tp_mesh, layer_plan) |
| return model |
|
|
|
|
| def _encode_prompt_tp( |
| text_encoder: Qwen3ForCausalLM, |
| tokenizer: Qwen2TokenizerFast, |
| prompt: str, |
| batch_size: int, |
| device, |
| ) -> torch.Tensor: |
| """ |
| Encode prompt with the TP Qwen3 text encoder on Neuron. |
| |
| All ranks tokenize independently (deterministic — avoids integer-dtype |
| dist.broadcast which fails on Neuron). Returns [B, 512, 7680]. |
| """ |
| prompts = [prompt] * batch_size |
| all_ids, all_masks = [], [] |
| for p in prompts: |
| text = tokenizer.apply_chat_template( |
| [{"role": "user", "content": p}], |
| tokenize=False, add_generation_prompt=True, enable_thinking=False, |
| ) |
| enc = tokenizer(text, return_tensors="pt", padding="max_length", |
| truncation=True, max_length=512) |
| all_ids.append(enc["input_ids"]) |
| all_masks.append(enc["attention_mask"]) |
|
|
| input_ids = torch.cat(all_ids, dim=0).to(device) |
| attention_mask = torch.cat(all_masks, dim=0).to(device) |
|
|
| with torch.no_grad(): |
| output = text_encoder( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| output_hidden_states=True, |
| use_cache=False, |
| ) |
|
|
| hidden = [] |
| for k in (9, 18, 27): |
| hs = output.hidden_states[k] |
| if hasattr(hs, "to_local"): |
| hs = hs.to_local() |
| hidden.append(hs.to(torch.bfloat16)) |
|
|
| out = torch.stack(hidden, dim=1) |
| B, _, seq, hdim = out.shape |
| return out.permute(0, 2, 1, 3).reshape(B, seq, 3 * hdim) |
|
|
|
|
| |
| |
| |
|
|
| def _snapshot(model_id: str) -> str: |
| from huggingface_hub import snapshot_download |
| return snapshot_download(model_id) |
|
|
|
|
| def _load_safetensors(model: torch.nn.Module, subfolder_dir: str) -> None: |
| """Stream safetensors one tensor at a time — avoids doubling peak CPU RAM.""" |
| from safetensors import safe_open |
| shards = sorted(f for f in os.listdir(subfolder_dir) if f.endswith(".safetensors")) |
| if not shards: |
| raise FileNotFoundError(f"No .safetensors in {subfolder_dir}") |
| params = dict(model.named_parameters()) |
| buffers = dict(model.named_buffers()) |
| loaded: set = set() |
| for fname in shards: |
| with safe_open(os.path.join(subfolder_dir, fname), framework="pt", device="cpu") as f: |
| for key in f.keys(): |
| tensor = f.get_tensor(key) |
| if key in params: |
| params[key].data.copy_(tensor) |
| loaded.add(key) |
| elif key in buffers: |
| buffers[key].data.copy_(tensor) |
| loaded.add(key) |
| del tensor |
| missing = [k for k in params if k not in loaded] |
| if missing: |
| logger.warning(" missing keys (first 5): %s", missing[:5]) |
|
|
|
|
| def load_text_encoder(model_id: str, random_weights: bool): |
| if random_weights: |
| return None, None |
| snap = _snapshot(model_id) |
| logger.info("Loading Qwen3 text encoder ...") |
| model = Qwen3ForCausalLM.from_pretrained( |
| os.path.join(snap, "text_encoder"), dtype=torch.bfloat16) |
| model.eval() |
| tokenizer = Qwen2TokenizerFast.from_pretrained(os.path.join(snap, "tokenizer")) |
| return model, tokenizer |
|
|
|
|
| def load_transformer(model_id: str, random_weights: bool) -> Flux2Transformer2DModel: |
| config = Flux2Transformer2DModel.load_config(model_id, subfolder="transformer") |
| _prev_dtype = torch.get_default_dtype() |
| torch.set_default_dtype(torch.bfloat16) |
| model = Flux2Transformer2DModel.from_config(config) |
| torch.set_default_dtype(_prev_dtype) |
| if not random_weights: |
| logger.info("Loading transformer weights ...") |
| snap = _snapshot(model_id) |
| _load_safetensors(model, os.path.join(snap, "transformer")) |
| return model |
|
|
|
|
| def load_vae(model_id: str, random_weights: bool) -> AutoencoderKLFlux2: |
| config = AutoencoderKLFlux2.load_config(model_id, subfolder="vae") |
| _prev_dtype = torch.get_default_dtype() |
| torch.set_default_dtype(torch.bfloat16) |
| vae = AutoencoderKLFlux2.from_config(config) |
| torch.set_default_dtype(_prev_dtype) |
| if not random_weights: |
| logger.info("Loading VAE weights ...") |
| snap = _snapshot(model_id) |
| _load_safetensors(vae, os.path.join(snap, "vae")) |
| vae.eval() |
| return vae |
|
|
|
|
| |
| |
| |
|
|
| def run(mode, model_id, prompt, height, width, num_steps, |
| batch_size, random_weights, output_path, seed, fuse_qkv=False): |
|
|
| assert mode in ("eager", "compile"), f"--mode must be 'eager' or 'compile', got {mode!r}" |
|
|
| dist.init_process_group(backend="neuron") |
| world_size = dist.get_world_size() |
| rank = dist.get_rank() |
| device = torch.neuron.current_device() |
|
|
| logger.info(f"Rank {rank}/{world_size} on device {device} " |
| f"mode={mode} ({height}x{width}, {num_steps} steps, random={random_weights})") |
|
|
| tp_mesh = DeviceMesh("neuron", list(range(world_size))) |
|
|
| joint_attention_dim = Flux2Transformer2DModel.load_config( |
| model_id, subfolder="transformer")["joint_attention_dim"] |
| text_seq_len = 512 |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| if not random_weights: |
| t0 = time.time() |
| text_encoder, tokenizer = load_text_encoder(model_id, random_weights) |
| logger.info(f"Rank {rank}: text encoder loaded in {time.time()-t0:.1f}s " |
| f"({sum(p.numel() for p in text_encoder.parameters())/1e9:.2f}B params)") |
|
|
| text_encoder = apply_tp_text_encoder(text_encoder, tp_mesh) |
| text_encoder = text_encoder.to(device) |
| text_encoder.eval() |
| if mode == "compile": |
| set_model_name(f"qwen3_text_encoder_rank{rank}") |
| |
| |
| |
| |
| |
| |
| |
| |
| from transformers.utils.output_capturing import install_all_output_capturing_hooks |
| install_all_output_capturing_hooks(text_encoder) |
| text_encoder = torch.compile(text_encoder, backend="neuron", fullgraph=True) |
| logger.info(f"Rank {rank}: text encoder compiled (NEFF will build on first call)") |
| gc.collect() |
|
|
| t0 = time.time() |
| prompt_embeds_dev = _encode_prompt_tp( |
| text_encoder, tokenizer, prompt, batch_size, device) |
| if rank == 0: |
| logger.info(f"Prompt encoded in {time.time()-t0:.1f}s " |
| f"shape={prompt_embeds_dev.shape}") |
|
|
| del text_encoder, tokenizer |
| gc.collect() |
| else: |
| |
| prompt_embeds_dev = torch.zeros( |
| batch_size, text_seq_len, joint_attention_dim, |
| dtype=torch.bfloat16, device=device) |
| if rank == 0: |
| prompt_embeds_dev.copy_( |
| torch.randn(batch_size, text_seq_len, joint_attention_dim, |
| dtype=torch.bfloat16).to(device)) |
| dist.broadcast(prompt_embeds_dev, src=0) |
|
|
| |
| |
| |
| |
| t0 = time.time() |
| transformer = load_transformer(model_id, random_weights) |
| logger.info(f"Rank {rank}: transformer loaded in {time.time()-t0:.1f}s " |
| f"({sum(p.numel() for p in transformer.parameters())/1e9:.2f}B params)") |
|
|
| transformer = apply_tp_flux2_transformer(transformer, tp_mesh, fuse_qkv=fuse_qkv) |
| transformer = transformer.to(device) |
| transformer.eval() |
| if mode == "compile": |
| set_model_name(f"flux2_transformer_rank{rank}") |
| transformer = torch.compile(transformer, backend="neuron", fullgraph=True) |
| logger.info(f"Rank {rank}: transformer compiled (NEFF will build on first call)") |
| gc.collect() |
|
|
| |
| |
| |
| |
| |
| |
| |
| pipe = None |
| if rank == 0: |
| snap = _snapshot(model_id) |
| vae = load_vae(model_id, random_weights) |
| vae = vae.to(device) |
| scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( |
| os.path.join(snap, "scheduler")) |
| pipe = Flux2KleinPipeline( |
| scheduler=scheduler, |
| vae=vae, |
| text_encoder=None, |
| tokenizer=None, |
| transformer=transformer, |
| is_distilled=True, |
| ) |
| if mode == "compile": |
| set_model_name("flux2_vae_rank0") |
| vae = torch.compile(vae, backend="neuron", fullgraph=True) |
| logger.info("Rank 0: VAE compiled (NEFF will build on first call)") |
|
|
| |
| |
| |
| generator = torch.Generator().manual_seed(seed) |
|
|
| num_latent_ch = transformer.config.in_channels // 4 |
| _vae_tmp = vae if rank == 0 else None |
|
|
| if _vae_tmp is None: |
| vae_scale = 8 |
| lh = 2 * (height // (vae_scale * 2)) |
| lw = 2 * (width // (vae_scale * 2)) |
| seq_len = (lh // 2) * (lw // 2) |
| latents_cpu = torch.randn( |
| batch_size, seq_len, transformer.config.in_channels, |
| dtype=torch.bfloat16, generator=generator) |
| ids = torch.cartesian_prod( |
| torch.arange(1), torch.arange(lh // 2), |
| torch.arange(lw // 2), torch.arange(1)) |
| latent_ids_cpu = ids.unsqueeze(0).expand(batch_size, -1, -1).contiguous().float() |
| else: |
| latents_cpu, latent_ids_cpu_int = pipe.prepare_latents( |
| batch_size=batch_size, |
| num_latents_channels=num_latent_ch, |
| height=height, |
| width=width, |
| dtype=torch.bfloat16, |
| device="cpu", |
| generator=generator, |
| ) |
| latent_ids_cpu = latent_ids_cpu_int.float() |
|
|
| t_ids = torch.cartesian_prod( |
| torch.arange(1), torch.arange(1), torch.arange(1), torch.arange(text_seq_len)) |
| text_ids_cpu = t_ids.unsqueeze(0).expand(batch_size, -1, -1).contiguous().float() |
|
|
| latents_dev = latents_cpu.to(device) |
| latent_ids_dev = latent_ids_cpu.to(device) |
| text_ids_dev = text_ids_cpu.to(device) |
|
|
| |
| |
| |
| if rank == 0: |
| image_seq_len = latents_dev.shape[1] |
| mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_steps) |
| scheduler.set_timesteps(num_steps, mu=mu) |
| timesteps = scheduler.timesteps |
| else: |
| scheduler = FlowMatchEulerDiscreteScheduler() |
|
|
| if rank == 0: |
| ts_tensor = timesteps.float() |
| else: |
| ts_tensor = torch.zeros(num_steps, dtype=torch.float32) |
| ts_dev = ts_tensor.to(device) |
| dist.broadcast(ts_dev, src=0) |
| ts_tensor = ts_dev |
|
|
| |
| |
| |
| |
| |
| dist.barrier() |
| compile_note = " (step 1 includes torch.compile NEFF build)" if mode == "compile" else "" |
| logger.info(f"Rank {rank}: starting {num_steps}-step denoising loop{compile_note} ...") |
| t_total = time.time() |
|
|
| with torch.no_grad(): |
| for step_idx in range(num_steps): |
| t_step = time.time() |
| t_val = ts_tensor[step_idx] |
|
|
| timestep = t_val.expand(batch_size).to(torch.bfloat16).to(device) / 1000.0 |
|
|
| noise_pred = transformer( |
| hidden_states=latents_dev, |
| encoder_hidden_states=prompt_embeds_dev, |
| timestep=timestep, |
| img_ids=latent_ids_dev, |
| txt_ids=text_ids_dev, |
| guidance=None, |
| return_dict=False, |
| )[0] |
|
|
| if rank == 0: |
| np_cpu = noise_pred.to("cpu") |
| lat_cpu = latents_dev.to("cpu") |
| t_cpu = t_val.cpu() |
| lat_new = scheduler.step(np_cpu, t_cpu, lat_cpu, return_dict=False)[0] |
| latents_dev.copy_(lat_new.to(device)) |
|
|
| dist.broadcast(latents_dev, src=0) |
|
|
| if rank == 0: |
| elapsed = time.time() - t_step |
| np_f = noise_pred.float() |
| logger.info(f" step {step_idx+1}/{num_steps} " |
| f"t={t_val.item():.1f} " |
| f"mean={np_f.mean().item():.4f} " |
| f"std={np_f.std().item():.4f} " |
| f"elapsed={elapsed:.3f}s") |
|
|
| if rank == 0: |
| logger.info(f"Denoising complete: {time.time()-t_total:.2f}s total") |
|
|
| |
| |
| |
| |
| |
| if rank == 0: |
| vae_note = " (step 1 includes VAE NEFF build)" if mode == "compile" else "" |
| logger.info(f"Decoding latents with VAE{vae_note} ...") |
|
|
| latents_spatial = Flux2KleinPipeline._unpack_latents_with_ids( |
| latents_dev, latent_ids_dev) |
|
|
| vae_mod = vae._orig_mod if mode == "compile" else vae |
| bn_mean = vae_mod.bn.running_mean.view(1, -1, 1, 1).to(latents_spatial.dtype) |
| bn_std = torch.sqrt( |
| vae_mod.bn.running_var.view(1, -1, 1, 1) + vae_mod.config.batch_norm_eps |
| ).to(latents_spatial.dtype) |
| latents_spatial = latents_spatial * bn_std + bn_mean |
| latents_spatial = Flux2KleinPipeline._unpatchify_latents(latents_spatial) |
|
|
| with torch.no_grad(): |
| image_tensor = vae.decode(latents_spatial, return_dict=False)[0] |
|
|
| image_np = image_tensor.clamp(-1.0, 1.0).add(1.0).div(2.0) |
| image_np = (image_np * 255).byte().permute(0, 2, 3, 1).cpu().numpy() |
| try: |
| from PIL import Image |
| for i, arr in enumerate(image_np): |
| fname = output_path if batch_size == 1 \ |
| else output_path.replace(".", f"_{i}.", 1) |
| Image.fromarray(arr).save(fname) |
| logger.info(f"Saved: {fname} (shape {arr.shape})") |
| except ImportError: |
| logger.warning("Pillow not installed — skipping image save.") |
|
|
| logger.info(f"output range: [{image_tensor.min():.3f}, {image_tensor.max():.3f}]") |
| logger.info("PASS: end-to-end pipeline completed") |
|
|
| dist.barrier() |
| dist.destroy_process_group() |
|
|
|
|
| |
| |
| |
|
|
| def parse_args(): |
| p = argparse.ArgumentParser( |
| description="Flux2-klein pipeline (Qwen3 + TP transformer + VAE) on Neuron" |
| ) |
| p.add_argument("--mode", choices=["eager", "compile"], default="eager", |
| help="eager: lazy-XLA path (default). compile: torch.compile Dynamo path.") |
| p.add_argument("--model-id", default=DEFAULT_MODEL_ID) |
| p.add_argument("--prompt", default="a photograph of a cat sitting on a Neuron chip, photorealistic") |
| p.add_argument("--height", type=int, default=512) |
| p.add_argument("--width", type=int, default=512) |
| p.add_argument("--num-steps", type=int, default=4) |
| p.add_argument("--batch-size", type=int, default=1) |
| p.add_argument("--random-weights", action="store_true", default=True) |
| p.add_argument("--no-random-weights", action="store_false", dest="random_weights") |
| p.add_argument("--output", default="flux2_output.png") |
| p.add_argument("--seed", type=int, default=42) |
| p.add_argument("--fused-qkv", action="store_true", default=False, |
| help="Use NKI fused QKV kernel for double-stream blocks (eager mode only). " |
| "Replaces 3 separate ColwisePar matmuls per stream with one nki_qkv call.") |
| p.add_argument( |
| "--cache-dir", |
| default=None, |
| help=( |
| "Persistent NEFF cache directory (sets TORCH_NEURONX_NEFF_CACHE_DIR). " |
| "Applies to both eager and compile modes. " |
| "NEFFs are saved on the first run and reloaded on subsequent runs, " |
| "skipping neuronx-cc recompilation. " |
| "Defaults to /tmp/neff_cache (lost on reboot). " |
| "Example: --cache-dir /home/ubuntu/neff_cache" |
| ), |
| ) |
| return p.parse_args() |
|
|
|
|
| if __name__ == "__main__": |
| args = parse_args() |
| |
| |
| |
| cache_dir = args.cache_dir or os.environ.get("TORCH_NEURONX_NEFF_CACHE_DIR", "/tmp/neff_cache") |
| os.environ["TORCH_NEURONX_NEFF_CACHE_DIR"] = cache_dir |
| os.makedirs(cache_dir, exist_ok=True) |
| logger.info(f"NEFF cache dir: {cache_dir}") |
| run( |
| mode=args.mode, |
| model_id=args.model_id, |
| prompt=args.prompt, |
| height=args.height, |
| width=args.width, |
| num_steps=args.num_steps, |
| batch_size=args.batch_size, |
| random_weights=args.random_weights, |
| output_path=args.output, |
| seed=args.seed, |
| fuse_qkv=args.fused_qkv, |
| ) |
|
|