# torchrun --nproc_per_node=4 test_t5_text_encoder.py import os, time, argparse, logging, torch, torch.distributed as dist from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor.parallel import ( ColwiseParallel, RowwiseParallel, PrepareModuleInput, parallelize_module ) from transformers import T5EncoderModel, AutoTokenizer from torchtitan.models.t5 import T5Encoder # or transformers T5EncoderModel import torch_neuronx logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def apply_tp_t5(encoder: torch.nn.Module, tp_mesh: DeviceMesh): # encoder.embed_tokens already replicated plan = { "embed_tokens": None, # replicate "encoder.block": None, # we will loop inside } parallelize_module(encoder, tp_mesh, plan) # shard every dense layer inside each encoder block for layer in encoder.encoder.block: layer_plan = { "layer.0.SelfAttention.q": ColwiseParallel(), "layer.0.SelfAttention.k": ColwiseParallel(), "layer.0.SelfAttention.v": ColwiseParallel(), "layer.0.SelfAttention.o": RowwiseParallel(output_layouts=Replicate()), "layer.0.dense": ColwiseParallel(), "layer.1.dense": RowwiseParallel(output_layouts=Replicate()), } parallelize_module(layer, tp_mesh, layer_plan) return encoder def main(): dist.init_process_group(backend="neuron") rank = dist.get_rank() device = torch.device(f"neuron:{rank}") tp_mesh = DeviceMesh("neuron", list(range(dist.get_world_size()))) parser = argparse.ArgumentParser() parser.add_argument("--model", default="google/t5-v1_1-xxl") parser.add_argument("--seq-len", type=int, default=512) args = parser.parse_args() tokenizer = AutoTokenizer.from_pretrained(args.model) # create model on CPU, real tensors with torch.device("cpu"): encoder = T5EncoderModel.from_pretrained(args.model, attn_implementation="eager").eval() encoder = apply_tp_t5(encoder, tp_mesh) # move local shards to Neuron for p in encoder.parameters(): if isinstance(p, DTensor): p._local_tensor = p._local_tensor.to(device) else: p.data = p.data.to(device) encoder = torch.compile(encoder, backend="neuron", fullgraph=False) text = ["a photo of a cat"] txt_in = tokenizer(text, max_length=args.seq_len, padding="max_length", return_tensors="pt") input_ids = txt_in.input_ids.to(device) with torch.no_grad(): _ = encoder(input_ids) # compile t0 = time.time() out = encoder(input_ids).last_hidden_state logger.info("Rank %d T5-XXL enc latency: %.3f ms shape: %s", rank, (time.time()-t0)*1000, out.shape) # [1, seq_len, 4096] if __name__ == "__main__": main()