Hexa09 commited on
Commit
321051a
·
verified ·
1 Parent(s): f13a98d

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. src/hf_model.py +1 -0
  2. src/model.py +19 -1
src/hf_model.py CHANGED
@@ -20,6 +20,7 @@ from .model import HexaTransformer as CoreTransformer
20
 
21
  class HexaModel(PreTrainedModel):
22
  config_class = HexaHFConfig
 
23
 
24
  def __init__(self, config):
25
  super().__init__(config)
 
20
 
21
  class HexaModel(PreTrainedModel):
22
  config_class = HexaHFConfig
23
+ _supports_gradient_checkpointing = True
24
 
25
  def __init__(self, config):
26
  super().__init__(config)
src/model.py CHANGED
@@ -97,6 +97,8 @@ class HexaTransformer(nn.Module):
97
  self.pos_emb = RotaryEmbedding(config.dim_head)
98
 
99
  # Transformer Layers
 
 
100
  self.layers = nn.ModuleList([])
101
  for _ in range(config.depth):
102
  self.layers.append(TransformerBlock(
@@ -136,7 +138,23 @@ class HexaTransformer(nn.Module):
136
 
137
  # Transformer Pass
138
  for layer in self.layers:
139
- x = layer(x, mask=mask, rope_emb=rope_emb)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  x = self.norm_final(x)
142
 
 
97
  self.pos_emb = RotaryEmbedding(config.dim_head)
98
 
99
  # Transformer Layers
100
+ self.gradient_checkpointing = False # Default
101
+
102
  self.layers = nn.ModuleList([])
103
  for _ in range(config.depth):
104
  self.layers.append(TransformerBlock(
 
138
 
139
  # Transformer Pass
140
  for layer in self.layers:
141
+ if self.training and self.gradient_checkpointing:
142
+ def create_custom_forward(module):
143
+ def custom_forward(*inputs):
144
+ return module(*inputs)
145
+ return custom_forward
146
+
147
+ # Checkpoint requires inputs to have requires_grad=True for at least one input.
148
+ # x usually has it.
149
+ x = torch.utils.checkpoint.checkpoint(
150
+ create_custom_forward(layer),
151
+ x,
152
+ mask,
153
+ rope_emb,
154
+ use_reentrant=False
155
+ )
156
+ else:
157
+ x = layer(x, mask=mask, rope_emb=rope_emb)
158
 
159
  x = self.norm_final(x)
160