| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| import tiktoken |
| from dataclasses import dataclass |
|
|
| from model import GPT |
|
|
|
|
| class GPT2Inference: |
| """ To generate text sequences using a trained GPT2 model """ |
|
|
| def __init__(self, model, token_encoder, device): |
| self.model = model |
| self.token_encoder = token_encoder |
| self.device = device |
| self.device_type = 'cuda' if device.startswith('cuda') else 'cpu' |
|
|
| def generate_sequences(self, prompt, num_seq=5, max_tokens=50): |
| self.model.eval() |
| tokens = self.token_encoder.encode(prompt) |
| tokens = torch.tensor(tokens, dtype=torch.long) |
| tokens = tokens.unsqueeze(0).repeat(num_seq, 1) |
| gen_tokens = tokens.to(self.device) |
| |
| sample_rng = torch.Generator(device=self.device).manual_seed(42) |
|
|
| |
| while gen_tokens.shape[-1] <= max_tokens: |
| with torch.no_grad(): |
| with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16): |
| logits, loss = self.model(gen_tokens) |
| logits = logits[:, -1, :] |
| probs = F.softmax(logits, dim=-1) |
| |
| topk_probs, topk_indices = torch.topk(probs, 50, dim=-1) |
| |
| ix = torch.multinomial(topk_probs, num_samples=1, generator=sample_rng) |
| next_tok = torch.gather(topk_indices, -1, ix) |
| gen_tokens = torch.cat([gen_tokens, next_tok], dim=1) |
| |
| for i in range(num_seq): |
| tokens = gen_tokens[i, :max_tokens].tolist() |
| gen_text = self.token_encoder.decode(tokens) |
| print(f"> sample {i}: {gen_text}") |
|
|
|
|
| def parse_args(): |
| import argparse |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--prompt', type=str, default="Hello, I am a language model,") |
| parser.add_argument('--num_seq', type=int, default=5) |
| parser.add_argument('--max_tokens', type=int, default=50) |
| args = parser.parse_args() |
| return args |
|
|
|
|
| @dataclass |
| class GPTConfig: |
| context_length: int = 1024 |
| vocab_size: int = 50257 |
| num_layers: int = 12 |
| embd_size: int = 768 |
| num_heads: int = 12 |
|
|
|
|
| def inference(args=None): |
| if args is None: |
| args = parse_args() |
|
|
| device = 'cpu' |
| if torch.cuda.is_available(): |
| device = 'cuda' |
| elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): |
| device = 'mps' |
| print(f'using device: {device}') |
|
|
| model_path = './logs/model_95364.pt' |
| checkpoint = torch.load(model_path, weights_only=False) |
| print(f"loaded model from: {model_path}") |
| |
|
|
| model = GPT(config=checkpoint['config']) |
| model.load_state_dict(checkpoint['model']) |
| model = model.to(device) |
| token_encoder = tiktoken.get_encoding('gpt2') |
| generator = GPT2Inference(model, token_encoder, device) |
|
|
| generator.generate_sequences(args.prompt, args.num_seq, args.max_tokens) |
|
|
|
|
| if __name__ == '__main__': |
| inference() |
|
|