|
|
|
|
|
|
|
|
import argparse |
|
|
import logging |
|
|
import time |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
import torch_neuronx |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def greedy_generate(model_forward, tokenizer, input_ids, max_new_tokens): |
|
|
"""Manual greedy loop. Calls the *compiled* forward iteratively.""" |
|
|
B, seq_len = input_ids.shape |
|
|
device = input_ids.device |
|
|
position_ids = torch.arange(seq_len, dtype=torch.long, device=device).unsqueeze(0).expand(B, -1) |
|
|
|
|
|
for _ in range(max_new_tokens): |
|
|
logits = model_forward(input_ids, position_ids)[0] |
|
|
next_id = logits[:, -1, :].argmax(dim=-1, keepdim=True) |
|
|
input_ids = torch.cat([input_ids, next_id], dim=1)[:, -seq_len:] |
|
|
return input_ids |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Phi forward-compile + manual greedy on Neuron") |
|
|
parser.add_argument("--model", default="microsoft/phi-2") |
|
|
parser.add_argument("--seq-len", type=int, default=128, help="Fixed context length") |
|
|
parser.add_argument("--new-tokens", type=int, default=20, help="Tokens to generate") |
|
|
args = parser.parse_args() |
|
|
|
|
|
torch.manual_seed(42) |
|
|
torch.set_default_dtype(torch.float32) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) |
|
|
|
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
args.model, |
|
|
torch_dtype=torch.float32, |
|
|
attn_implementation="eager", |
|
|
use_cache=False, |
|
|
).eval() |
|
|
|
|
|
prompt = "The future of AI is" |
|
|
inputs = tokenizer(prompt, max_length=args.seq_len, padding="max_length", truncation=True, return_tensors="pt") |
|
|
input_ids = inputs.input_ids |
|
|
B, seq_len = input_ids.shape |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0).expand(B, -1) |
|
|
_ = model(input_ids, position_ids) |
|
|
model.forward = torch.compile(model.forward, backend="neuron", fullgraph=True) |
|
|
|
|
|
|
|
|
start = time.time() |
|
|
with torch.no_grad(): |
|
|
_ = model(input_ids, position_ids) |
|
|
logger.info("Warmup (forward): %.3f s", time.time() - start) |
|
|
|
|
|
|
|
|
start = time.time() |
|
|
final_ids = greedy_generate(model.forward, tokenizer, input_ids, args.new_tokens) |
|
|
logger.info("Generate (manual loop): %.3f s", time.time() - start) |
|
|
|
|
|
text = tokenizer.decode(final_ids[0], skip_special_tokens=True) |
|
|
logger.info("Output: %s", text) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |