#!/usr/bin/env python3 # Phi (Phi-2 default) forward-trace + manual greedy on Neuron – fixed pad token import argparse import logging import time import torch from transformers import AutoTokenizer, AutoModelForCausalLM import torch_neuronx # guarantees Neuron backend logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @torch.no_grad() def greedy_generate(model_forward, tokenizer, input_ids, max_new_tokens): """Manual greedy loop. Calls the *compiled* forward iteratively.""" B, seq_len = input_ids.shape device = input_ids.device position_ids = torch.arange(seq_len, dtype=torch.long, device=device).unsqueeze(0).expand(B, -1) for _ in range(max_new_tokens): logits = model_forward(input_ids, position_ids)[0] # unpack tuple next_id = logits[:, -1, :].argmax(dim=-1, keepdim=True) input_ids = torch.cat([input_ids, next_id], dim=1)[:, -seq_len:] # rolling window return input_ids def main(): parser = argparse.ArgumentParser(description="Phi forward-compile + manual greedy on Neuron") parser.add_argument("--model", default="microsoft/phi-2") parser.add_argument("--seq-len", type=int, default=128, help="Fixed context length") parser.add_argument("--new-tokens", type=int, default=20, help="Tokens to generate") args = parser.parse_args() torch.manual_seed(42) torch.set_default_dtype(torch.float32) tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) # Phi has no pad_token by default if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( args.model, torch_dtype=torch.float32, attn_implementation="eager", use_cache=False, # static shapes ).eval() prompt = "The future of AI is" inputs = tokenizer(prompt, max_length=args.seq_len, padding="max_length", truncation=True, return_tensors="pt") input_ids = inputs.input_ids B, seq_len = input_ids.shape # shape lock & compile forward only (full graph) with torch.no_grad(): position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0).expand(B, -1) _ = model(input_ids, position_ids) model.forward = torch.compile(model.forward, backend="neuron", fullgraph=True) # warmup start = time.time() with torch.no_grad(): _ = model(input_ids, position_ids) logger.info("Warmup (forward): %.3f s", time.time() - start) # manual greedy generation start = time.time() final_ids = greedy_generate(model.forward, tokenizer, input_ids, args.new_tokens) logger.info("Generate (manual loop): %.3f s", time.time() - start) text = tokenizer.decode(final_ids[0], skip_special_tokens=True) logger.info("Output: %s", text) if __name__ == "__main__": main()