Fraser commited on
Commit
4232e2f
·
verified ·
1 Parent(s): 547f4ea

add grad ckpt

Browse files
Files changed (1) hide show
  1. modeling_llada.py +63 -5
modeling_llada.py CHANGED
@@ -1094,10 +1094,68 @@ class LLaDABlockGroup(nn.ModuleList):
1094
  block.set_activation_checkpointing(strategy)
1095
 
1096
 
1097
- class LLaDAModel(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1098
  def __init__(self, config: ModelConfig, init_params: bool = True):
1099
- super().__init__()
1100
- self.config = config
1101
  self.__cache = BufferCache()
1102
 
1103
  # Validate config.
@@ -1166,7 +1224,7 @@ class LLaDAModel(nn.Module):
1166
  )
1167
  # When `init_device="meta"` FSDP will call `reset_parameters()` to initialize weights.
1168
  if init_params and self.config.init_device != "meta":
1169
- self.reset_parameters()
1170
  self.__num_fwd_flops: Optional[int] = None
1171
 
1172
  # Warm up cache.
@@ -1455,7 +1513,7 @@ def create_model_config_from_pretrained_config(config: LLaDAConfig):
1455
  return model_config
1456
 
1457
 
1458
- class LLaDAModelLM(PreTrainedModel):
1459
  """
1460
  Extremely barebones HF model wrapper.
1461
  """
 
1094
  block.set_activation_checkpointing(strategy)
1095
 
1096
 
1097
+ class LLaDAPreTrainedModel(PreTrainedModel):
1098
+ """
1099
+ Minimal HF-compatible base to enable gradient checkpointing hooks and centralize
1100
+ parameter initialization.
1101
+ """
1102
+
1103
+ config_class = LLaDAConfig
1104
+ base_model_prefix = "model"
1105
+ _no_split_modules = ["LLaDALlamaBlock"]
1106
+ _supports_gradient_checkpointing = True # backward compat
1107
+ supports_gradient_checkpointing = True # transformers >=4.38
1108
+
1109
+ def __init__(self, config, *model_args, **model_kwargs):
1110
+ hf_config = config
1111
+ if not hasattr(hf_config, "to_dict"):
1112
+ hf_config = LLaDAConfig(**config.__dict__)
1113
+ super().__init__(hf_config, *model_args, **model_kwargs)
1114
+
1115
+ def _init_weights(self, module):
1116
+ if getattr(module, "_llada_params_initialized", False):
1117
+ return
1118
+ if hasattr(module, "reset_parameters"):
1119
+ module.reset_parameters()
1120
+ for child in module.modules():
1121
+ setattr(child, "_llada_params_initialized", True)
1122
+
1123
+ def _set_gradient_checkpointing(
1124
+ self, enable: bool = True, gradient_checkpointing_func: Callable = None
1125
+ ):
1126
+ """
1127
+ New-format hook expected by `PreTrainedModel.gradient_checkpointing_enable`.
1128
+ Only LLaDAModel (the heavy transformer) actually toggles checkpointing.
1129
+ """
1130
+ from torch.utils.checkpoint import checkpoint
1131
+
1132
+ if gradient_checkpointing_func is None:
1133
+ gradient_checkpointing_func = checkpoint
1134
+
1135
+ # When called on the HF wrapper (LLaDAModelLM), reach into the inner LLaDAModel.
1136
+ target = self.model if isinstance(self, LLaDAModelLM) else self
1137
+
1138
+ if isinstance(target, LLaDAModel):
1139
+ target._gradient_checkpointing_func = gradient_checkpointing_func
1140
+ target.gradient_checkpointing = enable
1141
+ strategy = ActivationCheckpointingStrategy.whole_layer if enable else None
1142
+ target.set_activation_checkpointing(strategy)
1143
+ return
1144
+
1145
+ # Fallback: walk modules to find the core model.
1146
+ for module in self.modules():
1147
+ if isinstance(module, LLaDAModel):
1148
+ module._gradient_checkpointing_func = gradient_checkpointing_func
1149
+ module.gradient_checkpointing = enable
1150
+ strategy = ActivationCheckpointingStrategy.whole_layer if enable else None
1151
+ module.set_activation_checkpointing(strategy)
1152
+ break
1153
+
1154
+
1155
+ class LLaDAModel(LLaDAPreTrainedModel):
1156
  def __init__(self, config: ModelConfig, init_params: bool = True):
1157
+ super().__init__(config)
1158
+ self.gradient_checkpointing: bool = False
1159
  self.__cache = BufferCache()
1160
 
1161
  # Validate config.
 
1224
  )
1225
  # When `init_device="meta"` FSDP will call `reset_parameters()` to initialize weights.
1226
  if init_params and self.config.init_device != "meta":
1227
+ self.post_init()
1228
  self.__num_fwd_flops: Optional[int] = None
1229
 
1230
  # Warm up cache.
 
1513
  return model_config
1514
 
1515
 
1516
+ class LLaDAModelLM(LLaDAPreTrainedModel):
1517
  """
1518
  Extremely barebones HF model wrapper.
1519
  """