custom-nki-kernels / src /pipeline.py
Jingya's picture
Jingya HF Staff
Upload src/pipeline.py with huggingface_hub
d38fb47 verified
#!/usr/bin/env python3
"""
Full end-to-end Flux2 klein pipeline on AWS Neuron.
Supports two execution modes selected with ``--mode``:
eager (default) — lazy-XLA path. Models are moved to the Neuron device
with ``.to(device)``; the XLA compiler traces and emits
a NEFF on the first forward call. No explicit compile
step needed; graph breaks are handled transparently.
compile — torch.compile path (Dynamo backend="neuron"). TorchDynamo
captures the full FX graph and compiles a NEFF on the
first forward call. ``fullgraph=True`` disallows graph
breaks. All three models (text encoder, transformer, VAE)
are compiled.
NEFF caching (compile mode only):
Compiled NEFFs are stored in two caches:
1. In-process C++ cache (NeuronResourceManager) — keyed by StableHLO hash.
Within a single torchrun the NEFF is compiled once (cold step 1) and
all subsequent steps reuse it (warm).
2. Persistent file cache (TORCH_NEURONX_NEFF_CACHE_DIR, default /tmp/neff_cache).
Pass ``--cache-dir /persistent/path`` to keep NEFFs across reboots /
torchrun invocations so the cold step is skipped entirely on re-runs.
Components and placement:
Text encoder (Qwen3ForCausalLM) — Neuron TP [+ torch.compile]
Scheduler (FlowMatchEulerDiscreteScheduler) — CPU
Transformer (Flux2Transformer2DModel) — Neuron TP [+ torch.compile]
VAE decoder (AutoencoderKLFlux2) — Neuron rank-0 [+ torch.compile]
Usage:
# eager mode (default)
torchrun --nproc_per_node=4 flux2-klein/pipeline.py --random-weights --num-steps 4
# compile mode — first run compiles and caches NEFFs
torchrun --nproc_per_node=4 flux2-klein/pipeline.py --mode compile \\
--cache-dir /home/ubuntu/neff_cache --random-weights --num-steps 4
# compile mode — second run loads cached NEFFs, no recompilation
torchrun --nproc_per_node=4 flux2-klein/pipeline.py --mode compile \\
--cache-dir /home/ubuntu/neff_cache --no-random-weights \\
--model-id black-forest-labs/FLUX.2-klein-9B \\
--prompt "a cat sitting on a Neuron chip" --height 512 --width 512
"""
import argparse
import gc
import logging
import os
import time
import torch
import torch.distributed as dist
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
parallelize_module,
)
from transformers import Qwen3ForCausalLM, Qwen2TokenizerFast
from diffusers import AutoencoderKLFlux2, Flux2Transformer2DModel, FlowMatchEulerDiscreteScheduler
from diffusers.models.embeddings import apply_rotary_emb
from diffusers.models.attention_dispatch import dispatch_attention_fn
from diffusers.models.transformers.transformer_flux2 import (
Flux2AttnProcessor,
Flux2ParallelSelfAttnProcessor,
_get_qkv_projections,
)
from diffusers.pipelines.flux2.pipeline_flux2_klein import (
Flux2KleinPipeline,
compute_empirical_mu,
)
import torch_neuronx # noqa: F401 — registers neuron backend
from torch_neuronx.neuron_dynamo_backend import set_model_name
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s")
logger = logging.getLogger(__name__)
DEFAULT_MODEL_ID = "black-forest-labs/FLUX.2-klein-4B"
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _to_local(t: torch.Tensor) -> torch.Tensor:
"""Extract local shard from a DTensor, or return tensor unchanged."""
return t.to_local() if hasattr(t, "to_local") else t
def _build_fused_qkv_weight(q_w, k_w, v_w) -> torch.Tensor:
"""
Concatenate 3 ColwisePar local shards into one fused QKV weight [H, 3*(H//tp)].
Each shard has shape [H//tp, H]; transposed → [H, H//tp]; cat → [H, 3*(H//tp)].
"""
return torch.cat([_to_local(q_w).T, _to_local(k_w).T, _to_local(v_w).T], dim=1).contiguous()
# ---------------------------------------------------------------------------
# TP-aware attention processors
# ---------------------------------------------------------------------------
class Flux2AttnProcessorFlashAttn(Flux2AttnProcessor):
"""
TP-aware double-stream attention processor using a NKI flash attention kernel.
Replaces dispatch_attention_fn with flux2_flash_attn (BSHD in, BSHD out).
"""
def __init__(self, tp_size: int):
super().__init__()
self.tp_size = tp_size
self._flash_attn = None
def _init_nki(self):
if self._flash_attn is None:
import sys, os
_dir = os.path.dirname(os.path.abspath(__file__))
if _dir not in sys.path:
sys.path.insert(0, _dir)
from nki_flash_attn import flux2_flash_attn
self._flash_attn = flux2_flash_attn
def __call__(self, attn, hidden_states, encoder_hidden_states=None,
attention_mask=None, image_rotary_emb=None, **kwargs):
self._init_nki()
local_heads = attn.heads // self.tp_size
head_dim = attn.head_dim
query, key, value, encoder_query, encoder_key, encoder_value = \
_get_qkv_projections(attn, hidden_states, encoder_hidden_states)
query = query.unflatten(-1, (local_heads, head_dim))
key = key.unflatten(-1, (local_heads, head_dim))
value = value.unflatten(-1, (local_heads, head_dim))
query = attn.norm_q(query)
key = attn.norm_k(key)
if attn.added_kv_proj_dim is not None:
encoder_query = encoder_query.unflatten(-1, (local_heads, head_dim))
encoder_key = encoder_key.unflatten(-1, (local_heads, head_dim))
encoder_value = encoder_value.unflatten(-1, (local_heads, head_dim))
encoder_query = attn.norm_added_q(encoder_query)
encoder_key = attn.norm_added_k(encoder_key)
query = torch.cat([encoder_query, query], dim=1)
key = torch.cat([encoder_key, key], dim=1)
value = torch.cat([encoder_value, value], dim=1)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
hidden_states = self._flash_attn(query, key, value) # (B, S, H_local, D) BSHD
hidden_states = hidden_states.flatten(2, 3).to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
[encoder_hidden_states.shape[1],
hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
if encoder_hidden_states is not None:
return hidden_states, encoder_hidden_states
return hidden_states
class Flux2AttnProcessorTP(Flux2AttnProcessor):
def __init__(self, tp_size: int):
super().__init__()
self.tp_size = tp_size
def __call__(self, attn, hidden_states, encoder_hidden_states=None,
attention_mask=None, image_rotary_emb=None, **kwargs):
local_heads = attn.heads // self.tp_size
head_dim = attn.head_dim
query, key, value, encoder_query, encoder_key, encoder_value = \
_get_qkv_projections(attn, hidden_states, encoder_hidden_states)
query = query.unflatten(-1, (local_heads, head_dim))
key = key.unflatten(-1, (local_heads, head_dim))
value = value.unflatten(-1, (local_heads, head_dim))
query = attn.norm_q(query)
key = attn.norm_k(key)
if attn.added_kv_proj_dim is not None:
encoder_query = encoder_query.unflatten(-1, (local_heads, head_dim))
encoder_key = encoder_key.unflatten(-1, (local_heads, head_dim))
encoder_value = encoder_value.unflatten(-1, (local_heads, head_dim))
encoder_query = attn.norm_added_q(encoder_query)
encoder_key = attn.norm_added_k(encoder_key)
query = torch.cat([encoder_query, query], dim=1)
key = torch.cat([encoder_key, key], dim=1)
value = torch.cat([encoder_value, value], dim=1)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
hidden_states = dispatch_attention_fn(
query, key, value,
attn_mask=attention_mask,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3).to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
[encoder_hidden_states.shape[1],
hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
if encoder_hidden_states is not None:
return hidden_states, encoder_hidden_states
return hidden_states
class Flux2AttnProcessorFusedQKV(Flux2AttnProcessor):
"""
TP-aware double-stream attention processor using a single NKI fused QKV kernel.
Replaces 6 separate ColwisePar matmuls (to_q/to_k/to_v for image,
add_q_proj/add_k_proj/add_v_proj for text) with 2 NKI qkv kernel calls —
one per stream — each computing Q+K+V in a single fused HBM pass.
Fused weight [H, 3*(H//tp)] is built lazily on the first forward call by
concatenating the TP-sharded local weight tensors, then cached as
attn._fused_img_qkv_w and attn._fused_txt_qkv_w.
Eager mode only: the lazy weight cache relies on plain tensor attributes
that torch.compile(fullgraph=True) would not trace through.
"""
def __init__(self, tp_size: int):
super().__init__()
self.tp_size = tp_size
self._nki_qkv = None
def _init_nki(self):
if self._nki_qkv is not None:
return
import sys
_dir = os.path.dirname(os.path.abspath(__file__))
if _dir not in sys.path:
sys.path.insert(0, _dir)
from nki_qkv import fused_qkv_kernel
self._nki_qkv = fused_qkv_kernel
def _fused_proj(self, attn, q_attr, k_attr, v_attr, cache_attr, x):
"""Build (once) and apply the fused [H, 3*(H//tp)] QKV weight."""
fused_w = getattr(attn, cache_attr, None)
if fused_w is None:
fused_w = _build_fused_qkv_weight(
getattr(attn, q_attr).weight,
getattr(attn, k_attr).weight,
getattr(attn, v_attr).weight,
).to(x.device)
setattr(attn, cache_attr, fused_w)
local_dim = attn.inner_dim // self.tp_size
B, S, H = x.shape
x_T = x.reshape(B * S, H).T.contiguous() # [H, B*S] — pre-transposed for kernel
out = self._nki_qkv(x_T, fused_w) # [B*S, 3*local_dim]
return out.reshape(B, S, -1).split(local_dim, dim=-1) # q, k, v each [B, S, local_dim]
def __call__(self, attn, hidden_states, encoder_hidden_states=None,
attention_mask=None, image_rotary_emb=None, **kwargs):
self._init_nki()
local_heads = attn.heads // self.tp_size
head_dim = attn.head_dim
# Image stream: fused QKV
query, key, value = self._fused_proj(
attn, "to_q", "to_k", "to_v", "_fused_img_qkv_w", hidden_states)
query = query.unflatten(-1, (local_heads, head_dim))
key = key.unflatten(-1, (local_heads, head_dim))
value = value.unflatten(-1, (local_heads, head_dim))
query = attn.norm_q(query)
key = attn.norm_k(key)
if attn.added_kv_proj_dim is not None:
# Text stream: fused QKV
encoder_query, encoder_key, encoder_value = self._fused_proj(
attn, "add_q_proj", "add_k_proj", "add_v_proj",
"_fused_txt_qkv_w", encoder_hidden_states)
encoder_query = encoder_query.unflatten(-1, (local_heads, head_dim))
encoder_key = encoder_key.unflatten(-1, (local_heads, head_dim))
encoder_value = encoder_value.unflatten(-1, (local_heads, head_dim))
encoder_query = attn.norm_added_q(encoder_query)
encoder_key = attn.norm_added_k(encoder_key)
query = torch.cat([encoder_query, query], dim=1)
key = torch.cat([encoder_key, key], dim=1)
value = torch.cat([encoder_value, value], dim=1)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
hidden_states = dispatch_attention_fn(
query, key, value,
attn_mask=attention_mask,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3).to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
[encoder_hidden_states.shape[1],
hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1,
)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
if encoder_hidden_states is not None:
return hidden_states, encoder_hidden_states
return hidden_states
class Flux2AttnProcessorFusedQKVFlashAttn(Flux2AttnProcessor):
"""
TP-aware double-stream processor: NKI fused QKV projection + NKI flash attention.
Combines Flux2AttnProcessorFusedQKV (QKV side) with Flux2AttnProcessorFlashAttn
(attention side). Activated by --fused-qkv --flash-attn together.
"""
def __init__(self, tp_size: int):
super().__init__()
self.tp_size = tp_size
self._nki_qkv = None
self._flash_attn = None
def _init_nki(self):
if self._nki_qkv is None:
import sys
_dir = os.path.dirname(os.path.abspath(__file__))
if _dir not in sys.path:
sys.path.insert(0, _dir)
from nki_qkv import fused_qkv_kernel
self._nki_qkv = fused_qkv_kernel
if self._flash_attn is None:
import sys
_dir = os.path.dirname(os.path.abspath(__file__))
if _dir not in sys.path:
sys.path.insert(0, _dir)
from nki_flash_attn import flux2_flash_attn
self._flash_attn = flux2_flash_attn
def _fused_proj(self, attn, q_attr, k_attr, v_attr, cache_attr, x):
fused_w = getattr(attn, cache_attr, None)
if fused_w is None:
fused_w = _build_fused_qkv_weight(
getattr(attn, q_attr).weight,
getattr(attn, k_attr).weight,
getattr(attn, v_attr).weight,
).to(x.device)
setattr(attn, cache_attr, fused_w)
local_dim = attn.inner_dim // self.tp_size
B, S, H = x.shape
x_T = x.reshape(B * S, H).T.contiguous()
out = self._nki_qkv(x_T, fused_w)
return out.reshape(B, S, -1).split(local_dim, dim=-1)
def __call__(self, attn, hidden_states, encoder_hidden_states=None,
attention_mask=None, image_rotary_emb=None, **kwargs):
self._init_nki()
local_heads = attn.heads // self.tp_size
head_dim = attn.head_dim
query, key, value = self._fused_proj(
attn, "to_q", "to_k", "to_v", "_fused_img_qkv_w", hidden_states)
query = query.unflatten(-1, (local_heads, head_dim))
key = key.unflatten(-1, (local_heads, head_dim))
value = value.unflatten(-1, (local_heads, head_dim))
query = attn.norm_q(query)
key = attn.norm_k(key)
if attn.added_kv_proj_dim is not None:
encoder_query, encoder_key, encoder_value = self._fused_proj(
attn, "add_q_proj", "add_k_proj", "add_v_proj",
"_fused_txt_qkv_w", encoder_hidden_states)
encoder_query = encoder_query.unflatten(-1, (local_heads, head_dim))
encoder_key = encoder_key.unflatten(-1, (local_heads, head_dim))
encoder_value = encoder_value.unflatten(-1, (local_heads, head_dim))
encoder_query = attn.norm_added_q(encoder_query)
encoder_key = attn.norm_added_k(encoder_key)
query = torch.cat([encoder_query, query], dim=1)
key = torch.cat([encoder_key, key], dim=1)
value = torch.cat([encoder_value, value], dim=1)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
hidden_states = self._flash_attn(query, key, value)
hidden_states = hidden_states.flatten(2, 3).to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
[encoder_hidden_states.shape[1],
hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
if encoder_hidden_states is not None:
return hidden_states, encoder_hidden_states
return hidden_states
class Flux2ParallelSelfAttnProcessorTP(Flux2ParallelSelfAttnProcessor):
def __init__(self, tp_size: int):
super().__init__()
self.tp_size = tp_size
def __call__(self, attn, hidden_states, attention_mask=None,
image_rotary_emb=None, **kwargs):
local_heads = attn.heads // self.tp_size
head_dim = attn.head_dim
local_inner = attn.inner_dim // self.tp_size
local_mlp_gate = attn.mlp_hidden_dim * attn.mlp_mult_factor // self.tp_size
hidden_states = attn.to_qkv_mlp_proj(hidden_states)
qkv, mlp_hidden_states = torch.split(
hidden_states, [3 * local_inner, local_mlp_gate], dim=-1)
query, key, value = qkv.chunk(3, dim=-1)
query = query.unflatten(-1, (local_heads, head_dim))
key = key.unflatten(-1, (local_heads, head_dim))
value = value.unflatten(-1, (local_heads, head_dim))
query = attn.norm_q(query)
key = attn.norm_k(key)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
hidden_states = dispatch_attention_fn(
query, key, value,
attn_mask=attention_mask,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3).to(query.dtype)
mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states)
hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1)
return attn.to_out(hidden_states)
class Flux2ParallelSelfAttnProcessorFlashAttn(Flux2ParallelSelfAttnProcessor):
"""
TP-aware single-stream attention processor using a NKI flash attention kernel.
"""
def __init__(self, tp_size: int):
super().__init__()
self.tp_size = tp_size
self._flash_attn = None
def _init_nki(self):
if self._flash_attn is None:
import sys, os
_dir = os.path.dirname(os.path.abspath(__file__))
if _dir not in sys.path:
sys.path.insert(0, _dir)
from nki_flash_attn import flux2_flash_attn
self._flash_attn = flux2_flash_attn
def __call__(self, attn, hidden_states, attention_mask=None,
image_rotary_emb=None, **kwargs):
self._init_nki()
local_heads = attn.heads // self.tp_size
head_dim = attn.head_dim
local_inner = attn.inner_dim // self.tp_size
local_mlp_gate = attn.mlp_hidden_dim * attn.mlp_mult_factor // self.tp_size
hidden_states = attn.to_qkv_mlp_proj(hidden_states)
qkv, mlp_hidden_states = torch.split(
hidden_states, [3 * local_inner, local_mlp_gate], dim=-1)
query, key, value = qkv.chunk(3, dim=-1)
query = query.unflatten(-1, (local_heads, head_dim))
key = key.unflatten(-1, (local_heads, head_dim))
value = value.unflatten(-1, (local_heads, head_dim))
query = attn.norm_q(query)
key = attn.norm_k(key)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
hidden_states = self._flash_attn(query, key, value) # (B, S, H_local, D) BSHD
hidden_states = hidden_states.flatten(2, 3).to(query.dtype)
mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states)
hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1)
return attn.to_out(hidden_states)
# ---------------------------------------------------------------------------
# Tensor parallelism
# ---------------------------------------------------------------------------
def _permute_swiglu_for_tp(weight: torch.Tensor, tp_size: int) -> torch.Tensor:
"""Permute SwiGLU linear_in rows so ColwisePar gives each rank [gate_i|linear_i]."""
with torch.no_grad():
total = weight.shape[0]
inner = total // 2
chunk = inner // tp_size
gate = weight[:inner]
linear = weight[inner:]
parts = []
for i in range(tp_size):
parts.append(gate[i * chunk : (i + 1) * chunk])
parts.append(linear[i * chunk : (i + 1) * chunk])
return torch.cat(parts, dim=0)
def _permute_qkv_mlp_for_tp(
weight: torch.Tensor,
tp_size: int,
inner_dim: int,
mlp_hidden_dim: int,
) -> torch.Tensor:
"""Permute fused QKV+SwiGLU rows so ColwisePar gives rank i [q_i|k_i|v_i|gate_i|lin_i]."""
with torch.no_grad():
q = weight[:inner_dim]
k = weight[inner_dim : 2 * inner_dim]
v = weight[2 * inner_dim : 3 * inner_dim]
mlp_gate = weight[3 * inner_dim : 3 * inner_dim + mlp_hidden_dim]
mlp_lin = weight[3 * inner_dim + mlp_hidden_dim :]
qkv_chunk = inner_dim // tp_size
mlp_chunk = mlp_hidden_dim // tp_size
parts = []
for i in range(tp_size):
parts += [
q[i * qkv_chunk : (i + 1) * qkv_chunk],
k[i * qkv_chunk : (i + 1) * qkv_chunk],
v[i * qkv_chunk : (i + 1) * qkv_chunk],
mlp_gate[i * mlp_chunk : (i + 1) * mlp_chunk],
mlp_lin [i * mlp_chunk : (i + 1) * mlp_chunk],
]
return torch.cat(parts, dim=0)
def _permute_out_for_tp(
weight: torch.Tensor,
tp_size: int,
attn_dim: int,
mlp_dim: int,
) -> torch.Tensor:
"""Permute to_out columns so RowwisePar rank i gets W[:, attn_i|mlp_i]."""
with torch.no_grad():
attn_part = weight[:, :attn_dim]
mlp_part = weight[:, attn_dim:]
attn_chunk = attn_dim // tp_size
mlp_chunk = mlp_dim // tp_size
parts = []
for i in range(tp_size):
parts.append(attn_part[:, i * attn_chunk : (i + 1) * attn_chunk])
parts.append(mlp_part[:, i * mlp_chunk : (i + 1) * mlp_chunk])
return torch.cat(parts, dim=1)
def apply_tp_flux2_transformer(
model: Flux2Transformer2DModel,
tp_mesh: DeviceMesh,
fuse_qkv: bool = False,
flash_attn: bool = False,
) -> Flux2Transformer2DModel:
tp_size = tp_mesh.size()
double_plan = {
"attn.to_q": ColwiseParallel(),
"attn.to_k": ColwiseParallel(),
"attn.to_v": ColwiseParallel(),
"attn.to_out.0": RowwiseParallel(),
"attn.add_q_proj": ColwiseParallel(),
"attn.add_k_proj": ColwiseParallel(),
"attn.add_v_proj": ColwiseParallel(),
"attn.to_add_out": RowwiseParallel(),
"ff.linear_in": ColwiseParallel(),
"ff.linear_out": RowwiseParallel(),
"ff_context.linear_in": ColwiseParallel(),
"ff_context.linear_out": RowwiseParallel(),
}
for block in model.transformer_blocks:
block.ff.linear_in.weight.data = _permute_swiglu_for_tp(
block.ff.linear_in.weight.data, tp_size)
block.ff_context.linear_in.weight.data = _permute_swiglu_for_tp(
block.ff_context.linear_in.weight.data, tp_size)
parallelize_module(block, tp_mesh, double_plan)
if fuse_qkv and flash_attn:
block.attn.set_processor(Flux2AttnProcessorFusedQKVFlashAttn(tp_size))
block.attn._fused_img_qkv_w = _build_fused_qkv_weight(
block.attn.to_q.weight,
block.attn.to_k.weight,
block.attn.to_v.weight,
)
block.attn._fused_txt_qkv_w = _build_fused_qkv_weight(
block.attn.add_q_proj.weight,
block.attn.add_k_proj.weight,
block.attn.add_v_proj.weight,
)
elif flash_attn:
block.attn.set_processor(Flux2AttnProcessorFlashAttn(tp_size))
elif fuse_qkv:
block.attn.set_processor(Flux2AttnProcessorFusedQKV(tp_size))
# Pre-build fused weights now (real DTensors present) so they are
# plain local tensors by the time torch.compile traces the forward.
# Lazy build inside _fused_proj fails under Dynamo fake-tensor mode
# because _to_local() on a fake DTensor returns the full DTensor.
block.attn._fused_img_qkv_w = _build_fused_qkv_weight(
block.attn.to_q.weight,
block.attn.to_k.weight,
block.attn.to_v.weight,
)
block.attn._fused_txt_qkv_w = _build_fused_qkv_weight(
block.attn.add_q_proj.weight,
block.attn.add_k_proj.weight,
block.attn.add_v_proj.weight,
)
else:
block.attn.set_processor(Flux2AttnProcessorTP(tp_size))
single_plan = {
"attn.to_qkv_mlp_proj": ColwiseParallel(),
"attn.to_out": RowwiseParallel(),
}
for block in model.single_transformer_blocks:
attn = block.attn
inner_dim = attn.inner_dim
mlp_hidden = attn.mlp_hidden_dim
attn.to_qkv_mlp_proj.weight.data = _permute_qkv_mlp_for_tp(
attn.to_qkv_mlp_proj.weight.data, tp_size, inner_dim, mlp_hidden)
attn.to_out.weight.data = _permute_out_for_tp(
attn.to_out.weight.data, tp_size, inner_dim, mlp_hidden)
parallelize_module(block, tp_mesh, single_plan)
if flash_attn:
block.attn.set_processor(Flux2ParallelSelfAttnProcessorFlashAttn(tp_size))
else:
block.attn.set_processor(Flux2ParallelSelfAttnProcessorTP(tp_size))
return model
def apply_tp_text_encoder(model: Qwen3ForCausalLM, tp_mesh: DeviceMesh) -> Qwen3ForCausalLM:
"""Llama-style TP plan for Qwen3ForCausalLM; no weight permutations needed."""
layer_plan = {
"self_attn.q_proj": ColwiseParallel(),
"self_attn.k_proj": ColwiseParallel(),
"self_attn.v_proj": ColwiseParallel(),
"self_attn.o_proj": RowwiseParallel(),
"mlp.gate_proj": ColwiseParallel(),
"mlp.up_proj": ColwiseParallel(),
"mlp.down_proj": RowwiseParallel(),
}
for layer in model.model.layers:
parallelize_module(layer, tp_mesh, layer_plan)
return model
def _encode_prompt_tp(
text_encoder: Qwen3ForCausalLM,
tokenizer: Qwen2TokenizerFast,
prompt: str,
batch_size: int,
device,
) -> torch.Tensor:
"""
Encode prompt with the TP Qwen3 text encoder on Neuron.
All ranks tokenize independently (deterministic — avoids integer-dtype
dist.broadcast which fails on Neuron). Returns [B, 512, 7680].
"""
prompts = [prompt] * batch_size
all_ids, all_masks = [], []
for p in prompts:
text = tokenizer.apply_chat_template(
[{"role": "user", "content": p}],
tokenize=False, add_generation_prompt=True, enable_thinking=False,
)
enc = tokenizer(text, return_tensors="pt", padding="max_length",
truncation=True, max_length=512)
all_ids.append(enc["input_ids"])
all_masks.append(enc["attention_mask"])
input_ids = torch.cat(all_ids, dim=0).to(device)
attention_mask = torch.cat(all_masks, dim=0).to(device)
with torch.no_grad():
output = text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
)
hidden = []
for k in (9, 18, 27):
hs = output.hidden_states[k]
if hasattr(hs, "to_local"):
hs = hs.to_local()
hidden.append(hs.to(torch.bfloat16))
out = torch.stack(hidden, dim=1) # [B, 3, 512, 2560]
B, _, seq, hdim = out.shape
return out.permute(0, 2, 1, 3).reshape(B, seq, 3 * hdim) # [B, 512, 7680]
# ---------------------------------------------------------------------------
# Model loading (meta-tensor-safe: from_config + safetensors)
# ---------------------------------------------------------------------------
def _snapshot(model_id: str) -> str:
from huggingface_hub import snapshot_download
return snapshot_download(model_id)
def _load_safetensors(model: torch.nn.Module, subfolder_dir: str) -> None:
"""Stream safetensors one tensor at a time — avoids doubling peak CPU RAM."""
from safetensors import safe_open
shards = sorted(f for f in os.listdir(subfolder_dir) if f.endswith(".safetensors"))
if not shards:
raise FileNotFoundError(f"No .safetensors in {subfolder_dir}")
params = dict(model.named_parameters())
buffers = dict(model.named_buffers())
loaded: set = set()
for fname in shards:
with safe_open(os.path.join(subfolder_dir, fname), framework="pt", device="cpu") as f:
for key in f.keys():
tensor = f.get_tensor(key)
if key in params:
params[key].data.copy_(tensor)
loaded.add(key)
elif key in buffers:
buffers[key].data.copy_(tensor)
loaded.add(key)
del tensor
missing = [k for k in params if k not in loaded]
if missing:
logger.warning(" missing keys (first 5): %s", missing[:5])
def load_text_encoder(model_id: str, random_weights: bool):
if random_weights:
return None, None
snap = _snapshot(model_id)
logger.info("Loading Qwen3 text encoder ...")
model = Qwen3ForCausalLM.from_pretrained(
os.path.join(snap, "text_encoder"), dtype=torch.bfloat16)
model.eval()
tokenizer = Qwen2TokenizerFast.from_pretrained(os.path.join(snap, "tokenizer"))
return model, tokenizer
def load_transformer(model_id: str, random_weights: bool) -> Flux2Transformer2DModel:
config = Flux2Transformer2DModel.load_config(model_id, subfolder="transformer")
_prev_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.bfloat16)
model = Flux2Transformer2DModel.from_config(config)
torch.set_default_dtype(_prev_dtype)
if not random_weights:
logger.info("Loading transformer weights ...")
snap = _snapshot(model_id)
_load_safetensors(model, os.path.join(snap, "transformer"))
return model
def load_vae(model_id: str, random_weights: bool) -> AutoencoderKLFlux2:
config = AutoencoderKLFlux2.load_config(model_id, subfolder="vae")
_prev_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.bfloat16)
vae = AutoencoderKLFlux2.from_config(config)
torch.set_default_dtype(_prev_dtype)
if not random_weights:
logger.info("Loading VAE weights ...")
snap = _snapshot(model_id)
_load_safetensors(vae, os.path.join(snap, "vae"))
vae.eval()
return vae
# ---------------------------------------------------------------------------
# Main run
# ---------------------------------------------------------------------------
def run(mode, model_id, prompt, height, width, num_steps,
batch_size, random_weights, output_path, seed, fuse_qkv=False):
assert mode in ("eager", "compile"), f"--mode must be 'eager' or 'compile', got {mode!r}"
dist.init_process_group(backend="neuron")
world_size = dist.get_world_size()
rank = dist.get_rank()
device = torch.neuron.current_device()
logger.info(f"Rank {rank}/{world_size} on device {device} "
f"mode={mode} ({height}x{width}, {num_steps} steps, random={random_weights})")
tp_mesh = DeviceMesh("neuron", list(range(world_size)))
joint_attention_dim = Flux2Transformer2DModel.load_config(
model_id, subfolder="transformer")["joint_attention_dim"]
text_seq_len = 512
# ------------------------------------------------------------------
# 1. Text encoder: all ranks load, TP, move to Neuron, encode, free.
#
# Memory ordering: text enc CPU copy freed before transformer loads
# so the two large CPU allocations never overlap:
# peak(text enc load) = tp × enc_size (e.g. 4×16 GB for 9B)
# peak(xfmr load) = tp × xfmr_size (e.g. 4×17 GB for 9B)
# ------------------------------------------------------------------
if not random_weights:
t0 = time.time()
text_encoder, tokenizer = load_text_encoder(model_id, random_weights)
logger.info(f"Rank {rank}: text encoder loaded in {time.time()-t0:.1f}s "
f"({sum(p.numel() for p in text_encoder.parameters())/1e9:.2f}B params)")
text_encoder = apply_tp_text_encoder(text_encoder, tp_mesh)
text_encoder = text_encoder.to(device)
text_encoder.eval()
if mode == "compile":
set_model_name(f"qwen3_text_encoder_rank{rank}")
# KNOWN ISSUE (transformers >= 4.52 / diffusers 0.37.0.dev, 2026-03-19):
# torch.compile(fullgraph=True) fails with:
# torch._dynamo.exc.Unsupported: Unsupported context manager
# Root cause: transformers.utils.output_capturing.maybe_install_capturing_hooks()
# uses a threading.Lock (_hook_installation_lock) that Dynamo cannot trace.
# The function has an early-return guard:
# if getattr(model, "_output_capturing_hooks_installed", False): return
# Fix: pre-install the hooks so the guard fires before the lock is reached.
from transformers.utils.output_capturing import install_all_output_capturing_hooks
install_all_output_capturing_hooks(text_encoder)
text_encoder = torch.compile(text_encoder, backend="neuron", fullgraph=True)
logger.info(f"Rank {rank}: text encoder compiled (NEFF will build on first call)")
gc.collect()
t0 = time.time()
prompt_embeds_dev = _encode_prompt_tp(
text_encoder, tokenizer, prompt, batch_size, device)
if rank == 0:
logger.info(f"Prompt encoded in {time.time()-t0:.1f}s "
f"shape={prompt_embeds_dev.shape}")
del text_encoder, tokenizer
gc.collect()
else:
# --random-weights: rank 0 generates random embeds, broadcasts to all
prompt_embeds_dev = torch.zeros(
batch_size, text_seq_len, joint_attention_dim,
dtype=torch.bfloat16, device=device)
if rank == 0:
prompt_embeds_dev.copy_(
torch.randn(batch_size, text_seq_len, joint_attention_dim,
dtype=torch.bfloat16).to(device))
dist.broadcast(prompt_embeds_dev, src=0)
# ------------------------------------------------------------------
# 2. Transformer: all ranks load, TP, move to Neuron.
# Text encoder CPU copy is already freed; peak RAM = 4 × xfmr_size.
# ------------------------------------------------------------------
t0 = time.time()
transformer = load_transformer(model_id, random_weights)
logger.info(f"Rank {rank}: transformer loaded in {time.time()-t0:.1f}s "
f"({sum(p.numel() for p in transformer.parameters())/1e9:.2f}B params)")
transformer = apply_tp_flux2_transformer(transformer, tp_mesh, fuse_qkv=fuse_qkv)
transformer = transformer.to(device)
transformer.eval()
if mode == "compile":
set_model_name(f"flux2_transformer_rank{rank}")
transformer = torch.compile(transformer, backend="neuron", fullgraph=True)
logger.info(f"Rank {rank}: transformer compiled (NEFF will build on first call)")
gc.collect()
# ------------------------------------------------------------------
# 3. VAE / scheduler / pipeline helpers: rank-0 only.
#
# Pipeline is created BEFORE torch.compile(vae) because
# Flux2KleinPipeline.__init__ calls len(vae.config.block_out_channels)
# which fails on an OptimizedModule wrapper.
# ------------------------------------------------------------------
pipe = None
if rank == 0:
snap = _snapshot(model_id)
vae = load_vae(model_id, random_weights)
vae = vae.to(device)
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
os.path.join(snap, "scheduler"))
pipe = Flux2KleinPipeline(
scheduler=scheduler,
vae=vae,
text_encoder=None,
tokenizer=None,
transformer=transformer,
is_distilled=True,
)
if mode == "compile":
set_model_name("flux2_vae_rank0")
vae = torch.compile(vae, backend="neuron", fullgraph=True)
logger.info("Rank 0: VAE compiled (NEFF will build on first call)")
# ------------------------------------------------------------------
# 4. Prepare latents and position IDs (all ranks, same seed — no broadcast)
# ------------------------------------------------------------------
generator = torch.Generator().manual_seed(seed)
num_latent_ch = transformer.config.in_channels // 4
_vae_tmp = vae if rank == 0 else None
if _vae_tmp is None:
vae_scale = 8
lh = 2 * (height // (vae_scale * 2))
lw = 2 * (width // (vae_scale * 2))
seq_len = (lh // 2) * (lw // 2)
latents_cpu = torch.randn(
batch_size, seq_len, transformer.config.in_channels,
dtype=torch.bfloat16, generator=generator)
ids = torch.cartesian_prod(
torch.arange(1), torch.arange(lh // 2),
torch.arange(lw // 2), torch.arange(1))
latent_ids_cpu = ids.unsqueeze(0).expand(batch_size, -1, -1).contiguous().float()
else:
latents_cpu, latent_ids_cpu_int = pipe.prepare_latents(
batch_size=batch_size,
num_latents_channels=num_latent_ch,
height=height,
width=width,
dtype=torch.bfloat16,
device="cpu",
generator=generator,
)
latent_ids_cpu = latent_ids_cpu_int.float()
t_ids = torch.cartesian_prod(
torch.arange(1), torch.arange(1), torch.arange(1), torch.arange(text_seq_len))
text_ids_cpu = t_ids.unsqueeze(0).expand(batch_size, -1, -1).contiguous().float()
latents_dev = latents_cpu.to(device)
latent_ids_dev = latent_ids_cpu.to(device)
text_ids_dev = text_ids_cpu.to(device)
# ------------------------------------------------------------------
# 5. Scheduler timesteps
# ------------------------------------------------------------------
if rank == 0:
image_seq_len = latents_dev.shape[1]
mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_steps)
scheduler.set_timesteps(num_steps, mu=mu)
timesteps = scheduler.timesteps
else:
scheduler = FlowMatchEulerDiscreteScheduler()
if rank == 0:
ts_tensor = timesteps.float()
else:
ts_tensor = torch.zeros(num_steps, dtype=torch.float32)
ts_dev = ts_tensor.to(device)
dist.broadcast(ts_dev, src=0)
ts_tensor = ts_dev
# ------------------------------------------------------------------
# 6. Denoising loop (all ranks via TP transformer)
# eager: step 1 triggers lazy-XLA compile
# compile: step 1 triggers Dynamo trace + NEFF compilation
# ------------------------------------------------------------------
dist.barrier()
compile_note = " (step 1 includes torch.compile NEFF build)" if mode == "compile" else ""
logger.info(f"Rank {rank}: starting {num_steps}-step denoising loop{compile_note} ...")
t_total = time.time()
with torch.no_grad():
for step_idx in range(num_steps):
t_step = time.time()
t_val = ts_tensor[step_idx]
timestep = t_val.expand(batch_size).to(torch.bfloat16).to(device) / 1000.0
noise_pred = transformer(
hidden_states=latents_dev,
encoder_hidden_states=prompt_embeds_dev,
timestep=timestep,
img_ids=latent_ids_dev,
txt_ids=text_ids_dev,
guidance=None,
return_dict=False,
)[0]
if rank == 0:
np_cpu = noise_pred.to("cpu")
lat_cpu = latents_dev.to("cpu")
t_cpu = t_val.cpu()
lat_new = scheduler.step(np_cpu, t_cpu, lat_cpu, return_dict=False)[0]
latents_dev.copy_(lat_new.to(device))
dist.broadcast(latents_dev, src=0)
if rank == 0:
elapsed = time.time() - t_step
np_f = noise_pred.float()
logger.info(f" step {step_idx+1}/{num_steps} "
f"t={t_val.item():.1f} "
f"mean={np_f.mean().item():.4f} "
f"std={np_f.std().item():.4f} "
f"elapsed={elapsed:.3f}s")
if rank == 0:
logger.info(f"Denoising complete: {time.time()-t_total:.2f}s total")
# ------------------------------------------------------------------
# 7. VAE decode (rank-0 only)
# compile mode: vae is an OptimizedModule; access config/buffers via
# vae._orig_mod to bypass the Dynamo wrapper.
# ------------------------------------------------------------------
if rank == 0:
vae_note = " (step 1 includes VAE NEFF build)" if mode == "compile" else ""
logger.info(f"Decoding latents with VAE{vae_note} ...")
latents_spatial = Flux2KleinPipeline._unpack_latents_with_ids(
latents_dev, latent_ids_dev)
vae_mod = vae._orig_mod if mode == "compile" else vae
bn_mean = vae_mod.bn.running_mean.view(1, -1, 1, 1).to(latents_spatial.dtype)
bn_std = torch.sqrt(
vae_mod.bn.running_var.view(1, -1, 1, 1) + vae_mod.config.batch_norm_eps
).to(latents_spatial.dtype)
latents_spatial = latents_spatial * bn_std + bn_mean
latents_spatial = Flux2KleinPipeline._unpatchify_latents(latents_spatial)
with torch.no_grad():
image_tensor = vae.decode(latents_spatial, return_dict=False)[0]
image_np = image_tensor.clamp(-1.0, 1.0).add(1.0).div(2.0)
image_np = (image_np * 255).byte().permute(0, 2, 3, 1).cpu().numpy()
try:
from PIL import Image
for i, arr in enumerate(image_np):
fname = output_path if batch_size == 1 \
else output_path.replace(".", f"_{i}.", 1)
Image.fromarray(arr).save(fname)
logger.info(f"Saved: {fname} (shape {arr.shape})")
except ImportError:
logger.warning("Pillow not installed — skipping image save.")
logger.info(f"output range: [{image_tensor.min():.3f}, {image_tensor.max():.3f}]")
logger.info("PASS: end-to-end pipeline completed")
dist.barrier()
dist.destroy_process_group()
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def parse_args():
p = argparse.ArgumentParser(
description="Flux2-klein pipeline (Qwen3 + TP transformer + VAE) on Neuron"
)
p.add_argument("--mode", choices=["eager", "compile"], default="eager",
help="eager: lazy-XLA path (default). compile: torch.compile Dynamo path.")
p.add_argument("--model-id", default=DEFAULT_MODEL_ID)
p.add_argument("--prompt", default="a photograph of a cat sitting on a Neuron chip, photorealistic")
p.add_argument("--height", type=int, default=512)
p.add_argument("--width", type=int, default=512)
p.add_argument("--num-steps", type=int, default=4)
p.add_argument("--batch-size", type=int, default=1)
p.add_argument("--random-weights", action="store_true", default=True)
p.add_argument("--no-random-weights", action="store_false", dest="random_weights")
p.add_argument("--output", default="flux2_output.png")
p.add_argument("--seed", type=int, default=42)
p.add_argument("--fused-qkv", action="store_true", default=False,
help="Use NKI fused QKV kernel for double-stream blocks (eager mode only). "
"Replaces 3 separate ColwisePar matmuls per stream with one nki_qkv call.")
p.add_argument(
"--cache-dir",
default=None,
help=(
"Persistent NEFF cache directory (sets TORCH_NEURONX_NEFF_CACHE_DIR). "
"Applies to both eager and compile modes. "
"NEFFs are saved on the first run and reloaded on subsequent runs, "
"skipping neuronx-cc recompilation. "
"Defaults to /tmp/neff_cache (lost on reboot). "
"Example: --cache-dir /home/ubuntu/neff_cache"
),
)
return p.parse_args()
if __name__ == "__main__":
args = parse_args()
# Always set the NEFF cache dir regardless of mode — both eager (lazy-XLA)
# and compile (Dynamo) paths use TORCH_NEURONX_NEFF_CACHE_DIR to persist
# compiled NEFFs across runs. Default /tmp/neff_cache is lost on reboot.
cache_dir = args.cache_dir or os.environ.get("TORCH_NEURONX_NEFF_CACHE_DIR", "/tmp/neff_cache")
os.environ["TORCH_NEURONX_NEFF_CACHE_DIR"] = cache_dir
os.makedirs(cache_dir, exist_ok=True)
logger.info(f"NEFF cache dir: {cache_dir}")
run(
mode=args.mode,
model_id=args.model_id,
prompt=args.prompt,
height=args.height,
width=args.width,
num_steps=args.num_steps,
batch_size=args.batch_size,
random_weights=args.random_weights,
output_path=args.output,
seed=args.seed,
fuse_qkv=args.fused_qkv,
)