#!/usr/bin/env python3 # Phi-3-mini – compile model.forward only, manual greedy loop on Neuron import argparse import logging import time import torch from transformers import AutoTokenizer, AutoModelForCausalLM import torch_neuronx # guarantees Neuron backend logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @torch.no_grad() def greedy_generate(model_forward, tokenizer, input_ids, max_new_tokens): B, seq_len = input_ids.shape device = input_ids.device position_ids = torch.arange(seq_len, dtype=torch.long, device=device).unsqueeze(0).expand(B, -1) for _ in range(max_new_tokens): logits = model_forward(input_ids, position_ids)[0] next_id = logits[:, -1, :].argmax(dim=-1, keepdim=True) input_ids = torch.cat([input_ids, next_id], dim=1)[:, -seq_len:] # rolling window # position_ids stays identical (fixed seq_len) return input_ids def main(): parser = argparse.ArgumentParser(description="Phi-3-mini forward-compile + manual greedy on Neuron") parser.add_argument("--model", default="microsoft/Phi-3-mini-4k-instruct") parser.add_argument("--seq-len", type=int, default=128, help="Fixed context length") parser.add_argument("--new-tokens", type=int, default=20, help="Tokens to generate") args = parser.parse_args() torch.manual_seed(42) torch.set_default_dtype(torch.float32) tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( args.model, torch_dtype=torch.float32, attn_implementation="eager", use_cache=False, # static shapes ).eval() # fixed-shape prompt prompt = "The future of AI is" inputs = tokenizer(prompt, max_length=args.seq_len, padding="max_length", truncation=True, return_tensors="pt") input_ids = inputs.input_ids B, seq_len = input_ids.shape # shape lock & compile forward only (full graph) with torch.no_grad(): position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0).expand(B, -1) _ = model(input_ids, position_ids) model.forward = torch.compile(model.forward, backend="neuron", fullgraph=True) # warmup start = time.time() with torch.no_grad(): _ = model(input_ids, position_ids) logger.info("Warmup (forward): %.3f s", time.time() - start) # manual greedy generation start = time.time() final_ids = greedy_generate(model.forward, tokenizer, input_ids, args.new_tokens) logger.info("Generate (manual loop): %.3f s", time.time() - start) text = tokenizer.decode(final_ids[0], skip_special_tokens=True) logger.info("Output: %s", text) if __name__ == "__main__": main() """ /usr/local/lib/python3.10/site-packages/torch_mlir/dialects/stablehlo/__init__.py:24: UserWarning: Could not import StableHLO C++ extension: libStablehloUnifiedPythonCAPI.so.22.0git: cannot open shared object file: No such file or directory warnings.warn(f"Could not import StableHLO C++ extension: {e}") `torch_dtype` is deprecated! Use `dtype` instead! Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00, 1.90it/s] INFO:__main__:Warmup (forward): 19.975 s INFO:__main__:Generate (manual loop): 271.678 s INFO:__main__:Output: The future of AI is : 1iewer I'melissa' """