peeyushsinghal commited on
Commit
6285ada
·
verified ·
1 Parent(s): cbcb2a6

Gradio App

Browse files
Files changed (7) hide show
  1. app.py +74 -0
  2. data.py +32 -0
  3. infer.py +117 -0
  4. main.py +167 -0
  5. model.py +221 -0
  6. requirements.txt +4 -0
  7. utils.py +15 -0
app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from infer import load_model_from_checkpoint, generate_text, InferenceConfig
3
+ from utils import get_device
4
+ from main import GPTConfig, Config
5
+ from torch.serialization import add_safe_globals
6
+ from dataclasses import dataclass
7
+
8
+
9
+
10
+ import warnings
11
+ # Suppress FutureWarnings
12
+ warnings.simplefilter(action="ignore", category=FutureWarning)
13
+
14
+
15
+
16
+ @dataclass
17
+ class AppConfig:
18
+ model_path: str = "checkpoint/model_final.pth"
19
+ num_return_sequences: int = 5
20
+ max_length: int = 50 # max length of the generated text
21
+ tokenizer: str = "gpt2"
22
+
23
+ config = AppConfig()
24
+
25
+ device = get_device()
26
+ add_safe_globals([Config, GPTConfig])
27
+
28
+ model = load_model_from_checkpoint(config.model_path, device=device)
29
+
30
+ def generate(prompt, num_sequences):
31
+ if not prompt:
32
+ return "Please enter a prompt."
33
+
34
+ generated_texts = generate_text(
35
+ model=model,
36
+ prompt=prompt,
37
+ num_return_sequences=num_sequences,
38
+ device=device
39
+ )
40
+
41
+ # Format output with sequence numbers
42
+ formatted_output = ""
43
+ for i, text in enumerate(generated_texts, 1):
44
+ formatted_output += f"**Sequence {i}**:\n{text}\n\n"
45
+
46
+ return formatted_output
47
+
48
+ # Create Gradio interface
49
+ iface = gr.Interface(
50
+ fn=generate,
51
+ inputs=[
52
+ gr.Textbox(
53
+ lines=3,
54
+ placeholder="Enter your prompt here...",
55
+ label="Prompt"
56
+ ),
57
+ gr.Slider(
58
+ minimum=1,
59
+ maximum=5,
60
+ step=1,
61
+ value=3,
62
+ label="Number of Sequences"
63
+ )
64
+ ],
65
+ outputs=gr.Textbox(
66
+ lines=10,
67
+ label="Generated Text"
68
+ ),
69
+ title="Text Generation with GPT",
70
+ description="Enter a prompt and select the number of sequences to generate different variations of text.",
71
+ )
72
+
73
+ if __name__ == "__main__":
74
+ iface.launch()
data.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tiktoken
2
+ import torch
3
+
4
+
5
+ class DataLoaderLite:
6
+ def __init__(self, B, T, file_path, model_type):
7
+ self.B = B
8
+ self.T = T
9
+
10
+ # at init load tokens from disk and store them in memory
11
+ with open(file_path, "r") as f:
12
+ text = f.read()
13
+ enc = tiktoken.get_encoding(model_type)
14
+ tokens = enc.encode(text)
15
+ self.tokens = torch.tensor(tokens)
16
+ print(f"loaded {len(self.tokens)} tokens")
17
+ print(f"1 epoch = {len(self.tokens) // (B * T)} batches")
18
+
19
+ # state
20
+ self.current_position = 0
21
+
22
+ def next_batch(self):
23
+ B, T = self.B, self.T
24
+ buf = self.tokens[self.current_position : self.current_position + B * T + 1]
25
+ x = (buf[:-1]).view(B, T) # inputs
26
+ y = (buf[1:]).view(B, T) # targets
27
+ # advance the position in the tensor
28
+ self.current_position += B * T
29
+ # if loading the next batch would be out of bounds, reset
30
+ if self.current_position + (B * T + 1) > len(self.tokens):
31
+ self.current_position = 0
32
+ return x, y
infer.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tiktoken
2
+ from dataclasses import dataclass
3
+ import torch
4
+ from utils import get_device, set_seed
5
+ from main import GPTConfig, Config
6
+ from torch.serialization import add_safe_globals
7
+
8
+ from model import GPT
9
+
10
+ import warnings
11
+
12
+ # Suppress FutureWarnings
13
+ warnings.simplefilter(action="ignore", category=FutureWarning)
14
+
15
+
16
+ @dataclass
17
+ class InferenceConfig:
18
+ model_path: str = "../checkpoint/model_final.pth"
19
+ num_return_sequences: int = 5
20
+ max_length: int = 100 # max length of the generated text
21
+ tokenizer: str = "gpt2"
22
+
23
+
24
+ config = InferenceConfig()
25
+
26
+
27
+ def encode(text, device, config=config):
28
+ enc = tiktoken.get_encoding(config.tokenizer)
29
+ enc_tensor = torch.tensor(enc.encode(text), dtype=torch.long, device=device)
30
+ enc_tensor = enc_tensor.unsqueeze(0)
31
+ return enc_tensor
32
+
33
+
34
+ def decode(tokens): ...
35
+
36
+
37
+ def generate_text(
38
+ model,
39
+ prompt,
40
+ max_length=config.max_length,
41
+ num_return_sequences=config.num_return_sequences,
42
+ device=get_device(),
43
+ ):
44
+ tokenizer = tiktoken.get_encoding(config.tokenizer)
45
+ input_ids = tokenizer.encode(prompt)
46
+ input_ids = torch.tensor(input_ids, dtype=torch.long, device=device)
47
+ input_ids = input_ids.unsqueeze(0).repeat(num_return_sequences, 1)
48
+ input_ids = input_ids.to(device)
49
+
50
+ #Calculate final length
51
+ final_length = input_ids.shape[1] + max_length
52
+
53
+ #Generate text
54
+ with torch.no_grad():
55
+ while input_ids.shape[1] < final_length:
56
+ logits = model(input_ids)[0]
57
+ next_token_logits = logits[:, -1, :]
58
+ probs = torch.softmax(next_token_logits, dim=-1)
59
+ next_token = torch.multinomial(probs, num_samples=1)
60
+ input_ids = torch.cat([input_ids, next_token], dim=-1)
61
+
62
+
63
+ generated_text = []
64
+ for i in range(num_return_sequences):
65
+ tokens = input_ids[i].tolist()
66
+ text = tokenizer.decode(tokens)
67
+ generated_text.append(text)
68
+
69
+ return generated_text
70
+
71
+
72
+ def load_model_from_checkpoint(model_path, device):
73
+ # Add Config and GPTConfig to safe globals
74
+ add_safe_globals([Config, GPTConfig])
75
+
76
+ try:
77
+ # First try with weights_only=True
78
+ checkpoint = torch.load(model_path, map_location=device, weights_only=True)
79
+ except Exception as e:
80
+ # If that fails, try without weights_only
81
+ checkpoint = torch.load(model_path, map_location=device)
82
+
83
+ # Get the model configuration from the saved GPTConfig
84
+ model_config = checkpoint["model_config"]
85
+
86
+ # Create a new model with this configuration
87
+ model = GPT(model_config)
88
+
89
+ # Load the state dict
90
+ model.load_state_dict(checkpoint["model_state_dict"])
91
+
92
+ # Move to device and set to eval mode
93
+ model.to(device)
94
+ model.eval()
95
+ return model
96
+
97
+
98
+ def inference():
99
+ device = get_device()
100
+
101
+ try:
102
+ model = load_model_from_checkpoint(config.model_path, device=device)
103
+ print("Model loaded successfully")
104
+ # print(model)
105
+ # return model
106
+ except Exception as e:
107
+ print(f"Error loading model: {e}")
108
+ return None
109
+
110
+ context = "To be or not to be, that is the question. "
111
+ generated_text = generate_text(model, context)
112
+ for text in generated_text:
113
+ print(text)
114
+
115
+
116
+ if __name__ == "__main__":
117
+ inference()
main.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model import GPT
2
+ from data import DataLoaderLite
3
+ import torch
4
+ from dataclasses import dataclass # for dataclass, config class
5
+ from utils import get_device, set_seed
6
+
7
+
8
+ @dataclass
9
+ class Config:
10
+ model_name: str = "gpt2"
11
+ seed: int = 1337
12
+ max_return_sequences: int = 5
13
+ file_path: str = "data/input.txt"
14
+ max_length: int = 30
15
+ B: int = 8 # batch size
16
+ T: int = 512 # sequence length
17
+ lr: float = 1e-4 # learning rate
18
+ epochs: int = 5000
19
+ interval: int = 100
20
+ moving_avg_window: int = 100
21
+ best_loss: float = float('inf')
22
+ checkpoint_dir: str = "checkpoint"
23
+ target_loss: float = 0.099999
24
+
25
+
26
+ @dataclass
27
+ class GPTConfig:
28
+ block_size: int = 1024 # max sequence length
29
+ vocab_size: int = (
30
+ 50257 # number of tokens: 50,000 BPE merges + 256 bytes tokens + 1 <|endoftext|> token
31
+ )
32
+ n_layer: int = 12 # number of layers
33
+ n_head: int = 12 # number of heads
34
+ n_embd: int = 768 # embedding dimension
35
+ dropout: float = 0.1 # dropout rate
36
+ bias: bool = True # use bias in attention and feedforward
37
+
38
+
39
+
40
+ def load_model(model_type=None):
41
+ if model_type is not None:
42
+ model = GPT.from_pretrained(model_type=model_type)
43
+ else:
44
+ model_config = GPTConfig()
45
+ model = GPT(model_config)
46
+ return model
47
+
48
+ def count_parameters(model):
49
+ """Count trainable parameters"""
50
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
51
+
52
+ def print_model_summary(model):
53
+ """Print model architecture and parameter count"""
54
+ print("\nModel Architecture:")
55
+ print("=" * 50)
56
+ print(f"Block Size (Context Length): {model.config.block_size}")
57
+ print(f"Vocabulary Size: {model.config.vocab_size}")
58
+ print(f"Number of Layers: {model.config.n_layer}")
59
+ print(f"Number of Heads: {model.config.n_head}")
60
+ print(f"Embedding Dimension: {model.config.n_embd}")
61
+ print(f"Dropout: {model.config.dropout}")
62
+
63
+ # Calculate parameter counts for each component
64
+ token_emb = model.config.vocab_size * model.config.n_embd
65
+ pos_emb = model.config.block_size * model.config.n_embd
66
+
67
+ # Per layer parameters
68
+ attn_params = 4 * model.config.n_embd * model.config.n_embd # Q,K,V, and output projection
69
+ mlp_params = 8 * model.config.n_embd * model.config.n_embd # MLP with 4x expansion
70
+ layer_params = attn_params + mlp_params
71
+
72
+ print("\nParameter Counts:")
73
+ print("-" * 50)
74
+ print(f"Token Embeddings: {token_emb:,}")
75
+ print(f"Position Embeddings: {pos_emb:,}")
76
+ print(f"Per Layer: {layer_params:,}")
77
+ print(f"All Layers: {layer_params * model.config.n_layer:,}")
78
+ print(f"Total Trainable Parameters: {count_parameters(model):,}")
79
+
80
+ # Estimated model size
81
+ model_size_mb = count_parameters(model) * 4 / (1024 * 1024) # 4 bytes per parameter
82
+ half_precision_size = model_size_mb / 2
83
+ print(f"\nEstimated Model Size:")
84
+ print(f"Full Precision (MB): {model_size_mb:.2f}")
85
+ print(f"Half Precision (MB): {half_precision_size:.2f}")
86
+ print("=" * 50 + "\n")
87
+
88
+
89
+ def main():
90
+
91
+ config = Config()
92
+
93
+ # set up device
94
+ device = get_device()
95
+ print(f"Using device: {device}")
96
+
97
+ # set seed
98
+ set_seed(config.seed)
99
+
100
+ # load model
101
+ # model = load_model(config.model_name) # from pretrained
102
+ model = load_model() # from scratch
103
+ # Print model summary
104
+ print_model_summary(model)
105
+ model.to(device)
106
+
107
+ # load dataset
108
+ train_loader = DataLoaderLite(
109
+ B=config.B, T=config.T, file_path=config.file_path, model_type=config.model_name
110
+ )
111
+ # print(train_loader.next_batch()) # check if data is loaded correctly
112
+
113
+ # train model
114
+
115
+ optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=1e-1)
116
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.epochs, eta_min=1e-5)
117
+ losses = []
118
+ best_loss = config.best_loss
119
+
120
+ for i in range(config.epochs):
121
+ x, y = train_loader.next_batch()
122
+ optimizer.zero_grad()
123
+ _, loss = model(x.to(device), y.to(device))
124
+ loss.backward()
125
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # clip gradients to prevent exploding gradients
126
+ optimizer.step()
127
+ scheduler.step()
128
+ losses.append(loss.item())
129
+
130
+ # Calculate moving average loss
131
+ avg_loss = sum(losses[-config.moving_avg_window:]) / min(config.moving_avg_window, len(losses))
132
+
133
+ if loss.item() < best_loss and i > config.interval-1:
134
+ best_loss = loss.item()
135
+ torch.save({"model_state_dict": model.state_dict(),
136
+ "config": config,
137
+ "model_config":GPTConfig()},
138
+ f"{config.checkpoint_dir}/model_best.pth")
139
+ print(f"Model saved at step {i}")
140
+
141
+ if i % config.interval == 0:
142
+ print(f"step{i}, loss: {loss.item():.4f}, best loss: {best_loss:.4f}, moving avg loss: {avg_loss:.4f}, lr: {scheduler.get_last_lr()[0]:.2e}")
143
+
144
+ if avg_loss < config.target_loss:
145
+ print(f"---Target loss reached at step {i}")
146
+ torch.save({"model_state_dict": model.state_dict(),
147
+ "config": config,
148
+ "model_config":GPTConfig()},
149
+ f"{config.checkpoint_dir}/model_final.pth")
150
+ break
151
+
152
+ print(f"Training completed. Best loss: {best_loss:.4f}, final loss: {loss.item():.4f}")
153
+ # save model
154
+ torch.save({"model_state_dict": model.state_dict(),
155
+ "config": config,
156
+ "model_config":GPTConfig()},
157
+ f"{config.checkpoint_dir}/model_final.pth")
158
+ print(f"Model saved to {config.checkpoint_dir}/model_final.pth")
159
+
160
+ # inference
161
+ # print(model)
162
+ return model
163
+
164
+
165
+ if __name__ == "__main__":
166
+ model = main()
167
+ print(model)
model.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import time
4
+ import inspect
5
+ from dataclasses import dataclass
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.nn import functional as F
9
+
10
+
11
+ class CausalSelfAttention(nn.Module):
12
+
13
+ def __init__(self, config):
14
+ super().__init__()
15
+ assert config.n_embd % config.n_head == 0
16
+ # key, query, value projections for all heads, but in a batch
17
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
18
+ # output projection
19
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd)
20
+ self.c_proj.NANGPT_SCALE_INIT = 1
21
+ # regularization
22
+ self.n_head = config.n_head
23
+ self.n_embd = config.n_embd
24
+ self.register_buffer(
25
+ "bias",
26
+ torch.tril(torch.ones(config.block_size, config.block_size)).view(
27
+ 1, 1, config.block_size, config.block_size
28
+ ),
29
+ )
30
+
31
+ def forward(self, x):
32
+ B, T, C = (
33
+ x.size()
34
+ ) # batch size, sequence length, embedding dimensionality (n_embd)
35
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
36
+ # nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
37
+ # e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformer
38
+ qkv = self.c_attn(x)
39
+ q, k, v = qkv.split(self.n_embd, dim=2)
40
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(
41
+ 1, 2
42
+ ) # (B, nh, T, hs)
43
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(
44
+ 1, 2
45
+ ) # (B, nh, T, hs)
46
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(
47
+ 1, 2
48
+ ) # (B, nh, T, hs)
49
+
50
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
51
+ att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
52
+ att = F.softmax(att, dim=-1)
53
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
54
+
55
+ y = (
56
+ y.transpose(1, 2).contiguous().view(B, T, C)
57
+ ) # re-assemble all head outputs side by side
58
+ # output projection
59
+ y = self.c_proj(y)
60
+ return y
61
+
62
+
63
+ class MLP(nn.Module):
64
+
65
+ def __init__(self, config):
66
+ super().__init__()
67
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
68
+ self.gelu = nn.GELU(approximate="tanh")
69
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
70
+ self.c_proj.NANOGPT_SCALE_INIT = 1
71
+
72
+ def forward(self, x):
73
+ x = self.c_fc(x)
74
+ x = self.gelu(x)
75
+ x = self.c_proj(x)
76
+ return x
77
+
78
+
79
+ class Block(nn.Module):
80
+
81
+ def __init__(self, config):
82
+ super().__init__()
83
+ self.ln_1 = nn.LayerNorm(config.n_embd)
84
+ self.attn = CausalSelfAttention(config)
85
+ self.ln_2 = nn.LayerNorm(config.n_embd)
86
+ self.mlp = MLP(config)
87
+
88
+ def forward(self, x):
89
+ x = x + self.attn(self.ln_1(x))
90
+ x = x + self.mlp(self.ln_2(x))
91
+ return x
92
+
93
+
94
+ @dataclass
95
+ class GPTConfig:
96
+ block_size: int = 1024 # max sequence length
97
+ vocab_size: int = (
98
+ 50257 # number of tokens: 50,000 BPE merges + 256 bytes tokens + 1 <|endoftext|> token
99
+ )
100
+ n_layer: int = 12 # number of layers
101
+ n_head: int = 12 # number of heads
102
+ n_embd: int = 768 # embedding dimension
103
+
104
+
105
+ class GPT(nn.Module):
106
+
107
+ def __init__(self, config):
108
+ super().__init__()
109
+ self.config = config
110
+
111
+ self.transformer = nn.ModuleDict(
112
+ dict(
113
+ wte=nn.Embedding(config.vocab_size, config.n_embd),
114
+ wpe=nn.Embedding(config.block_size, config.n_embd),
115
+ h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
116
+ ln_f=nn.LayerNorm(config.n_embd),
117
+ )
118
+ )
119
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
120
+
121
+ # weight sharing
122
+ self.transformer.wte.weight = self.lm_head.weight
123
+
124
+ # weight initialization
125
+ self.apply(self._init_weights)
126
+
127
+ def _init_weights(self, module):
128
+ if isinstance(module, nn.Linear):
129
+ std = 0.02
130
+ if hasattr(module, "NANGPT_SCALE_INIT"):
131
+ std *= (2 * self.config.n_layer) ** -0.5
132
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
133
+ if module.bias is not None:
134
+ torch.nn.init.zeros_(module.bias)
135
+ elif isinstance(module, nn.Embedding):
136
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
137
+
138
+ def forward(self, idx, targets=None):
139
+ # idx is of shape (B, T)
140
+ B, T = idx.size()
141
+ assert (
142
+ T <= self.config.block_size
143
+ ), f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
144
+ # forward the token and posisition embeddings
145
+ pos = torch.arange(0, T, dtype=torch.long, device=idx.device) # shape (T)
146
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (T, n_embd)
147
+ tok_emb = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd)
148
+ x = tok_emb + pos_emb
149
+ # forward the blocks of the transformer
150
+ for block in self.transformer.h:
151
+ x = block(x)
152
+ # forward the final layernorm and the classifier
153
+ x = self.transformer.ln_f(x)
154
+ logits = self.lm_head(x) # (B, T, vocab_size)
155
+ loss = None
156
+ if targets is not None:
157
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
158
+ return logits, loss
159
+
160
+ @classmethod
161
+ def from_pretrained(cls, model_type):
162
+ """Loads pretrained GPT-2 model weights from huggingface"""
163
+ assert model_type in {"gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl"}
164
+ from transformers import GPT2LMHeadModel
165
+
166
+ print("loading weights from pretrained gpt: %s" % model_type)
167
+
168
+ # n_layer, n_head and n_embd are determined from model_type
169
+ config_args = {
170
+ "gpt2": dict(n_layer=12, n_head=12, n_embd=768), # 124M params
171
+ "gpt2-medium": dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
172
+ "gpt2-large": dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
173
+ "gpt2-xl": dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
174
+ }[model_type]
175
+ config_args["vocab_size"] = 50257 # always 50257 for GPT model checkpoints
176
+ config_args["block_size"] = 1024 # always 1024 for GPT model checkpoints
177
+ # create a from-scratch initialized minGPT model
178
+ config = GPTConfig(**config_args)
179
+ model = GPT(config)
180
+ sd = model.state_dict()
181
+ sd_keys = sd.keys()
182
+ sd_keys = [
183
+ k for k in sd_keys if not k.endswith(".attn.bias")
184
+ ] # discard this mask / buffer, not a param
185
+
186
+ # init a huggingface/transformers model
187
+ model_hf = GPT2LMHeadModel.from_pretrained(model_type)
188
+ sd_hf = model_hf.state_dict()
189
+
190
+ # copy while ensuring all of the parameters are aligned and match in names and shapes
191
+ sd_keys_hf = sd_hf.keys()
192
+ sd_keys_hf = [
193
+ k for k in sd_keys_hf if not k.endswith(".attn.masked_bias")
194
+ ] # ignore these, just a buffer
195
+ sd_keys_hf = [
196
+ k for k in sd_keys_hf if not k.endswith(".attn.bias")
197
+ ] # same, just the mask (buffer)
198
+ transposed = [
199
+ "attn.c_attn.weight",
200
+ "attn.c_proj.weight",
201
+ "mlp.c_fc.weight",
202
+ "mlp.c_proj.weight",
203
+ ]
204
+ # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
205
+ # this means that we have to transpose these weights when we import them
206
+ assert len(sd_keys_hf) == len(
207
+ sd_keys
208
+ ), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
209
+ for k in sd_keys_hf:
210
+ if any(k.endswith(w) for w in transposed):
211
+ # special treatment for the Conv1D weights we need to transpose
212
+ assert sd_hf[k].shape[::-1] == sd[k].shape
213
+ with torch.no_grad():
214
+ sd[k].copy_(sd_hf[k].t())
215
+ else:
216
+ # vanilla copy over the other parameters
217
+ assert sd_hf[k].shape == sd[k].shape
218
+ with torch.no_grad():
219
+ sd[k].copy_(sd_hf[k])
220
+
221
+ return model
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ tiktoken
4
+ gradio
utils.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def get_device():
4
+ if torch.cuda.is_available():
5
+ return "cuda"
6
+ elif torch.backends.mps.is_available():
7
+ return "mps"
8
+ else:
9
+ return "cpu"
10
+
11
+
12
+ def set_seed(seed):
13
+ torch.manual_seed(seed)
14
+ if torch.cuda.is_available():
15
+ torch.cuda.manual_seed(seed)