Ubuntu
tests
5ee43e9
#!/usr/bin/env python3
# T5 decoder (no cache) on Neuron – constant shapes, full graph, no Apex
import os
os.environ["USE_FUSED_LAYER_NORM"] = "0" # MUST be before any transformers import
import argparse
import logging
import time
import torch
from transformers import T5Tokenizer, T5Model
import torch_neuronx # guarantees Neuron backend
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def main():
parser = argparse.ArgumentParser(description="T5 decoder on Neuron (full graph, no cache)")
parser.add_argument("--model", default="t5-small")
parser.add_argument("--seq-len", type=int, default=128, help="Fixed seq length")
args = parser.parse_args()
torch.manual_seed(42)
torch.set_default_dtype(torch.float32)
tokenizer = T5Tokenizer.from_pretrained(args.model)
# disable DynamicCache → no deepcopy of config
model = T5Model.from_pretrained(
args.model,
torch_dtype=torch.float32,
attn_implementation="eager",
use_cache=False, # <-- static shapes, no cache
).eval()
# constant-shape inputs
text = "hello"
enc_tok = tokenizer(text, max_length=args.seq_len, padding="max_length", truncation=True, return_tensors="pt")
with torch.no_grad():
enc_out = model.encoder(input_ids=enc_tok.input_ids).last_hidden_state.detach()
dec_tok = tokenizer("<pad>", max_length=args.seq_len, padding="max_length", return_tensors="pt")
# pre-run to lock shapes
with torch.no_grad():
_ = model.decoder(input_ids=dec_tok.input_ids, encoder_hidden_states=enc_out).last_hidden_state
# compile decoder forward only (full graph)
decode_fn = lambda inp, enc: model.decoder(input_ids=inp, encoder_hidden_states=enc).last_hidden_state
decode_fn = torch.compile(decode_fn, backend="neuron", fullgraph=True)
# warmup
start = time.time()
with torch.no_grad():
_ = decode_fn(dec_tok.input_ids, enc_out)
logger.info("Warmup: %.3f s", time.time() - start)
# benchmark
start = time.time()
with torch.no_grad():
hidden = decode_fn(dec_tok.input_ids, enc_out)
logger.info("Run: %.3f s", time.time() - start)
logger.info("Hidden shape: %s", hidden.shape) # [B, seq_len, d_model]
if __name__ == "__main__":
main()