Commit
·
abbb688
1
Parent(s):
b2a7b14
fixes
Browse files
__pycache__/modeling_minitransformer.cpython-312.pyc
CHANGED
|
Binary files a/__pycache__/modeling_minitransformer.cpython-312.pyc and b/__pycache__/modeling_minitransformer.cpython-312.pyc differ
|
|
|
modeling_minitransformer.py
CHANGED
|
@@ -7,6 +7,7 @@ import torch.nn.functional as F
|
|
| 7 |
|
| 8 |
from transformers import PreTrainedModel, PretrainedConfig
|
| 9 |
from .configuration_minitransformer import MiniTransformerConfig
|
|
|
|
| 10 |
try:
|
| 11 |
from flash_attn import flash_attn_func
|
| 12 |
except ImportError as e:
|
|
@@ -220,16 +221,38 @@ class MiniTransformer(PreTrainedModel):
|
|
| 220 |
self.apply(self._init_weights)
|
| 221 |
print("Model Parameter Count: %.2fM\n" % (self._get_num_params() / 1e6,))
|
| 222 |
|
| 223 |
-
def forward(
|
| 224 |
-
|
| 225 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
|
| 227 |
for layer in self.layers:
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
|
| 234 |
def _get_num_params(self):
|
| 235 |
n_params = sum(p.numel() for p in self.parameters())
|
|
|
|
| 7 |
|
| 8 |
from transformers import PreTrainedModel, PretrainedConfig
|
| 9 |
from .configuration_minitransformer import MiniTransformerConfig
|
| 10 |
+
from transformers.modeling_outputs import CausalLMOutput
|
| 11 |
try:
|
| 12 |
from flash_attn import flash_attn_func
|
| 13 |
except ImportError as e:
|
|
|
|
| 221 |
self.apply(self._init_weights)
|
| 222 |
print("Model Parameter Count: %.2fM\n" % (self._get_num_params() / 1e6,))
|
| 223 |
|
| 224 |
+
def forward(
|
| 225 |
+
self,
|
| 226 |
+
input_ids: torch.Tensor,
|
| 227 |
+
labels: torch.Tensor = None,
|
| 228 |
+
**kwargs
|
| 229 |
+
) -> CausalLMOutput:
|
| 230 |
+
# Compute embeddings
|
| 231 |
+
tok_emb = self.tok_emb(input_ids)
|
| 232 |
|
| 233 |
for layer in self.layers:
|
| 234 |
+
tok_emb = layer(tok_emb, self.freqs_cis)
|
| 235 |
+
|
| 236 |
+
# Normalize and project to vocabulary
|
| 237 |
+
tok_emb = self.norm(tok_emb)
|
| 238 |
+
logits = self.lm_head(tok_emb)
|
| 239 |
+
|
| 240 |
+
loss = None
|
| 241 |
+
if labels is not None:
|
| 242 |
+
# Shift so that tokens predict the next token
|
| 243 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 244 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 245 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 246 |
+
loss = loss_fct(
|
| 247 |
+
shift_logits.view(-1, shift_logits.size(-1)),
|
| 248 |
+
shift_labels.view(-1)
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
return CausalLMOutput(
|
| 252 |
+
loss=loss,
|
| 253 |
+
logits=logits,
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
|
| 257 |
def _get_num_params(self):
|
| 258 |
n_params = sum(p.numel() for p in self.parameters())
|