Update modeling_challenger.py
Browse files- modeling_challenger.py +43 -2
modeling_challenger.py
CHANGED
|
@@ -87,7 +87,7 @@ class FP8Linear(torch.nn.Module):
|
|
| 87 |
|
| 88 |
def forward(self, x):
|
| 89 |
"""
|
| 90 |
-
Accepts x of shape (..., in_features) –
|
| 91 |
Flattens to 2‑D, does the FP8 matmul, then restores the shape.
|
| 92 |
"""
|
| 93 |
orig_shape = x.shape[:-1] # e.g. (B, T)
|
|
@@ -155,6 +155,47 @@ class Block(nn.Module):
|
|
| 155 |
def forward(self, x):
|
| 156 |
return x + self.attn(self.ln_1(x)) + self.mlp(self.ln_2(x))
|
| 157 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
class ChallengerModel(PreTrainedModel):
|
| 159 |
config_class = ChallengerConfig
|
| 160 |
|
|
@@ -166,4 +207,4 @@ class ChallengerModel(PreTrainedModel):
|
|
| 166 |
logits, loss = self.model(input_ids, labels)
|
| 167 |
if labels is not None:
|
| 168 |
return {"loss": loss, "logits": logits}
|
| 169 |
-
return {"logits": logits}
|
|
|
|
| 87 |
|
| 88 |
def forward(self, x):
|
| 89 |
"""
|
| 90 |
+
Accepts x of shape (..., in_features) – any leading dims.
|
| 91 |
Flattens to 2‑D, does the FP8 matmul, then restores the shape.
|
| 92 |
"""
|
| 93 |
orig_shape = x.shape[:-1] # e.g. (B, T)
|
|
|
|
| 155 |
def forward(self, x):
|
| 156 |
return x + self.attn(self.ln_1(x)) + self.mlp(self.ln_2(x))
|
| 157 |
|
| 158 |
+
class GPT(nn.Module):
|
| 159 |
+
|
| 160 |
+
def __init__(self, config):
|
| 161 |
+
super().__init__()
|
| 162 |
+
self.config = config
|
| 163 |
+
|
| 164 |
+
self.transformer = nn.ModuleDict(dict(
|
| 165 |
+
wte = nn.Embedding(config.vocab_size, config.n_embd),
|
| 166 |
+
h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
|
| 167 |
+
ln_f = RMSNorm(config.n_embd),
|
| 168 |
+
))
|
| 169 |
+
self.lm_head = FP8Linear(config.n_embd, config.vocab_size, bias=False)
|
| 170 |
+
self.lm_head.NANOGPT_SCALE_INIT = 1
|
| 171 |
+
|
| 172 |
+
# init params
|
| 173 |
+
self.apply(self._init_weights)
|
| 174 |
+
|
| 175 |
+
def _init_weights(self, module):
|
| 176 |
+
if isinstance(module, FP8Linear):
|
| 177 |
+
std = 0.02
|
| 178 |
+
if hasattr(module, 'NANOGPT_SCALE_INIT'):
|
| 179 |
+
std *= (2 * self.config.n_layer) ** -0.5
|
| 180 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
|
| 181 |
+
if module.bias is not None:
|
| 182 |
+
torch.nn.init.zeros_(module.bias)
|
| 183 |
+
elif isinstance(module, nn.Embedding):
|
| 184 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 185 |
+
|
| 186 |
+
def forward(self, idx, targets=None):
|
| 187 |
+
x = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd)
|
| 188 |
+
# forward the blocks of the transformer
|
| 189 |
+
for block in self.transformer.h:
|
| 190 |
+
x = block(x)
|
| 191 |
+
# forward the final layernorm and the classifier
|
| 192 |
+
x = self.transformer.ln_f(x)
|
| 193 |
+
logits = self.lm_head(x).float() # (B, T, vocab_size)
|
| 194 |
+
loss = None
|
| 195 |
+
if targets is not None:
|
| 196 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
| 197 |
+
return logits, loss
|
| 198 |
+
|
| 199 |
class ChallengerModel(PreTrainedModel):
|
| 200 |
config_class = ChallengerConfig
|
| 201 |
|
|
|
|
| 207 |
logits, loss = self.model(input_ids, labels)
|
| 208 |
if labels is not None:
|
| 209 |
return {"loss": loss, "logits": logits}
|
| 210 |
+
return {"logits": logits}
|