MaxiiMin commited on
Commit
a67ae76
·
verified ·
1 Parent(s): 698b5ef

Update modeling_challenger.py

Browse files
Files changed (1) hide show
  1. modeling_challenger.py +43 -2
modeling_challenger.py CHANGED
@@ -87,7 +87,7 @@ class FP8Linear(torch.nn.Module):
87
 
88
  def forward(self, x):
89
  """
90
- Accepts x of shape (..., in_features) – any leading dims.
91
  Flattens to 2‑D, does the FP8 matmul, then restores the shape.
92
  """
93
  orig_shape = x.shape[:-1] # e.g. (B, T)
@@ -155,6 +155,47 @@ class Block(nn.Module):
155
  def forward(self, x):
156
  return x + self.attn(self.ln_1(x)) + self.mlp(self.ln_2(x))
157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  class ChallengerModel(PreTrainedModel):
159
  config_class = ChallengerConfig
160
 
@@ -166,4 +207,4 @@ class ChallengerModel(PreTrainedModel):
166
  logits, loss = self.model(input_ids, labels)
167
  if labels is not None:
168
  return {"loss": loss, "logits": logits}
169
- return {"logits": logits}
 
87
 
88
  def forward(self, x):
89
  """
90
+ Accepts x of shape (..., in_features) – any leading dims.
91
  Flattens to 2‑D, does the FP8 matmul, then restores the shape.
92
  """
93
  orig_shape = x.shape[:-1] # e.g. (B, T)
 
155
  def forward(self, x):
156
  return x + self.attn(self.ln_1(x)) + self.mlp(self.ln_2(x))
157
 
158
+ class GPT(nn.Module):
159
+
160
+ def __init__(self, config):
161
+ super().__init__()
162
+ self.config = config
163
+
164
+ self.transformer = nn.ModuleDict(dict(
165
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
166
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
167
+ ln_f = RMSNorm(config.n_embd),
168
+ ))
169
+ self.lm_head = FP8Linear(config.n_embd, config.vocab_size, bias=False)
170
+ self.lm_head.NANOGPT_SCALE_INIT = 1
171
+
172
+ # init params
173
+ self.apply(self._init_weights)
174
+
175
+ def _init_weights(self, module):
176
+ if isinstance(module, FP8Linear):
177
+ std = 0.02
178
+ if hasattr(module, 'NANOGPT_SCALE_INIT'):
179
+ std *= (2 * self.config.n_layer) ** -0.5
180
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
181
+ if module.bias is not None:
182
+ torch.nn.init.zeros_(module.bias)
183
+ elif isinstance(module, nn.Embedding):
184
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
185
+
186
+ def forward(self, idx, targets=None):
187
+ x = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd)
188
+ # forward the blocks of the transformer
189
+ for block in self.transformer.h:
190
+ x = block(x)
191
+ # forward the final layernorm and the classifier
192
+ x = self.transformer.ln_f(x)
193
+ logits = self.lm_head(x).float() # (B, T, vocab_size)
194
+ loss = None
195
+ if targets is not None:
196
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
197
+ return logits, loss
198
+
199
  class ChallengerModel(PreTrainedModel):
200
  config_class = ChallengerConfig
201
 
 
207
  logits, loss = self.model(input_ids, labels)
208
  if labels is not None:
209
  return {"loss": loss, "logits": logits}
210
+ return {"logits": logits}