Ubuntu
tests
5ee43e9
#!/usr/bin/env python3
# T5 encoder on Neuron – no Apex, full graph, constant shapes
import os
os.environ["USE_FUSED_LAYER_NORM"] = "0" # <── disable Apex
import argparse
import logging
import time
import torch
from transformers import T5Tokenizer, T5Model # use T5Model (no LM head)
from datasets import load_dataset
import torch_neuronx # guarantees Neuron backend
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def main():
parser = argparse.ArgumentParser(description="T5 encoder on Neuron (full graph)")
parser.add_argument("--model", default="t5-small")
parser.add_argument("--seq-len", type=int, default=128, help="Fixed seq length")
args = parser.parse_args()
torch.manual_seed(42)
torch.set_default_dtype(torch.float32)
tokenizer = T5Tokenizer.from_pretrained(args.model)
model = T5Model.from_pretrained(
args.model, torch_dtype=torch.float32, attn_implementation="eager"
).eval()
# fixed-shape input
text = "translate English to French: The cat is on the mat."
inputs = tokenizer(text, max_length=args.seq_len, padding="max_length", truncation=True, return_tensors="pt")
# pre-run to lock shapes
with torch.no_grad():
_ = model.encoder(**inputs).last_hidden_state
# compile encoder forward only (full graph)
encode_fn = lambda **kw: model.encoder(**kw).last_hidden_state
encode_fn = torch.compile(encode_fn, backend="neuron", fullgraph=True)
# warmup
start = time.time()
with torch.no_grad():
_ = encode_fn(**inputs)
logger.info("Warmup: %.3f s", time.time() - start)
# benchmark
start = time.time()
with torch.no_grad():
hidden = encode_fn(**inputs)
logger.info("Run: %.3f s", time.time() - start)
logger.info("Hidden shape: %s", hidden.shape) # [B, seq_len, d_model]
if __name__ == "__main__":
main()