| | |
| | import os |
| | from transformers import AutoTokenizer, GPT2Tokenizer |
| | |
| | from metaseq import checkpoint_utils |
| | from transformers import OPTForCausalLM |
| | import torch |
| |
|
| | path = "./model" |
| | hf_path = "/home/patrick/facebook/opt-125m" |
| |
|
| |
|
| | vocab_file = os.path.join(path, "gpt2-vocab.json") |
| | merges_file = os.path.join(path, "gpt2-merges.txt") |
| |
|
| | tokenizer = GPT2Tokenizer(vocab_file, merges_file) |
| | tokenizer.save_pretrained(path) |
| |
|
| | checkpoint = checkpoint_utils.load_model_ensemble_and_task( |
| | [os.path.join(path, "restored.pt")], |
| | arg_overrides={ |
| | "vocab_filename": vocab_file, |
| | "merges_filename": merges_file, |
| | } |
| | ) |
| |
|
| | model = checkpoint[0][0].eval() |
| | model = model |
| |
|
| | hf_model = OPTForCausalLM.from_pretrained(hf_path) |
| |
|
| | |
| | def single_batch_forward_logits(prompts): |
| | input_ids = tokenizer(prompts, return_tensors="pt").input_ids |
| | input_ids = torch.cat([torch.tensor([[0]]), input_ids], dim=-1) |
| | input_ids = input_ids |
| | with torch.no_grad(): |
| | logits = model(input_ids)[0] |
| | return logits |
| |
|
| | |
| | def forward_hf(prompts): |
| | input_ids = tokenizer(prompts, return_tensors="pt").input_ids |
| | input_ids = torch.cat([torch.tensor([[0]]), input_ids], dim=-1) |
| | input_ids = input_ids |
| | with torch.no_grad(): |
| | logits = hf_model(input_ids)[0] |
| | return logits |
| |
|
| | prompts = [ |
| | "Today is a beautiful day and I want to", |
| | "In the city of", |
| | "Paris is the capital of France and", |
| | "Computers and mobile phones have taken", |
| | ] |
| |
|
| | print("Next word generation") |
| | for prompt in prompts: |
| | print("-------------") |
| | print(f"Prompt: {prompt}...\n") |
| | logits_fsq = single_batch_forward_logits(prompt) |
| | pred_next_token = torch.argmax(logits_fsq[0, -1], -1) |
| | next_token = tokenizer.convert_ids_to_tokens([pred_next_token]) |
| | next_token = next_token[0].replace("Ġ", "") |
| | print(f"Next word: {next_token}") |
| | print("-------------") |
| | logits = forward_hf(prompt) |
| | pred_next_token = torch.argmax(logits[0, -1], -1) |
| | next_token = tokenizer.convert_ids_to_tokens([pred_next_token]) |
| | next_token = next_token[0].replace("Ġ", "") |
| | print(f"Next word: {next_token}") |
| | print("-------------") |
| |
|
| |
|
| | print("Is equal:", torch.allclose(logits_fsq.cpu(), logits.cpu(), atol=1e-3)) |
| |
|