Ubuntu
tests
5ee43e9
# torchrun --nproc_per_node=4 test_t5_text_encoder.py
import os, time, argparse, logging, torch, torch.distributed as dist
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor.parallel import (
ColwiseParallel, RowwiseParallel, PrepareModuleInput, parallelize_module
)
from transformers import T5EncoderModel, AutoTokenizer
from torchtitan.models.t5 import T5Encoder # or transformers T5EncoderModel
import torch_neuronx
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def apply_tp_t5(encoder: torch.nn.Module, tp_mesh: DeviceMesh):
# encoder.embed_tokens already replicated
plan = {
"embed_tokens": None, # replicate
"encoder.block": None, # we will loop inside
}
parallelize_module(encoder, tp_mesh, plan)
# shard every dense layer inside each encoder block
for layer in encoder.encoder.block:
layer_plan = {
"layer.0.SelfAttention.q": ColwiseParallel(),
"layer.0.SelfAttention.k": ColwiseParallel(),
"layer.0.SelfAttention.v": ColwiseParallel(),
"layer.0.SelfAttention.o": RowwiseParallel(output_layouts=Replicate()),
"layer.0.dense": ColwiseParallel(),
"layer.1.dense": RowwiseParallel(output_layouts=Replicate()),
}
parallelize_module(layer, tp_mesh, layer_plan)
return encoder
def main():
dist.init_process_group(backend="neuron")
rank = dist.get_rank()
device = torch.device(f"neuron:{rank}")
tp_mesh = DeviceMesh("neuron", list(range(dist.get_world_size())))
parser = argparse.ArgumentParser()
parser.add_argument("--model", default="google/t5-v1_1-xxl")
parser.add_argument("--seq-len", type=int, default=512)
args = parser.parse_args()
tokenizer = AutoTokenizer.from_pretrained(args.model)
# create model on CPU, real tensors
with torch.device("cpu"):
encoder = T5EncoderModel.from_pretrained(args.model, attn_implementation="eager").eval()
encoder = apply_tp_t5(encoder, tp_mesh)
# move local shards to Neuron
for p in encoder.parameters():
if isinstance(p, DTensor):
p._local_tensor = p._local_tensor.to(device)
else:
p.data = p.data.to(device)
encoder = torch.compile(encoder, backend="neuron", fullgraph=False)
text = ["a photo of a cat"]
txt_in = tokenizer(text, max_length=args.seq_len, padding="max_length", return_tensors="pt")
input_ids = txt_in.input_ids.to(device)
with torch.no_grad():
_ = encoder(input_ids) # compile
t0 = time.time()
out = encoder(input_ids).last_hidden_state
logger.info("Rank %d T5-XXL enc latency: %.3f ms shape: %s",
rank, (time.time()-t0)*1000, out.shape) # [1, seq_len, 4096]
if __name__ == "__main__":
main()