|
|
|
|
|
import os, time, argparse, logging, torch, torch.distributed as dist |
|
|
from torch.distributed.device_mesh import DeviceMesh |
|
|
from torch.distributed.tensor.parallel import ( |
|
|
ColwiseParallel, RowwiseParallel, PrepareModuleInput, parallelize_module |
|
|
) |
|
|
from transformers import T5EncoderModel, AutoTokenizer |
|
|
from torchtitan.models.t5 import T5Encoder |
|
|
import torch_neuronx |
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def apply_tp_t5(encoder: torch.nn.Module, tp_mesh: DeviceMesh): |
|
|
|
|
|
plan = { |
|
|
"embed_tokens": None, |
|
|
"encoder.block": None, |
|
|
} |
|
|
parallelize_module(encoder, tp_mesh, plan) |
|
|
|
|
|
|
|
|
for layer in encoder.encoder.block: |
|
|
layer_plan = { |
|
|
"layer.0.SelfAttention.q": ColwiseParallel(), |
|
|
"layer.0.SelfAttention.k": ColwiseParallel(), |
|
|
"layer.0.SelfAttention.v": ColwiseParallel(), |
|
|
"layer.0.SelfAttention.o": RowwiseParallel(output_layouts=Replicate()), |
|
|
"layer.0.dense": ColwiseParallel(), |
|
|
"layer.1.dense": RowwiseParallel(output_layouts=Replicate()), |
|
|
} |
|
|
parallelize_module(layer, tp_mesh, layer_plan) |
|
|
return encoder |
|
|
|
|
|
def main(): |
|
|
dist.init_process_group(backend="neuron") |
|
|
rank = dist.get_rank() |
|
|
device = torch.device(f"neuron:{rank}") |
|
|
tp_mesh = DeviceMesh("neuron", list(range(dist.get_world_size()))) |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--model", default="google/t5-v1_1-xxl") |
|
|
parser.add_argument("--seq-len", type=int, default=512) |
|
|
args = parser.parse_args() |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model) |
|
|
|
|
|
with torch.device("cpu"): |
|
|
encoder = T5EncoderModel.from_pretrained(args.model, attn_implementation="eager").eval() |
|
|
|
|
|
encoder = apply_tp_t5(encoder, tp_mesh) |
|
|
|
|
|
for p in encoder.parameters(): |
|
|
if isinstance(p, DTensor): |
|
|
p._local_tensor = p._local_tensor.to(device) |
|
|
else: |
|
|
p.data = p.data.to(device) |
|
|
|
|
|
encoder = torch.compile(encoder, backend="neuron", fullgraph=False) |
|
|
|
|
|
text = ["a photo of a cat"] |
|
|
txt_in = tokenizer(text, max_length=args.seq_len, padding="max_length", return_tensors="pt") |
|
|
input_ids = txt_in.input_ids.to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
_ = encoder(input_ids) |
|
|
t0 = time.time() |
|
|
out = encoder(input_ids).last_hidden_state |
|
|
logger.info("Rank %d T5-XXL enc latency: %.3f ms shape: %s", |
|
|
rank, (time.time()-t0)*1000, out.shape) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |