Delete inference.py
Browse files- inference.py +0 -93
inference.py
DELETED
|
@@ -1,93 +0,0 @@
|
|
| 1 |
-
import numpy as np
|
| 2 |
-
import torch
|
| 3 |
-
import torch.nn.functional as F
|
| 4 |
-
import tiktoken
|
| 5 |
-
from dataclasses import dataclass
|
| 6 |
-
|
| 7 |
-
from model import GPT
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
class GPT2Inference:
|
| 11 |
-
""" To generate text sequences using a trained GPT2 model """
|
| 12 |
-
|
| 13 |
-
def __init__(self, model, token_encoder, device):
|
| 14 |
-
self.model = model
|
| 15 |
-
self.token_encoder = token_encoder
|
| 16 |
-
self.device = device
|
| 17 |
-
self.device_type = 'cuda' if device.startswith('cuda') else 'cpu'
|
| 18 |
-
|
| 19 |
-
def generate_sequences(self, prompt, num_seq=5, max_tokens=50):
|
| 20 |
-
self.model.eval()
|
| 21 |
-
tokens = self.token_encoder.encode(prompt)
|
| 22 |
-
tokens = torch.tensor(tokens, dtype=torch.long) # (n,) n : current sequence length
|
| 23 |
-
tokens = tokens.unsqueeze(0).repeat(num_seq, 1) # (1,n) --> (num_seq, n)
|
| 24 |
-
gen_tokens = tokens.to(self.device)
|
| 25 |
-
# create a different rng generator so as not to impact the global rng state used for training
|
| 26 |
-
sample_rng = torch.Generator(device=self.device).manual_seed(42)
|
| 27 |
-
|
| 28 |
-
# generate new tokens one token at a time until the sequence length becomes 'max_tokens'
|
| 29 |
-
while gen_tokens.shape[-1] <= max_tokens:
|
| 30 |
-
with torch.no_grad():
|
| 31 |
-
with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16):
|
| 32 |
-
logits, loss = self.model(gen_tokens) # (num_seq, n, vocab_size)
|
| 33 |
-
logits = logits[:, -1, :] # (num_seq, vocab_size)
|
| 34 |
-
probs = F.softmax(logits, dim=-1) # (num_seq, vocab_size)
|
| 35 |
-
# take top-k 50 probs
|
| 36 |
-
topk_probs, topk_indices = torch.topk(probs, 50, dim=-1) # (num_seq, 50), (num_seq, 50)
|
| 37 |
-
# sample a token from top-50 probabilities
|
| 38 |
-
ix = torch.multinomial(topk_probs, num_samples=1, generator=sample_rng) # (num_seq, 1)
|
| 39 |
-
next_tok = torch.gather(topk_indices, -1, ix) # (num_seq, 1)
|
| 40 |
-
gen_tokens = torch.cat([gen_tokens, next_tok], dim=1)
|
| 41 |
-
# decode generated tokens and print generated text
|
| 42 |
-
for i in range(num_seq):
|
| 43 |
-
tokens = gen_tokens[i, :max_tokens].tolist()
|
| 44 |
-
gen_text = self.token_encoder.decode(tokens)
|
| 45 |
-
print(f"> sample {i}: {gen_text}")
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
def parse_args():
|
| 49 |
-
import argparse
|
| 50 |
-
parser = argparse.ArgumentParser()
|
| 51 |
-
parser.add_argument('--prompt', type=str, default="Hello, I am a language model,")
|
| 52 |
-
parser.add_argument('--num_seq', type=int, default=5)
|
| 53 |
-
parser.add_argument('--max_tokens', type=int, default=50)
|
| 54 |
-
args = parser.parse_args()
|
| 55 |
-
return args
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
@dataclass
|
| 59 |
-
class GPTConfig:
|
| 60 |
-
context_length: int = 1024 # max context / sequence length
|
| 61 |
-
vocab_size: int = 50257 # number of tokens: 50000 BPE merges + 256 bytes tokens + 1 <endoftext> token
|
| 62 |
-
num_layers: int = 12
|
| 63 |
-
embd_size: int = 768 # embedding dim
|
| 64 |
-
num_heads: int = 12
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
def inference(args=None):
|
| 68 |
-
if args is None:
|
| 69 |
-
args = parse_args()
|
| 70 |
-
|
| 71 |
-
device = 'cpu'
|
| 72 |
-
if torch.cuda.is_available():
|
| 73 |
-
device = 'cuda'
|
| 74 |
-
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
| 75 |
-
device = 'mps' # for apple macbook GPUs
|
| 76 |
-
print(f'using device: {device}')
|
| 77 |
-
|
| 78 |
-
model_path = './logs/model_95364.pt'
|
| 79 |
-
checkpoint = torch.load(model_path, weights_only=False)
|
| 80 |
-
print(f"loaded model from: {model_path}")
|
| 81 |
-
# print(checkpoint['model'].keys())
|
| 82 |
-
|
| 83 |
-
model = GPT(config=checkpoint['config'])
|
| 84 |
-
model.load_state_dict(checkpoint['model'])
|
| 85 |
-
model = model.to(device)
|
| 86 |
-
token_encoder = tiktoken.get_encoding('gpt2')
|
| 87 |
-
generator = GPT2Inference(model, token_encoder, device)
|
| 88 |
-
|
| 89 |
-
generator.generate_sequences(args.prompt, args.num_seq, args.max_tokens)
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
if __name__ == '__main__':
|
| 93 |
-
inference()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|