Mohamed Hassan Ashmawy commited on
Commit
5e488a4
·
verified ·
1 Parent(s): d46a28c

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +311 -0
model.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+ import math
5
+ from dataclasses import dataclass
6
+ from contextlib import nullcontext
7
+ from typing import Literal
8
+
9
+
10
+ class CausalSelfAttention(nn.Module):
11
+ # A causal self-attention layer that supports both flash attention and standard attention.
12
+ def __init__(self, config):
13
+ super().__init__()
14
+ assert config.n_embd % config.n_head == 0 # Ensures the embedding dimension can be evenly split across attention heads.
15
+
16
+ # This linear layer projects input x into query (q), key (k), and value (v) vectors —
17
+ # all at once (so the output is 3× the size).
18
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
19
+
20
+ # After attention is done, this layer projects the output back to the original embedding size.
21
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
22
+
23
+ # Dropout applied to the attention weights (probabilities).
24
+ self.attn_dropout = nn.Dropout(config.dropout)
25
+
26
+ # Dropout applied after the final projection.
27
+ self.resid_dropout = nn.Dropout(config.dropout)
28
+
29
+ # Store values for easy access later.
30
+ self.n_head = config.n_head
31
+ self.n_embd = config.n_embd
32
+
33
+ # Checks whether the efficient Flash Attention API is available in torch.nn.functional.
34
+ self.flash = hasattr(F, "scaled_dot_product_attention")
35
+
36
+ # If Flash Attention is not available, we create a lower triangular mask to ensure causality.
37
+ # This mask prevents the model from attending to future tokens in the sequence.
38
+ if not self.flash:
39
+ # register_buffer ensures this tensor is saved with the model but not updated by gradients.
40
+ self.register_buffer(
41
+ "bias",
42
+ torch.tril(torch.ones(config.block_size, config.block_size)).view(
43
+ 1, 1, config.block_size, config.block_size
44
+ ),
45
+ )
46
+ def forward(self, x):
47
+ B, T, C = x.size()
48
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
49
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
50
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
51
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
52
+
53
+ if self.flash:
54
+ y = F.scaled_dot_product_attention(
55
+ q,
56
+ k,
57
+ v,
58
+ attn_mask=None,
59
+ dropout_p=self.attn_dropout.p if self.training else 0.0,
60
+ is_causal=True,
61
+ )
62
+ else:
63
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
64
+ att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
65
+ att = F.softmax(att, dim=-1)
66
+ att = self.attn_dropout(att)
67
+ y = att @ v
68
+
69
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
70
+ y = self.resid_dropout(self.c_proj(y))
71
+ return y
72
+
73
+ # --- User's Original LayerNorm ---
74
+ class LayerNorm(nn.Module):
75
+ def __init__(self, ndim, bias):
76
+ """
77
+ Initializes the LayerNorm module.
78
+ Args:
79
+ ndim (int): is the number of features in the last dimension (e.g., embedding size).
80
+ bias (bool): Whether to include a bias term in the normalization.
81
+ """
82
+ super().__init__()
83
+ self.weight = nn.Parameter(torch.ones(ndim))
84
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
85
+
86
+ def forward(self, x):
87
+ return F.layer_norm(x, self.weight.shape, self.weight, self.bias, 1e-5)
88
+ # --- End User's Original LayerNorm ---
89
+
90
+ # --- User's Original MLP ---
91
+ class MLP(nn.Module):
92
+ def __init__(self, config):
93
+ super().__init__()
94
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
95
+ self.gelu = nn.GELU()
96
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
97
+ self.dropout = nn.Dropout(config.dropout)
98
+
99
+ def forward(self, x):
100
+ return self.dropout(self.c_proj(self.gelu(self.c_fc(x))))
101
+ # --- End User's Original MLP ---
102
+
103
+ # --- User's Original Block ---
104
+ class Block(nn.Module):
105
+ def __init__(self, config):
106
+ super().__init__()
107
+ self.ln1 = LayerNorm(config.n_embd, config.bias)
108
+ self.attn = CausalSelfAttention(config)
109
+ self.ln2 = LayerNorm(config.n_embd, config.bias)
110
+ self.mlp = MLP(config)
111
+
112
+ def forward(self, x):
113
+ x = x + self.attn(self.ln1(x))
114
+ x = x + self.mlp(self.ln2(x))
115
+ return x
116
+ # --- End User's Original Block ---
117
+
118
+
119
+ # --- User's Original GPTConfig ---
120
+ @dataclass
121
+ class GPTConfig:
122
+ block_size: int
123
+ vocab_size: int
124
+ n_layer: int
125
+ n_head: int
126
+ n_embd: int
127
+ dropout: float = 0.0
128
+ bias: bool = True
129
+ # --- End User's Original GPTConfig ---
130
+
131
+ # --- User's Original TrainingConfig ---
132
+ @dataclass
133
+ class TrainingConfig:
134
+ learning_rate: float = 1e-4 # more stable training, earlier 1e-4
135
+ max_iters: int = 20000 # increase from 25000
136
+ warmup_steps: int = 1000 # smoother initial train, earlier 100
137
+ min_lr: float = 5e-4 # lower rate, earlier 5e-4
138
+ eval_iters: int = 500 # increased from 100
139
+ batch_size: int = 32 # changed from 16, better gradient estimate
140
+ block_size: int = 128 # changed from 64, capture longer range dependencies
141
+ gradient_accumulation_steps: int = 32 # reduced from 50
142
+ device: Literal["cuda", "cpu"] = "cuda" if torch.cuda.is_available() else "cpu"
143
+ device_type: Literal["cuda", "cpu"] = (
144
+ "cuda" if "cuda" in device else "cpu"
145
+ ) # for later use in torch.autocast
146
+ dtype: Literal["bfloat16", "float16"] = (
147
+ "bfloat16"
148
+ if torch.cuda.is_available() and torch.cuda.is_bf16_supported()
149
+ else "float16"
150
+ )
151
+ ptdtype: torch.dtype = {
152
+ "float32": torch.float32,
153
+ "bfloat16": torch.bfloat16,
154
+ "float16": torch.float16,
155
+ }[dtype]
156
+ ctx: nullcontext[None] | torch.autocast = (
157
+ nullcontext()
158
+ if device_type == "cpu"
159
+ else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
160
+ )
161
+ # --- End User's Original TrainingConfig ---
162
+
163
+
164
+ class GPT(nn.Module):
165
+ """
166
+ The main GPT model, now with an optional QA head for Question Answering tasks.
167
+ The QA head will predict start and end token indices of the answer span.
168
+ """
169
+ def __init__(self, config, is_qa_model=False):
170
+ super().__init__()
171
+ assert config.vocab_size is not None
172
+ assert config.block_size is not None
173
+ self.config = config
174
+ self.is_qa_model = is_qa_model
175
+
176
+ self.transformer = nn.ModuleDict(dict(
177
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
178
+ wpe = nn.Embedding(config.block_size, config.n_embd),
179
+ drop = nn.Dropout(config.dropout),
180
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
181
+ ln_f = LayerNorm(config.n_embd, bias=config.bias),
182
+ ))
183
+
184
+ # Language modeling head (for pre-training)
185
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
186
+
187
+ # QA head (for fine-tuning)
188
+ # This will predict start and end logits for the answer span
189
+ if self.is_qa_model:
190
+ self.qa_head = nn.Linear(config.n_embd, 2, bias=False) # 2 outputs: start_logit, end_logit
191
+ else:
192
+ self.qa_head = None # No QA head if not a QA model
193
+
194
+ # tie weights
195
+ self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
196
+
197
+ # init all weights
198
+ self.apply(self._init_weights)
199
+ # apply special scaled init to the residual projections, per GPT-2 paper
200
+ for pn, p in self.named_parameters():
201
+ if pn.endswith('c_proj.weight'):
202
+ torch.nn.init.normal_(p, mean=0.0, std=0.02/((2 * config.n_layer)**0.5))
203
+
204
+ # report number of parameters
205
+ # n_params calculation will differ slightly if QA head is present
206
+ n_params = sum(p.numel() for p in self.parameters())
207
+ # For non-embedding count it excludes token embeddings and positional embeddings.
208
+ non_embedding_params = n_params - self.transformer.wpe.weight.numel()
209
+ print(f"Number of parameters: {non_embedding_params/1e6:.2f}M (excluding positional embeddings)")
210
+
211
+
212
+ def _init_weights(self, module):
213
+ if isinstance(module, nn.Linear):
214
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
215
+ if module.bias is not None:
216
+ torch.nn.init.zeros_(module.bias)
217
+ elif isinstance(module, nn.Embedding):
218
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
219
+
220
+ def forward(self, input_ids, targets=None, attention_mask=None, token_type_ids=None):
221
+ device = input_ids.device
222
+ b, t = input_ids.size()
223
+ assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
224
+ pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
225
+
226
+ # forward the GPT model itself
227
+ tok_emb = self.transformer.wte(input_ids) # token embeddings of shape (b, t, n_embd)
228
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
229
+ x = self.transformer.drop(tok_emb + pos_emb)
230
+ for block in self.transformer.h:
231
+ x = block(x)
232
+ x = self.transformer.ln_f(x)
233
+
234
+ if self.is_qa_model and self.qa_head is not None:
235
+ # For QA, we typically use the pooled output or sequence output directly
236
+ # For extractive QA, we need logits for each token for start/end prediction
237
+ # The output 'x' is (batch_size, sequence_length, n_embd)
238
+ logits = self.qa_head(x) # (batch_size, sequence_length, 2)
239
+ start_logits, end_logits = logits.split(1, dim=-1)
240
+ start_logits = start_logits.squeeze(-1).contiguous() # (batch_size, sequence_length)
241
+ end_logits = end_logits.squeeze(-1).contiguous() # (batch_size, sequence_length)
242
+
243
+ if targets is not None:
244
+ # targets for QA are start_positions and end_positions
245
+ start_positions, end_positions = targets[:, 0], targets[:, 1]
246
+
247
+ # Apply attention mask to logits for valid tokens
248
+ if attention_mask is not None:
249
+ # Tokens that are part of the context (token_type_ids == 1) should be considered for answers
250
+ # and also non-padding tokens (attention_mask == 1)
251
+ valid_tokens_mask = (attention_mask == 1) & (token_type_ids == 1)
252
+
253
+ start_logits = start_logits.masked_fill(~valid_tokens_mask, float('-inf'))
254
+ end_logits = end_logits.masked_fill(~valid_tokens_mask, float('-inf'))
255
+
256
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-100) # Use -100 as ignore_index for consistency
257
+ start_loss = loss_fct(start_logits, start_positions)
258
+ end_loss = loss_fct(end_logits, end_positions)
259
+ total_loss = (start_loss + end_loss) / 2
260
+ return start_logits, end_logits, total_loss
261
+
262
+ return start_logits, end_logits, None # For inference
263
+ else: # Standard language model for pre-training or text generation
264
+ if targets is not None:
265
+ # if we are given some targets (e.g. for training), calculate the loss
266
+ logits = self.lm_head(x)
267
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-100) # Use -100
268
+ else:
269
+ # inference-time mini-optimization: only forward the lm_head on the very last position
270
+ logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
271
+ loss = None
272
+
273
+ return logits, loss
274
+
275
+ @torch.no_grad()
276
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
277
+ """
278
+ Generate tokens given a conditioning sequence.
279
+ idx: Tensor of shape (B, T)
280
+ """
281
+ if self.is_qa_model:
282
+ print("Warning: generate method is not intended for QA models directly.")
283
+ print("Please use the QA forward pass for inference and post-processing.")
284
+ return idx # Or raise an error
285
+
286
+ for _ in range(max_new_tokens):
287
+ idx_cond = (
288
+ idx
289
+ if idx.size(1) <= self.config.block_size
290
+ else idx[:, -self.config.block_size :]
291
+ )
292
+ logits, _ = self(idx_cond)
293
+ logits = logits[:, -1, :] / temperature
294
+ if top_k is not None:
295
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
296
+ logits[logits < v[:, [-1]]] = -float("Inf")
297
+ probs = F.softmax(logits, dim=-1)
298
+ idx_next = torch.multinomial(probs, num_samples=1)
299
+ idx = torch.cat((idx, idx_next), dim=1)
300
+ return idx
301
+
302
+ # The 'config' object for pre-training is also kept here, if it's used by other scripts for its definition
303
+ config = GPTConfig(
304
+ vocab_size=50257, # use the tokenizer's vocab size
305
+ block_size=1024, # or whatever context size you're training with
306
+ n_layer=8,
307
+ n_head=8,
308
+ n_embd=512,
309
+ dropout=0.1,
310
+ bias=True,
311
+ )