abhishek4607 commited on
Commit
4624956
·
verified ·
1 Parent(s): aff58f2

Delete inference.py

Browse files
Files changed (1) hide show
  1. inference.py +0 -93
inference.py DELETED
@@ -1,93 +0,0 @@
1
- import numpy as np
2
- import torch
3
- import torch.nn.functional as F
4
- import tiktoken
5
- from dataclasses import dataclass
6
-
7
- from model import GPT
8
-
9
-
10
- class GPT2Inference:
11
- """ To generate text sequences using a trained GPT2 model """
12
-
13
- def __init__(self, model, token_encoder, device):
14
- self.model = model
15
- self.token_encoder = token_encoder
16
- self.device = device
17
- self.device_type = 'cuda' if device.startswith('cuda') else 'cpu'
18
-
19
- def generate_sequences(self, prompt, num_seq=5, max_tokens=50):
20
- self.model.eval()
21
- tokens = self.token_encoder.encode(prompt)
22
- tokens = torch.tensor(tokens, dtype=torch.long) # (n,) n : current sequence length
23
- tokens = tokens.unsqueeze(0).repeat(num_seq, 1) # (1,n) --> (num_seq, n)
24
- gen_tokens = tokens.to(self.device)
25
- # create a different rng generator so as not to impact the global rng state used for training
26
- sample_rng = torch.Generator(device=self.device).manual_seed(42)
27
-
28
- # generate new tokens one token at a time until the sequence length becomes 'max_tokens'
29
- while gen_tokens.shape[-1] <= max_tokens:
30
- with torch.no_grad():
31
- with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16):
32
- logits, loss = self.model(gen_tokens) # (num_seq, n, vocab_size)
33
- logits = logits[:, -1, :] # (num_seq, vocab_size)
34
- probs = F.softmax(logits, dim=-1) # (num_seq, vocab_size)
35
- # take top-k 50 probs
36
- topk_probs, topk_indices = torch.topk(probs, 50, dim=-1) # (num_seq, 50), (num_seq, 50)
37
- # sample a token from top-50 probabilities
38
- ix = torch.multinomial(topk_probs, num_samples=1, generator=sample_rng) # (num_seq, 1)
39
- next_tok = torch.gather(topk_indices, -1, ix) # (num_seq, 1)
40
- gen_tokens = torch.cat([gen_tokens, next_tok], dim=1)
41
- # decode generated tokens and print generated text
42
- for i in range(num_seq):
43
- tokens = gen_tokens[i, :max_tokens].tolist()
44
- gen_text = self.token_encoder.decode(tokens)
45
- print(f"> sample {i}: {gen_text}")
46
-
47
-
48
- def parse_args():
49
- import argparse
50
- parser = argparse.ArgumentParser()
51
- parser.add_argument('--prompt', type=str, default="Hello, I am a language model,")
52
- parser.add_argument('--num_seq', type=int, default=5)
53
- parser.add_argument('--max_tokens', type=int, default=50)
54
- args = parser.parse_args()
55
- return args
56
-
57
-
58
- @dataclass
59
- class GPTConfig:
60
- context_length: int = 1024 # max context / sequence length
61
- vocab_size: int = 50257 # number of tokens: 50000 BPE merges + 256 bytes tokens + 1 <endoftext> token
62
- num_layers: int = 12
63
- embd_size: int = 768 # embedding dim
64
- num_heads: int = 12
65
-
66
-
67
- def inference(args=None):
68
- if args is None:
69
- args = parse_args()
70
-
71
- device = 'cpu'
72
- if torch.cuda.is_available():
73
- device = 'cuda'
74
- elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
75
- device = 'mps' # for apple macbook GPUs
76
- print(f'using device: {device}')
77
-
78
- model_path = './logs/model_95364.pt'
79
- checkpoint = torch.load(model_path, weights_only=False)
80
- print(f"loaded model from: {model_path}")
81
- # print(checkpoint['model'].keys())
82
-
83
- model = GPT(config=checkpoint['config'])
84
- model.load_state_dict(checkpoint['model'])
85
- model = model.to(device)
86
- token_encoder = tiktoken.get_encoding('gpt2')
87
- generator = GPT2Inference(model, token_encoder, device)
88
-
89
- generator.generate_sequences(args.prompt, args.num_seq, args.max_tokens)
90
-
91
-
92
- if __name__ == '__main__':
93
- inference()