torch-neuron-test-samples / torch_compile /flux /test_flux_transformer.py
Ubuntu
tests
5ee43e9
# torchrun --nproc_per_node=8 test_flux_transformer.py
import os, time, argparse, logging, torch, torch.distributed as dist
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor.parallel import (
ColwiseParallel, RowwiseParallel, PrepareModuleInput, parallelize_module
)
from diffusers import FluxTransformer2DModel
import torch_neuronx
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def apply_tp_flux(transformer: torch.nn.Module, tp_mesh: DeviceMesh):
# embed & final-norm replicated
plan = {"x_embedder": None, "norm_out": None}
parallelize_module(transformer, tp_mesh, plan)
# inside each transformer block
for block in transformer.transformer_blocks:
blk = {
"norm1": None,
"norm1_k": None,
"attn.qkv": ColwiseParallel(),
"attn.proj": RowwiseParallel(output_layouts=Replicate()),
"attn.norm_q": None,
"attn.norm_k": None,
"ffn.net.0": ColwiseParallel(), # gate
"ffn.net.2": RowwiseParallel(output_layouts=Replicate()),
}
parallelize_module(block, tp_mesh, blk)
return transformer
def main():
dist.init_process_group(backend="neuron")
rank = dist.get_rank()
device = torch.device(f"neuron:{rank}")
tp_mesh = DeviceMesh("neuron", list(range(dist.get_world_size())))
parser = argparse.ArgumentParser()
parser.add_argument("--model", default="black-forest-labs/FLUX.1-dev/transformer")
parser.add_argument("--seq-len", type=int, default=4096)
parser.add_argument("--dim", type=int, default=3072)
args = parser.parse_args()
# create on CPU, real tensors
with torch.device("cpu"):
transformer = FluxTransformer2DModel.from_pretrained(
args.model, torch_dtype=torch.bfloat16, attn_implementation="eager"
).eval()
transformer = apply_tp_flux(transformer, tp_mesh)
# move local shards to Neuron
for p in transformer.parameters():
if isinstance(p, DTensor):
p._local_tensor = p._local_tensor.to(device, dtype=torch.bfloat16)
else:
p.data = p.data.to(device, dtype=torch.bfloat16)
transformer = torch.compile(transformer, backend="neuron", fullgraph=False)
batch = 1
hidden = torch.randn(batch, args.seq_len, args.dim, dtype=torch.bfloat16, device=device)
encoder_hidden = torch.randn(batch, args.seq_len, 4096, dtype=torch.bfloat16, device=device)
timestep = torch.tensor([500], dtype=torch.int64, device=device)
with torch.no_grad():
_ = transformer(hidden=hidden, encoder_hidden=encoder_hidden, timestep=timestep)
t0 = time.time()
out = transformer(hidden=hidden, encoder_hidden=encoder_hidden, timestep=timestep)
logger.info("Rank %d Flux-TFM latency: %.3f ms shape: %s",
rank, (time.time()-t0)*1000, out.sample.shape)
if __name__ == "__main__":
main()