Yujivus commited on
Commit
3082b32
·
verified ·
1 Parent(s): 92f1584

Upload modeling_baseline.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_baseline.py +12 -21
modeling_baseline.py CHANGED
@@ -2,16 +2,11 @@
2
  import torch
3
  import torch.nn as nn
4
  from transformers import PreTrainedModel
5
-
6
- try:
7
- from .configuration_baseline import BaselineConfig
8
- except ImportError:
9
- from configuration_baseline import BaselineConfig
10
-
11
  try:
12
  from x_transformers import TransformerWrapper, Encoder
13
  except ImportError:
14
- raise ImportError("To use this model, you must run: pip install x-transformers")
15
 
16
  class BaselineModel(PreTrainedModel):
17
  config_class = BaselineConfig
@@ -19,13 +14,11 @@ class BaselineModel(PreTrainedModel):
19
  def __init__(self, config):
20
  super().__init__(config)
21
  self.config = config
22
-
23
- # Instantiate the x_transformers model
24
  self.model = TransformerWrapper(
25
  num_tokens=config.vocab_size,
26
  max_seq_len=config.seq_len,
27
- use_abs_pos_emb=False, # RoPE is enabled in Encoder
28
- tie_embedding=True, # Matches training
29
  attn_layers=Encoder(
30
  dim=config.d_model,
31
  depth=config.depth,
@@ -33,21 +26,19 @@ class BaselineModel(PreTrainedModel):
33
  layer_dropout=config.dropout,
34
  attn_dropout=config.dropout,
35
  ff_dropout=config.dropout,
36
- rotary_pos_emb=True, # Matches training
37
- attn_flash=True, # Matches training
38
  use_scalenorm=False
39
  )
40
  )
 
 
 
 
 
41
 
42
  def forward(self, input_ids, labels=None, mask=None):
43
- # x_transformers takes 'mask' argument if provided
44
  logits = self.model(input_ids, mask=mask)
45
-
46
- loss = None
47
  if labels is not None:
48
- loss_fct = nn.CrossEntropyLoss()
49
- # Reshape for loss calculation
50
- loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
51
- return {"loss": loss, "logits": logits}
52
-
53
  return logits
 
2
  import torch
3
  import torch.nn as nn
4
  from transformers import PreTrainedModel
5
+ from .configuration_baseline import BaselineConfig
 
 
 
 
 
6
  try:
7
  from x_transformers import TransformerWrapper, Encoder
8
  except ImportError:
9
+ raise ImportError("pip install x-transformers")
10
 
11
  class BaselineModel(PreTrainedModel):
12
  config_class = BaselineConfig
 
14
  def __init__(self, config):
15
  super().__init__(config)
16
  self.config = config
 
 
17
  self.model = TransformerWrapper(
18
  num_tokens=config.vocab_size,
19
  max_seq_len=config.seq_len,
20
+ use_abs_pos_emb=False,
21
+ tie_embedding=True,
22
  attn_layers=Encoder(
23
  dim=config.d_model,
24
  depth=config.depth,
 
26
  layer_dropout=config.dropout,
27
  attn_dropout=config.dropout,
28
  ff_dropout=config.dropout,
29
+ rotary_pos_emb=True,
30
+ attn_flash=True,
31
  use_scalenorm=False
32
  )
33
  )
34
+ # TIE FIX
35
+ if hasattr(self.model.token_emb, 'emb'):
36
+ self.model.to_logits.weight = self.model.token_emb.emb.weight
37
+ else:
38
+ self.model.to_logits.weight = self.model.token_emb.weight
39
 
40
  def forward(self, input_ids, labels=None, mask=None):
 
41
  logits = self.model(input_ids, mask=mask)
 
 
42
  if labels is not None:
43
+ return {"loss": nn.CrossEntropyLoss()(logits.view(-1, self.config.vocab_size), labels.view(-1)), "logits": logits}
 
 
 
 
44
  return logits