Karthik15 commited on
Commit
9d30381
·
verified ·
1 Parent(s): 35dd74c

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +158 -163
model.py CHANGED
@@ -1,163 +1,158 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import math
5
- from dataclasses import dataclass
6
- import numpy as np
7
- from tqdm.auto import tqdm
8
- from contextlib import nullcontext
9
- import os
10
-
11
- class LayerNorm(nn.Module):
12
- def __init__(self, ndim, bias):
13
- super().__init__()
14
- self.weight = nn.Parameter(torch.ones(ndim))
15
- self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
16
- def forward(self, x):
17
- return F.layer_norm(x, self.weight.shape, self.weight, self.bias, 1e-5)
18
-
19
- class CausalSelfAttention(nn.Module):
20
- def __init__(self, config):
21
- super().__init__()
22
- assert config.n_embd % config.n_head == 0
23
- self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
24
- self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
25
- self.attn_dropout = nn.Dropout(config.dropout)
26
- self.resid_dropout = nn.Dropout(config.dropout)
27
- self.n_head = config.n_head
28
- self.n_embd = config.n_embd
29
- self.flash = hasattr(F, 'scaled_dot_product_attention')
30
- if not self.flash:
31
- self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
32
- .view(1, 1, config.block_size, config.block_size))
33
-
34
- def forward(self, x):
35
- B, T, C = x.size()
36
- q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
37
- k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
38
- q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
39
- v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
40
-
41
- if self.flash:
42
- y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.attn_dropout.p if self.training else 0.0, is_causal=True)
43
- else:
44
- att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
45
- att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
46
- att = F.softmax(att, dim=-1)
47
- att = self.attn_dropout(att)
48
- y = att @ v
49
-
50
- y = y.transpose(1, 2).contiguous().view(B, T, C)
51
- y = self.resid_dropout(self.c_proj(y))
52
- return y
53
-
54
- class MLP(nn.Module):
55
- def __init__(self, config):
56
- super().__init__()
57
- self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
58
- self.gelu = nn.GELU()
59
- self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
60
- self.dropout = nn.Dropout(config.dropout)
61
- def forward(self, x):
62
- return self.dropout(self.c_proj(self.gelu(self.c_fc(x))))
63
-
64
- class Block(nn.Module):
65
- def __init__(self, config):
66
- super().__init__()
67
- self.ln1 = LayerNorm(config.n_embd, config.bias)
68
- self.attn = CausalSelfAttention(config)
69
- self.ln2 = LayerNorm(config.n_embd, config.bias)
70
- self.mlp = MLP(config)
71
- def forward(self, x):
72
- x = x + self.attn(self.ln1(x))
73
- x = x + self.mlp(self.ln2(x))
74
- return x
75
-
76
- @dataclass
77
- class GPTConfig:
78
- block_size: int
79
- vocab_size: int
80
- n_layer: int
81
- n_head: int
82
- n_embd: int
83
- dropout: float = 0.0
84
- bias: bool = True
85
-
86
- class GPT(nn.Module):
87
- def __init__(self, config):
88
- super().__init__()
89
- self.config = config
90
- self.transformer = nn.ModuleDict(dict(
91
- wte=nn.Embedding(config.vocab_size, config.n_embd),
92
- wpe=nn.Embedding(config.block_size, config.n_embd),
93
- drop=nn.Dropout(config.dropout),
94
- h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
95
- ln_f=LayerNorm(config.n_embd, config.bias),
96
- ))
97
- self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
98
- self.transformer.wte.weight = self.lm_head.weight # weight tying
99
-
100
- self.apply(self._init_weights)
101
- for pn, p in self.named_parameters():
102
- if pn.endswith('c_proj.weight'):
103
- nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))
104
-
105
- def _init_weights(self, module):
106
- if isinstance(module, nn.Linear):
107
- nn.init.normal_(module.weight, mean=0.0, std=0.02)
108
- if module.bias is not None:
109
- nn.init.zeros_(module.bias)
110
- elif isinstance(module, nn.Embedding):
111
- nn.init.normal_(module.weight, mean=0.0, std=0.02)
112
-
113
- def forward(self, idx, targets=None):
114
- device = idx.device
115
- b, t = idx.size()
116
- assert t <= self.config.block_size
117
- pos = torch.arange(0, t, dtype=torch.long, device=device)
118
-
119
- tok_emb = self.transformer.wte(idx)
120
- pos_emb = self.transformer.wpe(pos)
121
- x = self.transformer.drop(tok_emb + pos_emb)
122
- for block in self.transformer.h:
123
- x = block(x)
124
- x = self.transformer.ln_f(x)
125
-
126
- if targets is not None:
127
- logits = self.lm_head(x)
128
- loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
129
- return logits, loss
130
- else:
131
- logits = self.lm_head(x[:, [-1], :])
132
- return logits, None
133
-
134
- @torch.no_grad()
135
- def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
136
- """
137
- Generate tokens given a conditioning sequence.
138
- idx: Tensor of shape (B, T)
139
- """
140
- for _ in range(max_new_tokens):
141
- idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
142
- logits, _ = self(idx_cond)
143
- logits = logits[:, -1, :] / temperature
144
- if top_k is not None:
145
- v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
146
- logits[logits < v[:, [-1]]] = -float('Inf')
147
- probs = F.softmax(logits, dim=-1)
148
- idx_next = torch.multinomial(probs, num_samples=1)
149
- idx = torch.cat((idx, idx_next), dim=1)
150
- return idx
151
-
152
-
153
- config = GPTConfig(
154
- vocab_size=50257, # use the tokenizer's vocab size
155
- block_size=128, # or whatever context size you're training with
156
- n_layer=6,
157
- n_head=6,
158
- n_embd=384,
159
- dropout=0.1,
160
- bias=True
161
- )
162
-
163
- model = GPT(config)
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from transformers import PreTrainedModel, PretrainedConfig
6
+
7
+ # ------------------------------
8
+ # Hugging Face-Compatible Config
9
+ # ------------------------------
10
+ class GPTConfig(PretrainedConfig):
11
+ model_type = "custom-gpt"
12
+
13
+ def __init__(self, vocab_size, block_size, n_layer, n_head, n_embd, dropout=0.0, bias=True, **kwargs):
14
+ super().__init__(**kwargs)
15
+ self.vocab_size = vocab_size
16
+ self.block_size = block_size
17
+ self.n_layer = n_layer
18
+ self.n_head = n_head
19
+ self.n_embd = n_embd
20
+ self.dropout = dropout
21
+ self.bias = bias
22
+
23
+ # ------------------------------
24
+ # GPT Model with HF Integration
25
+ # ------------------------------
26
+ class GPT(PreTrainedModel):
27
+ config_class = GPTConfig
28
+
29
+ def __init__(self, config):
30
+ super().__init__(config)
31
+ self.transformer = nn.ModuleDict(dict(
32
+ wte=nn.Embedding(config.vocab_size, config.n_embd),
33
+ wpe=nn.Embedding(config.block_size, config.n_embd),
34
+ drop=nn.Dropout(config.dropout),
35
+ h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
36
+ ln_f=LayerNorm(config.n_embd, config.bias),
37
+ ))
38
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
39
+ self.transformer.wte.weight = self.lm_head.weight # weight tying
40
+
41
+ self.apply(self._init_weights)
42
+ for pn, p in self.named_parameters():
43
+ if pn.endswith('c_proj.weight'):
44
+ nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))
45
+
46
+ def _init_weights(self, module):
47
+ if isinstance(module, nn.Linear):
48
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
49
+ if module.bias is not None:
50
+ nn.init.zeros_(module.bias)
51
+ elif isinstance(module, nn.Embedding):
52
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
53
+
54
+ def forward(self, idx, targets=None):
55
+ device = idx.device
56
+ b, t = idx.size()
57
+ assert t <= self.config.block_size
58
+ pos = torch.arange(0, t, dtype=torch.long, device=device)
59
+
60
+ tok_emb = self.transformer.wte(idx)
61
+ pos_emb = self.transformer.wpe(pos)
62
+ x = self.transformer.drop(tok_emb + pos_emb)
63
+ for block in self.transformer.h:
64
+ x = block(x)
65
+ x = self.transformer.ln_f(x)
66
+
67
+ if targets is not None:
68
+ logits = self.lm_head(x)
69
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
70
+ return logits, loss
71
+ else:
72
+ logits = self.lm_head(x[:, [-1], :])
73
+ return logits, None
74
+
75
+ @torch.no_grad()
76
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
77
+ for _ in range(max_new_tokens):
78
+ idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
79
+ logits, _ = self(idx_cond)
80
+ logits = logits[:, -1, :] / temperature
81
+ if top_k is not None:
82
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
83
+ logits[logits < v[:, [-1]]] = -float('Inf')
84
+ probs = F.softmax(logits, dim=-1)
85
+ idx_next = torch.multinomial(probs, num_samples=1)
86
+ idx = torch.cat((idx, idx_next), dim=1)
87
+ return idx
88
+
89
+ # ------------------------------
90
+ # Building Blocks
91
+ # ------------------------------
92
+ class LayerNorm(nn.Module):
93
+ def __init__(self, ndim, bias):
94
+ super().__init__()
95
+ self.weight = nn.Parameter(torch.ones(ndim))
96
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
97
+ def forward(self, x):
98
+ return F.layer_norm(x, self.weight.shape, self.weight, self.bias, 1e-5)
99
+
100
+ class CausalSelfAttention(nn.Module):
101
+ def __init__(self, config):
102
+ super().__init__()
103
+ assert config.n_embd % config.n_head == 0
104
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
105
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
106
+ self.attn_dropout = nn.Dropout(config.dropout)
107
+ self.resid_dropout = nn.Dropout(config.dropout)
108
+ self.n_head = config.n_head
109
+ self.n_embd = config.n_embd
110
+ self.flash = hasattr(F, 'scaled_dot_product_attention')
111
+ if not self.flash:
112
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
113
+ .view(1, 1, config.block_size, config.block_size))
114
+
115
+ def forward(self, x):
116
+ B, T, C = x.size()
117
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
118
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
119
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
120
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
121
+
122
+ if self.flash:
123
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=None,
124
+ dropout_p=self.attn_dropout.p if self.training else 0.0,
125
+ is_causal=True)
126
+ else:
127
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
128
+ att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
129
+ att = F.softmax(att, dim=-1)
130
+ att = self.attn_dropout(att)
131
+ y = att @ v
132
+
133
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
134
+ y = self.resid_dropout(self.c_proj(y))
135
+ return y
136
+
137
+ class MLP(nn.Module):
138
+ def __init__(self, config):
139
+ super().__init__()
140
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
141
+ self.gelu = nn.GELU()
142
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
143
+ self.dropout = nn.Dropout(config.dropout)
144
+ def forward(self, x):
145
+ return self.dropout(self.c_proj(self.gelu(self.c_fc(x))))
146
+
147
+ class Block(nn.Module):
148
+ def __init__(self, config):
149
+ super().__init__()
150
+ self.ln1 = LayerNorm(config.n_embd, config.bias)
151
+ self.attn = CausalSelfAttention(config)
152
+ self.ln2 = LayerNorm(config.n_embd, config.bias)
153
+ self.mlp = MLP(config)
154
+ def forward(self, x):
155
+ x = x + self.attn(self.ln1(x))
156
+ x = x + self.mlp(self.ln2(x))
157
+ return x
158
+