File size: 2,882 Bytes
5ee43e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# torchrun --nproc_per_node=4 test_t5_text_encoder.py
import os, time, argparse, logging, torch, torch.distributed as dist
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor.parallel import (
    ColwiseParallel, RowwiseParallel, PrepareModuleInput, parallelize_module
)
from transformers import T5EncoderModel, AutoTokenizer
from torchtitan.models.t5 import T5Encoder  # or transformers T5EncoderModel
import torch_neuronx
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def apply_tp_t5(encoder: torch.nn.Module, tp_mesh: DeviceMesh):
    # encoder.embed_tokens  already replicated
    plan = {
        "embed_tokens": None,  # replicate
        "encoder.block": None,  # we will loop inside
    }
    parallelize_module(encoder, tp_mesh, plan)

    # shard every dense layer inside each encoder block
    for layer in encoder.encoder.block:
        layer_plan = {
            "layer.0.SelfAttention.q":  ColwiseParallel(),
            "layer.0.SelfAttention.k":  ColwiseParallel(),
            "layer.0.SelfAttention.v":  ColwiseParallel(),
            "layer.0.SelfAttention.o":  RowwiseParallel(output_layouts=Replicate()),
            "layer.0.dense":            ColwiseParallel(),
            "layer.1.dense":            RowwiseParallel(output_layouts=Replicate()),
        }
        parallelize_module(layer, tp_mesh, layer_plan)
    return encoder

def main():
    dist.init_process_group(backend="neuron")
    rank = dist.get_rank()
    device = torch.device(f"neuron:{rank}")
    tp_mesh = DeviceMesh("neuron", list(range(dist.get_world_size())))

    parser = argparse.ArgumentParser()
    parser.add_argument("--model", default="google/t5-v1_1-xxl")
    parser.add_argument("--seq-len", type=int, default=512)
    args = parser.parse_args()

    tokenizer = AutoTokenizer.from_pretrained(args.model)
    # create model on CPU, real tensors
    with torch.device("cpu"):
        encoder = T5EncoderModel.from_pretrained(args.model, attn_implementation="eager").eval()

    encoder = apply_tp_t5(encoder, tp_mesh)
    # move local shards to Neuron
    for p in encoder.parameters():
        if isinstance(p, DTensor):
            p._local_tensor = p._local_tensor.to(device)
        else:
            p.data = p.data.to(device)

    encoder = torch.compile(encoder, backend="neuron", fullgraph=False)

    text = ["a photo of a cat"]
    txt_in = tokenizer(text, max_length=args.seq_len, padding="max_length", return_tensors="pt")
    input_ids = txt_in.input_ids.to(device)

    with torch.no_grad():
        _ = encoder(input_ids)  # compile
        t0 = time.time()
        out = encoder(input_ids).last_hidden_state
        logger.info("Rank %d T5-XXL enc latency: %.3f ms  shape: %s",
                    rank, (time.time()-t0)*1000, out.shape)  # [1, seq_len, 4096]

if __name__ == "__main__":
    main()