add grad ckpt
Browse files- modeling_llada.py +63 -5
modeling_llada.py
CHANGED
|
@@ -1094,10 +1094,68 @@ class LLaDABlockGroup(nn.ModuleList):
|
|
| 1094 |
block.set_activation_checkpointing(strategy)
|
| 1095 |
|
| 1096 |
|
| 1097 |
-
class
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1098 |
def __init__(self, config: ModelConfig, init_params: bool = True):
|
| 1099 |
-
super().__init__()
|
| 1100 |
-
self.
|
| 1101 |
self.__cache = BufferCache()
|
| 1102 |
|
| 1103 |
# Validate config.
|
|
@@ -1166,7 +1224,7 @@ class LLaDAModel(nn.Module):
|
|
| 1166 |
)
|
| 1167 |
# When `init_device="meta"` FSDP will call `reset_parameters()` to initialize weights.
|
| 1168 |
if init_params and self.config.init_device != "meta":
|
| 1169 |
-
self.
|
| 1170 |
self.__num_fwd_flops: Optional[int] = None
|
| 1171 |
|
| 1172 |
# Warm up cache.
|
|
@@ -1455,7 +1513,7 @@ def create_model_config_from_pretrained_config(config: LLaDAConfig):
|
|
| 1455 |
return model_config
|
| 1456 |
|
| 1457 |
|
| 1458 |
-
class LLaDAModelLM(
|
| 1459 |
"""
|
| 1460 |
Extremely barebones HF model wrapper.
|
| 1461 |
"""
|
|
|
|
| 1094 |
block.set_activation_checkpointing(strategy)
|
| 1095 |
|
| 1096 |
|
| 1097 |
+
class LLaDAPreTrainedModel(PreTrainedModel):
|
| 1098 |
+
"""
|
| 1099 |
+
Minimal HF-compatible base to enable gradient checkpointing hooks and centralize
|
| 1100 |
+
parameter initialization.
|
| 1101 |
+
"""
|
| 1102 |
+
|
| 1103 |
+
config_class = LLaDAConfig
|
| 1104 |
+
base_model_prefix = "model"
|
| 1105 |
+
_no_split_modules = ["LLaDALlamaBlock"]
|
| 1106 |
+
_supports_gradient_checkpointing = True # backward compat
|
| 1107 |
+
supports_gradient_checkpointing = True # transformers >=4.38
|
| 1108 |
+
|
| 1109 |
+
def __init__(self, config, *model_args, **model_kwargs):
|
| 1110 |
+
hf_config = config
|
| 1111 |
+
if not hasattr(hf_config, "to_dict"):
|
| 1112 |
+
hf_config = LLaDAConfig(**config.__dict__)
|
| 1113 |
+
super().__init__(hf_config, *model_args, **model_kwargs)
|
| 1114 |
+
|
| 1115 |
+
def _init_weights(self, module):
|
| 1116 |
+
if getattr(module, "_llada_params_initialized", False):
|
| 1117 |
+
return
|
| 1118 |
+
if hasattr(module, "reset_parameters"):
|
| 1119 |
+
module.reset_parameters()
|
| 1120 |
+
for child in module.modules():
|
| 1121 |
+
setattr(child, "_llada_params_initialized", True)
|
| 1122 |
+
|
| 1123 |
+
def _set_gradient_checkpointing(
|
| 1124 |
+
self, enable: bool = True, gradient_checkpointing_func: Callable = None
|
| 1125 |
+
):
|
| 1126 |
+
"""
|
| 1127 |
+
New-format hook expected by `PreTrainedModel.gradient_checkpointing_enable`.
|
| 1128 |
+
Only LLaDAModel (the heavy transformer) actually toggles checkpointing.
|
| 1129 |
+
"""
|
| 1130 |
+
from torch.utils.checkpoint import checkpoint
|
| 1131 |
+
|
| 1132 |
+
if gradient_checkpointing_func is None:
|
| 1133 |
+
gradient_checkpointing_func = checkpoint
|
| 1134 |
+
|
| 1135 |
+
# When called on the HF wrapper (LLaDAModelLM), reach into the inner LLaDAModel.
|
| 1136 |
+
target = self.model if isinstance(self, LLaDAModelLM) else self
|
| 1137 |
+
|
| 1138 |
+
if isinstance(target, LLaDAModel):
|
| 1139 |
+
target._gradient_checkpointing_func = gradient_checkpointing_func
|
| 1140 |
+
target.gradient_checkpointing = enable
|
| 1141 |
+
strategy = ActivationCheckpointingStrategy.whole_layer if enable else None
|
| 1142 |
+
target.set_activation_checkpointing(strategy)
|
| 1143 |
+
return
|
| 1144 |
+
|
| 1145 |
+
# Fallback: walk modules to find the core model.
|
| 1146 |
+
for module in self.modules():
|
| 1147 |
+
if isinstance(module, LLaDAModel):
|
| 1148 |
+
module._gradient_checkpointing_func = gradient_checkpointing_func
|
| 1149 |
+
module.gradient_checkpointing = enable
|
| 1150 |
+
strategy = ActivationCheckpointingStrategy.whole_layer if enable else None
|
| 1151 |
+
module.set_activation_checkpointing(strategy)
|
| 1152 |
+
break
|
| 1153 |
+
|
| 1154 |
+
|
| 1155 |
+
class LLaDAModel(LLaDAPreTrainedModel):
|
| 1156 |
def __init__(self, config: ModelConfig, init_params: bool = True):
|
| 1157 |
+
super().__init__(config)
|
| 1158 |
+
self.gradient_checkpointing: bool = False
|
| 1159 |
self.__cache = BufferCache()
|
| 1160 |
|
| 1161 |
# Validate config.
|
|
|
|
| 1224 |
)
|
| 1225 |
# When `init_device="meta"` FSDP will call `reset_parameters()` to initialize weights.
|
| 1226 |
if init_params and self.config.init_device != "meta":
|
| 1227 |
+
self.post_init()
|
| 1228 |
self.__num_fwd_flops: Optional[int] = None
|
| 1229 |
|
| 1230 |
# Warm up cache.
|
|
|
|
| 1513 |
return model_config
|
| 1514 |
|
| 1515 |
|
| 1516 |
+
class LLaDAModelLM(LLaDAPreTrainedModel):
|
| 1517 |
"""
|
| 1518 |
Extremely barebones HF model wrapper.
|
| 1519 |
"""
|