yagizdevre commited on
Commit
abbb688
·
1 Parent(s): b2a7b14
__pycache__/modeling_minitransformer.cpython-312.pyc CHANGED
Binary files a/__pycache__/modeling_minitransformer.cpython-312.pyc and b/__pycache__/modeling_minitransformer.cpython-312.pyc differ
 
modeling_minitransformer.py CHANGED
@@ -7,6 +7,7 @@ import torch.nn.functional as F
7
 
8
  from transformers import PreTrainedModel, PretrainedConfig
9
  from .configuration_minitransformer import MiniTransformerConfig
 
10
  try:
11
  from flash_attn import flash_attn_func
12
  except ImportError as e:
@@ -220,16 +221,38 @@ class MiniTransformer(PreTrainedModel):
220
  self.apply(self._init_weights)
221
  print("Model Parameter Count: %.2fM\n" % (self._get_num_params() / 1e6,))
222
 
223
- def forward(self, x: torch.Tensor) -> torch.Tensor:
224
- tok_emb = self.tok_emb(x)
225
- x = self.dropout(tok_emb)
 
 
 
 
 
226
 
227
  for layer in self.layers:
228
- x = layer(x, self.freqs_cis)
229
-
230
- y_hat = self.lm_head(self.norm(x))
231
-
232
- return y_hat
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
  def _get_num_params(self):
235
  n_params = sum(p.numel() for p in self.parameters())
 
7
 
8
  from transformers import PreTrainedModel, PretrainedConfig
9
  from .configuration_minitransformer import MiniTransformerConfig
10
+ from transformers.modeling_outputs import CausalLMOutput
11
  try:
12
  from flash_attn import flash_attn_func
13
  except ImportError as e:
 
221
  self.apply(self._init_weights)
222
  print("Model Parameter Count: %.2fM\n" % (self._get_num_params() / 1e6,))
223
 
224
+ def forward(
225
+ self,
226
+ input_ids: torch.Tensor,
227
+ labels: torch.Tensor = None,
228
+ **kwargs
229
+ ) -> CausalLMOutput:
230
+ # Compute embeddings
231
+ tok_emb = self.tok_emb(input_ids)
232
 
233
  for layer in self.layers:
234
+ tok_emb = layer(tok_emb, self.freqs_cis)
235
+
236
+ # Normalize and project to vocabulary
237
+ tok_emb = self.norm(tok_emb)
238
+ logits = self.lm_head(tok_emb)
239
+
240
+ loss = None
241
+ if labels is not None:
242
+ # Shift so that tokens predict the next token
243
+ shift_logits = logits[..., :-1, :].contiguous()
244
+ shift_labels = labels[..., 1:].contiguous()
245
+ loss_fct = nn.CrossEntropyLoss()
246
+ loss = loss_fct(
247
+ shift_logits.view(-1, shift_logits.size(-1)),
248
+ shift_labels.view(-1)
249
+ )
250
+
251
+ return CausalLMOutput(
252
+ loss=loss,
253
+ logits=logits,
254
+ )
255
+
256
 
257
  def _get_num_params(self):
258
  n_params = sum(p.numel() for p in self.parameters())