fix: _tied_weights_keys as dict + post_init() for transformers 5.3.0
Browse files- modeling_unicosys.py +5 -1
modeling_unicosys.py
CHANGED
|
@@ -169,7 +169,8 @@ class UnicosysHypergraphModel(PreTrainedModel):
|
|
| 169 |
"""
|
| 170 |
|
| 171 |
config_class = UnicosysConfig
|
| 172 |
-
_tied_weights_keys =
|
|
|
|
| 173 |
|
| 174 |
def __init__(self, config: UnicosysConfig):
|
| 175 |
super().__init__(config)
|
|
@@ -205,6 +206,9 @@ class UnicosysHypergraphModel(PreTrainedModel):
|
|
| 205 |
# Initialize weights
|
| 206 |
self.apply(self._init_weights)
|
| 207 |
|
|
|
|
|
|
|
|
|
|
| 208 |
def _init_weights(self, module):
|
| 209 |
if isinstance(module, nn.Linear):
|
| 210 |
nn.init.xavier_uniform_(module.weight)
|
|
|
|
| 169 |
"""
|
| 170 |
|
| 171 |
config_class = UnicosysConfig
|
| 172 |
+
_tied_weights_keys = {}
|
| 173 |
+
supports_gradient_checkpointing = False
|
| 174 |
|
| 175 |
def __init__(self, config: UnicosysConfig):
|
| 176 |
super().__init__(config)
|
|
|
|
| 206 |
# Initialize weights
|
| 207 |
self.apply(self._init_weights)
|
| 208 |
|
| 209 |
+
# Required by transformers >= 5.x for tied weight tracking
|
| 210 |
+
self.post_init()
|
| 211 |
+
|
| 212 |
def _init_weights(self, module):
|
| 213 |
if isinstance(module, nn.Linear):
|
| 214 |
nn.init.xavier_uniform_(module.weight)
|