|
|
|
|
|
|
|
|
import argparse |
|
|
import logging |
|
|
import time |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
import torch_neuronx |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def greedy_generate(model_forward, tokenizer, input_ids, max_new_tokens): |
|
|
B, seq_len = input_ids.shape |
|
|
device = input_ids.device |
|
|
position_ids = torch.arange(seq_len, dtype=torch.long, device=device).unsqueeze(0).expand(B, -1) |
|
|
|
|
|
for _ in range(max_new_tokens): |
|
|
logits = model_forward(input_ids, position_ids)[0] |
|
|
next_id = logits[:, -1, :].argmax(dim=-1, keepdim=True) |
|
|
input_ids = torch.cat([input_ids, next_id], dim=1)[:, -seq_len:] |
|
|
|
|
|
return input_ids |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Phi-3-mini forward-compile + manual greedy on Neuron") |
|
|
parser.add_argument("--model", default="microsoft/Phi-3-mini-4k-instruct") |
|
|
parser.add_argument("--seq-len", type=int, default=128, help="Fixed context length") |
|
|
parser.add_argument("--new-tokens", type=int, default=20, help="Tokens to generate") |
|
|
args = parser.parse_args() |
|
|
|
|
|
torch.manual_seed(42) |
|
|
torch.set_default_dtype(torch.float32) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
args.model, |
|
|
torch_dtype=torch.float32, |
|
|
attn_implementation="eager", |
|
|
use_cache=False, |
|
|
).eval() |
|
|
|
|
|
|
|
|
prompt = "The future of AI is" |
|
|
inputs = tokenizer(prompt, max_length=args.seq_len, padding="max_length", truncation=True, return_tensors="pt") |
|
|
input_ids = inputs.input_ids |
|
|
B, seq_len = input_ids.shape |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0).expand(B, -1) |
|
|
_ = model(input_ids, position_ids) |
|
|
model.forward = torch.compile(model.forward, backend="neuron", fullgraph=True) |
|
|
|
|
|
|
|
|
start = time.time() |
|
|
with torch.no_grad(): |
|
|
_ = model(input_ids, position_ids) |
|
|
logger.info("Warmup (forward): %.3f s", time.time() - start) |
|
|
|
|
|
|
|
|
start = time.time() |
|
|
final_ids = greedy_generate(model.forward, tokenizer, input_ids, args.new_tokens) |
|
|
logger.info("Generate (manual loop): %.3f s", time.time() - start) |
|
|
|
|
|
text = tokenizer.decode(final_ids[0], skip_special_tokens=True) |
|
|
logger.info("Output: %s", text) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
""" |
|
|
/usr/local/lib/python3.10/site-packages/torch_mlir/dialects/stablehlo/__init__.py:24: UserWarning: Could not import StableHLO C++ extension: libStablehloUnifiedPythonCAPI.so.22.0git: cannot open shared object file: No such file or directory |
|
|
warnings.warn(f"Could not import StableHLO C++ extension: {e}") |
|
|
`torch_dtype` is deprecated! Use `dtype` instead! |
|
|
Loading checkpoint shards: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 2/2 [00:01<00:00, 1.90it/s] |
|
|
INFO:__main__:Warmup (forward): 19.975 s |
|
|
INFO:__main__:Generate (manual loop): 271.678 s |
|
|
INFO:__main__:Output: The future of AI is |
|
|
: 1iewer |
|
|
I'melissa' |
|
|
""" |