Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse files- src/hf_model.py +1 -0
- src/model.py +19 -1
src/hf_model.py
CHANGED
|
@@ -20,6 +20,7 @@ from .model import HexaTransformer as CoreTransformer
|
|
| 20 |
|
| 21 |
class HexaModel(PreTrainedModel):
|
| 22 |
config_class = HexaHFConfig
|
|
|
|
| 23 |
|
| 24 |
def __init__(self, config):
|
| 25 |
super().__init__(config)
|
|
|
|
| 20 |
|
| 21 |
class HexaModel(PreTrainedModel):
|
| 22 |
config_class = HexaHFConfig
|
| 23 |
+
_supports_gradient_checkpointing = True
|
| 24 |
|
| 25 |
def __init__(self, config):
|
| 26 |
super().__init__(config)
|
src/model.py
CHANGED
|
@@ -97,6 +97,8 @@ class HexaTransformer(nn.Module):
|
|
| 97 |
self.pos_emb = RotaryEmbedding(config.dim_head)
|
| 98 |
|
| 99 |
# Transformer Layers
|
|
|
|
|
|
|
| 100 |
self.layers = nn.ModuleList([])
|
| 101 |
for _ in range(config.depth):
|
| 102 |
self.layers.append(TransformerBlock(
|
|
@@ -136,7 +138,23 @@ class HexaTransformer(nn.Module):
|
|
| 136 |
|
| 137 |
# Transformer Pass
|
| 138 |
for layer in self.layers:
|
| 139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
|
| 141 |
x = self.norm_final(x)
|
| 142 |
|
|
|
|
| 97 |
self.pos_emb = RotaryEmbedding(config.dim_head)
|
| 98 |
|
| 99 |
# Transformer Layers
|
| 100 |
+
self.gradient_checkpointing = False # Default
|
| 101 |
+
|
| 102 |
self.layers = nn.ModuleList([])
|
| 103 |
for _ in range(config.depth):
|
| 104 |
self.layers.append(TransformerBlock(
|
|
|
|
| 138 |
|
| 139 |
# Transformer Pass
|
| 140 |
for layer in self.layers:
|
| 141 |
+
if self.training and self.gradient_checkpointing:
|
| 142 |
+
def create_custom_forward(module):
|
| 143 |
+
def custom_forward(*inputs):
|
| 144 |
+
return module(*inputs)
|
| 145 |
+
return custom_forward
|
| 146 |
+
|
| 147 |
+
# Checkpoint requires inputs to have requires_grad=True for at least one input.
|
| 148 |
+
# x usually has it.
|
| 149 |
+
x = torch.utils.checkpoint.checkpoint(
|
| 150 |
+
create_custom_forward(layer),
|
| 151 |
+
x,
|
| 152 |
+
mask,
|
| 153 |
+
rope_emb,
|
| 154 |
+
use_reentrant=False
|
| 155 |
+
)
|
| 156 |
+
else:
|
| 157 |
+
x = layer(x, mask=mask, rope_emb=rope_emb)
|
| 158 |
|
| 159 |
x = self.norm_final(x)
|
| 160 |
|