| |
| import os, time, argparse, logging, torch, torch.distributed as dist |
| from torch.distributed.device_mesh import DeviceMesh |
| from torch.distributed.tensor.parallel import ( |
| ColwiseParallel, RowwiseParallel, PrepareModuleInput, parallelize_module |
| ) |
| from diffusers import FluxTransformer2DModel |
| import torch_neuronx |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| def apply_tp_flux(transformer: torch.nn.Module, tp_mesh: DeviceMesh): |
| |
| plan = {"x_embedder": None, "norm_out": None} |
| parallelize_module(transformer, tp_mesh, plan) |
|
|
| |
| for block in transformer.transformer_blocks: |
| blk = { |
| "norm1": None, |
| "norm1_k": None, |
| "attn.qkv": ColwiseParallel(), |
| "attn.proj": RowwiseParallel(output_layouts=Replicate()), |
| "attn.norm_q": None, |
| "attn.norm_k": None, |
| "ffn.net.0": ColwiseParallel(), |
| "ffn.net.2": RowwiseParallel(output_layouts=Replicate()), |
| } |
| parallelize_module(block, tp_mesh, blk) |
| return transformer |
|
|
| def main(): |
| dist.init_process_group(backend="neuron") |
| rank = dist.get_rank() |
| device = torch.device(f"neuron:{rank}") |
| tp_mesh = DeviceMesh("neuron", list(range(dist.get_world_size()))) |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument("--model", default="black-forest-labs/FLUX.1-dev/transformer") |
| parser.add_argument("--seq-len", type=int, default=4096) |
| parser.add_argument("--dim", type=int, default=3072) |
| args = parser.parse_args() |
|
|
| |
| with torch.device("cpu"): |
| transformer = FluxTransformer2DModel.from_pretrained( |
| args.model, torch_dtype=torch.bfloat16, attn_implementation="eager" |
| ).eval() |
|
|
| transformer = apply_tp_flux(transformer, tp_mesh) |
| |
| for p in transformer.parameters(): |
| if isinstance(p, DTensor): |
| p._local_tensor = p._local_tensor.to(device, dtype=torch.bfloat16) |
| else: |
| p.data = p.data.to(device, dtype=torch.bfloat16) |
|
|
| transformer = torch.compile(transformer, backend="neuron", fullgraph=False) |
|
|
| batch = 1 |
| hidden = torch.randn(batch, args.seq_len, args.dim, dtype=torch.bfloat16, device=device) |
| encoder_hidden = torch.randn(batch, args.seq_len, 4096, dtype=torch.bfloat16, device=device) |
| timestep = torch.tensor([500], dtype=torch.int64, device=device) |
|
|
| with torch.no_grad(): |
| _ = transformer(hidden=hidden, encoder_hidden=encoder_hidden, timestep=timestep) |
| t0 = time.time() |
| out = transformer(hidden=hidden, encoder_hidden=encoder_hidden, timestep=timestep) |
| logger.info("Rank %d Flux-TFM latency: %.3f ms shape: %s", |
| rank, (time.time()-t0)*1000, out.sample.shape) |
|
|
| if __name__ == "__main__": |
| main() |