Fraser commited on
Commit
f98df9d
·
verified ·
1 Parent(s): 1c92536

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. config.json +170 -117
  2. model.safetensors +1 -1
  3. modeling_llada.py +5 -63
  4. modeling_recursive.py +9 -10
config.json CHANGED
@@ -1,157 +1,210 @@
1
  {
2
- "architectures": [
3
- "RecursiveMaskedLM"
4
- ],
5
- "auto_map": {
6
- "AutoConfig": "configuration_recursive.RecursiveMLMConfig",
7
- "AutoModel": "modeling_recursive.RecursiveMaskedLM"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  "base_model_config": {
10
- "_name_or_path": "Fraser/LLaDA-8B-Base-gg2m",
11
- "activation_type": "silu",
 
 
 
 
 
 
 
 
12
  "add_cross_attention": false,
13
- "alibi": false,
14
- "alibi_bias_max": 8.0,
15
  "architectures": [
16
  "LLaDAModelLM"
17
  ],
18
- "attention_dropout": 0.0,
19
- "attention_layer_norm": false,
20
- "attention_layer_norm_with_affine": true,
21
- "auto_map": {
22
- "AutoConfig": "configuration_llada.LLaDAConfig",
23
- "AutoModel": "modeling_llada.LLaDAModelLM",
24
- "AutoModelForCausalLM": "modeling_llada.LLaDAModelLM"
25
- },
26
- "bad_words_ids": null,
27
- "begin_suppress_tokens": null,
28
- "bias_for_layer_norm": false,
29
- "block_group_size": 1,
30
- "block_type": "llama",
31
- "bos_token_id": 75,
32
- "chunk_size_feed_forward": 0,
33
- "cross_attention_hidden_size": null,
34
- "d_model": 4096,
35
- "decoder_start_token_id": null,
36
- "diversity_penalty": 0.0,
37
- "do_sample": false,
38
- "dtype": "bfloat16",
39
- "early_stopping": false,
40
- "embedding_dropout": 0.0,
41
- "embedding_size": 85,
42
- "encoder_no_repeat_ngram_size": 0,
43
- "eos_token_id": 76,
44
- "exponential_decay_length_penalty": null,
45
  "finetuning_task": null,
46
- "flash_attention": false,
47
- "forced_bos_token_id": null,
48
- "forced_eos_token_id": null,
49
  "id2label": {
50
  "0": "LABEL_0",
51
  "1": "LABEL_1"
52
  },
53
- "include_bias": false,
54
- "include_qkv_bias": false,
55
- "init_cutoff_factor": null,
56
- "init_device": "meta",
57
- "init_fn": "mitchell",
58
- "init_std": 0.02,
59
- "input_emb_norm": false,
60
- "is_decoder": false,
61
- "is_encoder_decoder": false,
62
  "label2id": {
63
  "LABEL_0": 0,
64
  "LABEL_1": 1
65
  },
66
- "layer_norm_type": "rms",
67
- "layer_norm_with_affine": true,
68
- "length_penalty": 1.0,
69
- "mask_token_id": 78,
 
 
 
 
 
70
  "max_length": 20,
71
- "max_sequence_length": 4096,
72
  "min_length": 0,
73
- "mlp_hidden_size": 12288,
74
- "mlp_ratio": 4,
75
- "model_type": "llada",
76
- "multi_query_attention": null,
77
- "n_heads": 32,
78
- "n_kv_heads": 32,
79
- "n_layers": 32,
80
- "no_repeat_ngram_size": 0,
81
- "num_beam_groups": 1,
82
  "num_beams": 1,
 
 
 
 
 
 
 
 
 
83
  "num_return_sequences": 1,
84
- "output_attentions": false,
85
- "output_hidden_states": false,
86
  "output_scores": false,
87
- "pad_token_id": 76,
88
- "precision": "amp_bf16",
89
- "prefix": null,
90
- "problem_type": null,
91
- "pruned_heads": {},
92
- "remove_invalid_values": false,
93
- "repetition_penalty": 1.0,
94
- "residual_dropout": 0.0,
95
- "return_dict": true,
96
  "return_dict_in_generate": false,
97
- "rms_norm_eps": 1e-05,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  "rope": true,
99
  "rope_full_precision": true,
 
 
 
 
 
 
 
 
 
 
 
 
100
  "rope_theta": 500000.0,
 
 
 
101
  "scale_logits": false,
102
- "sep_token_id": null,
103
- "suppress_tokens": null,
104
- "task_specific_params": null,
105
- "temperature": 1.0,
 
 
 
 
 
 
 
 
 
 
 
 
106
  "tf_legacy_loss": false,
107
- "tie_encoder_decoder": false,
108
- "tie_word_embeddings": true,
109
- "tokenizer_class": null,
110
- "top_k": 50,
111
- "top_p": 1.0,
112
- "torchscript": false,
113
- "transformers_version": "4.57.0",
114
- "typical_p": 1.0,
115
  "use_bfloat16": false,
116
- "use_cache": false,
117
- "vocab_size": 85,
118
- "weight_tying": false
119
  },
120
- "bos_token_id": 75,
 
 
 
 
 
121
  "causal_strength": 1.0,
122
- "dtype": "bfloat16",
123
- "entropy_floor_max": 0.0,
124
  "entropy_target_max": 0.0,
125
- "eos_token_id": 76,
 
 
 
 
 
 
126
  "flow_matching_enabled": false,
127
  "flow_matching_lambda": 0.5,
128
- "flow_matching_mask_scale": false,
129
- "flow_matching_noise_scale": 2.0,
130
  "flow_matching_t_distribution": "logit_normal",
131
  "flow_matching_t_logit_mean": -0.4,
132
  "flow_matching_t_logit_std": 1.0,
133
- "flow_matching_t_max": 0.99,
134
  "flow_matching_t_min": 0.01,
135
- "gradient_steps": null,
136
- "iteration_rope_dim_fraction": 0.0,
137
- "loss_weight": "linear",
138
- "mask_token_id": 78,
139
- "model_type": "recursive-mlm",
140
- "noise_std_max": 0.0,
141
- "normalization": "softmax",
142
- "num_recursions": 4,
143
- "pad_token_id": 76,
144
- "schedule": "linear",
145
  "self_distillation_enabled": false,
146
  "self_distillation_lambda": 0.5,
147
- "self_distillation_teacher": "first",
148
- "self_distillation_temperature_distribution": "log_uniform",
149
- "self_distillation_temperature_max": 10.0,
150
  "self_distillation_temperature_min": 1.5,
151
- "smear_sigma_max": 0.0,
152
- "soft_embedding_ema_step": 1.0,
153
- "soft_embedding_method": "softmax",
154
- "temperature_max": 0.0,
155
- "transformers_version": "4.57.0",
156
- "use_recursion_checkpointing": true
157
- }
 
 
 
 
1
  {
2
+ "return_dict": true,
3
+ "output_hidden_states": false,
4
+ "torchscript": false,
5
+ "dtype": null,
6
+ "pruned_heads": {},
7
+ "tie_word_embeddings": false,
8
+ "chunk_size_feed_forward": 0,
9
+ "is_encoder_decoder": false,
10
+ "is_decoder": false,
11
+ "cross_attention_hidden_size": null,
12
+ "add_cross_attention": false,
13
+ "tie_encoder_decoder": false,
14
+ "architectures": ["RecursiveMaskedLM"],
15
+ "finetuning_task": null,
16
+ "id2label": {
17
+ "0": "LABEL_0",
18
+ "1": "LABEL_1"
19
+ },
20
+ "label2id": {
21
+ "LABEL_0": 0,
22
+ "LABEL_1": 1
23
  },
24
+ "task_specific_params": null,
25
+ "problem_type": null,
26
+ "tokenizer_class": null,
27
+ "prefix": null,
28
+ "bos_token_id": null,
29
+ "pad_token_id": null,
30
+ "eos_token_id": null,
31
+ "sep_token_id": null,
32
+ "decoder_start_token_id": null,
33
+ "max_length": 20,
34
+ "min_length": 0,
35
+ "do_sample": false,
36
+ "early_stopping": false,
37
+ "num_beams": 1,
38
+ "temperature": 1.0,
39
+ "top_k": 50,
40
+ "top_p": 1.0,
41
+ "typical_p": 1.0,
42
+ "repetition_penalty": 1.0,
43
+ "length_penalty": 1.0,
44
+ "no_repeat_ngram_size": 0,
45
+ "encoder_no_repeat_ngram_size": 0,
46
+ "bad_words_ids": null,
47
+ "num_return_sequences": 1,
48
+ "output_scores": false,
49
+ "return_dict_in_generate": false,
50
+ "forced_bos_token_id": null,
51
+ "forced_eos_token_id": null,
52
+ "remove_invalid_values": false,
53
+ "exponential_decay_length_penalty": null,
54
+ "suppress_tokens": null,
55
+ "begin_suppress_tokens": null,
56
+ "num_beam_groups": 1,
57
+ "diversity_penalty": 0.0,
58
+ "_name_or_path": "",
59
+ "transformers_version": "4.57.0",
60
+ "tf_legacy_loss": false,
61
+ "use_bfloat16": false,
62
  "base_model_config": {
63
+ "return_dict": true,
64
+ "output_hidden_states": false,
65
+ "torchscript": false,
66
+ "dtype": "bfloat16",
67
+ "pruned_heads": {},
68
+ "tie_word_embeddings": false,
69
+ "chunk_size_feed_forward": 0,
70
+ "is_encoder_decoder": false,
71
+ "is_decoder": false,
72
+ "cross_attention_hidden_size": null,
73
  "add_cross_attention": false,
74
+ "tie_encoder_decoder": false,
 
75
  "architectures": [
76
  "LLaDAModelLM"
77
  ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  "finetuning_task": null,
 
 
 
79
  "id2label": {
80
  "0": "LABEL_0",
81
  "1": "LABEL_1"
82
  },
 
 
 
 
 
 
 
 
 
83
  "label2id": {
84
  "LABEL_0": 0,
85
  "LABEL_1": 1
86
  },
87
+ "task_specific_params": null,
88
+ "problem_type": null,
89
+ "tokenizer_class": null,
90
+ "prefix": null,
91
+ "bos_token_id": null,
92
+ "pad_token_id": 76,
93
+ "eos_token_id": 76,
94
+ "sep_token_id": null,
95
+ "decoder_start_token_id": null,
96
  "max_length": 20,
 
97
  "min_length": 0,
98
+ "do_sample": false,
99
+ "early_stopping": false,
 
 
 
 
 
 
 
100
  "num_beams": 1,
101
+ "temperature": 1.0,
102
+ "top_k": 50,
103
+ "top_p": 1.0,
104
+ "typical_p": 1.0,
105
+ "repetition_penalty": 1.0,
106
+ "length_penalty": 1.0,
107
+ "no_repeat_ngram_size": 0,
108
+ "encoder_no_repeat_ngram_size": 0,
109
+ "bad_words_ids": null,
110
  "num_return_sequences": 1,
 
 
111
  "output_scores": false,
 
 
 
 
 
 
 
 
 
112
  "return_dict_in_generate": false,
113
+ "forced_bos_token_id": null,
114
+ "forced_eos_token_id": null,
115
+ "remove_invalid_values": false,
116
+ "exponential_decay_length_penalty": null,
117
+ "suppress_tokens": null,
118
+ "begin_suppress_tokens": null,
119
+ "num_beam_groups": 1,
120
+ "diversity_penalty": 0.0,
121
+ "_name_or_path": "Fraser/LLaDA-8B-Base-gg2m",
122
+ "transformers_version": "4.57.0",
123
+ "d_model": 4096,
124
+ "n_heads": 32,
125
+ "n_kv_heads": 32,
126
+ "n_layers": 32,
127
+ "mlp_ratio": 4,
128
+ "mlp_hidden_size": 12288,
129
+ "activation_type": "silu",
130
+ "block_type": "llama",
131
+ "block_group_size": 1,
132
+ "alibi": false,
133
+ "alibi_bias_max": 8.0,
134
  "rope": true,
135
  "rope_full_precision": true,
136
+ "flash_attention": false,
137
+ "attention_dropout": 0.0,
138
+ "multi_query_attention": null,
139
+ "attention_layer_norm": false,
140
+ "residual_dropout": 0.0,
141
+ "embedding_dropout": 0.0,
142
+ "input_emb_norm": false,
143
+ "layer_norm_type": "rms",
144
+ "layer_norm_with_affine": true,
145
+ "rms_norm_eps": 1e-05,
146
+ "attention_layer_norm_with_affine": true,
147
+ "max_sequence_length": 4096,
148
  "rope_theta": 500000.0,
149
+ "include_qkv_bias": false,
150
+ "include_bias": false,
151
+ "bias_for_layer_norm": false,
152
  "scale_logits": false,
153
+ "vocab_size": 85,
154
+ "embedding_size": 85,
155
+ "weight_tying": false,
156
+ "mask_token_id": 78,
157
+ "init_device": "meta",
158
+ "init_fn": "mitchell",
159
+ "init_std": 0.02,
160
+ "init_cutoff_factor": null,
161
+ "precision": "amp_bf16",
162
+ "auto_map": {
163
+ "AutoConfig": "configuration_llada.LLaDAConfig",
164
+ "AutoModelForCausalLM": "modeling_llada.LLaDAModelLM",
165
+ "AutoModel": "modeling_llada.LLaDAModelLM"
166
+ },
167
+ "model_type": "llada",
168
+ "use_cache": false,
169
  "tf_legacy_loss": false,
 
 
 
 
 
 
 
 
170
  "use_bfloat16": false,
171
+ "output_attentions": false
 
 
172
  },
173
+ "num_recursions": 4,
174
+ "normalization": "softmax",
175
+ "loss_weight": "linear",
176
+ "mask_token_id": 78,
177
+ "gradient_steps": null,
178
+ "schedule": "linear",
179
  "causal_strength": 1.0,
180
+ "temperature_max": 0.0,
 
181
  "entropy_target_max": 0.0,
182
+ "entropy_floor_max": 0.0,
183
+ "smear_sigma_max": 0.0,
184
+ "noise_std_max": 0.0,
185
+ "iteration_rope_dim_fraction": 0.0,
186
+ "use_recursion_checkpointing": true,
187
+ "soft_embedding_method": "softmax",
188
+ "soft_embedding_ema_step": 1.0,
189
  "flow_matching_enabled": false,
190
  "flow_matching_lambda": 0.5,
 
 
191
  "flow_matching_t_distribution": "logit_normal",
192
  "flow_matching_t_logit_mean": -0.4,
193
  "flow_matching_t_logit_std": 1.0,
 
194
  "flow_matching_t_min": 0.01,
195
+ "flow_matching_t_max": 0.99,
196
+ "flow_matching_noise_scale": 2.0,
197
+ "flow_matching_mask_scale": false,
 
 
 
 
 
 
 
198
  "self_distillation_enabled": false,
199
  "self_distillation_lambda": 0.5,
 
 
 
200
  "self_distillation_temperature_min": 1.5,
201
+ "self_distillation_temperature_max": 10.0,
202
+ "self_distillation_temperature_distribution": "log_uniform",
203
+ "self_distillation_teacher": "first",
204
+ "model_type": "recursive-mlm",
205
+ "output_attentions": false,
206
+ "auto_map": {
207
+ "AutoConfig": "configuration_recursive.RecursiveMLMConfig",
208
+ "AutoModel": "modeling_recursive.RecursiveMaskedLM"
209
+ }
210
+ }
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:8b879d9aaefbb88ca28a3babf863b2ab1ec1ef00bbf82ef7f3d7ddf7284dd968
3
  size 13960604928
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1ebfe2c41adc724c0a164e5ea1efbaabc4c89a0da6c51e74e4d51502218219e0
3
  size 13960604928
modeling_llada.py CHANGED
@@ -1094,68 +1094,10 @@ class LLaDABlockGroup(nn.ModuleList):
1094
  block.set_activation_checkpointing(strategy)
1095
 
1096
 
1097
- class LLaDAPreTrainedModel(PreTrainedModel):
1098
- """
1099
- Minimal HF-compatible base to enable gradient checkpointing hooks and centralize
1100
- parameter initialization.
1101
- """
1102
-
1103
- config_class = LLaDAConfig
1104
- base_model_prefix = "model"
1105
- _no_split_modules = ["LLaDALlamaBlock"]
1106
- _supports_gradient_checkpointing = True # backward compat
1107
- supports_gradient_checkpointing = True # transformers >=4.38
1108
-
1109
- def __init__(self, config, *model_args, **model_kwargs):
1110
- hf_config = config
1111
- if not hasattr(hf_config, "to_dict"):
1112
- hf_config = LLaDAConfig(**config.__dict__)
1113
- super().__init__(hf_config, *model_args, **model_kwargs)
1114
-
1115
- def _init_weights(self, module):
1116
- if getattr(module, "_llada_params_initialized", False):
1117
- return
1118
- if hasattr(module, "reset_parameters"):
1119
- module.reset_parameters()
1120
- for child in module.modules():
1121
- setattr(child, "_llada_params_initialized", True)
1122
-
1123
- def _set_gradient_checkpointing(
1124
- self, enable: bool = True, gradient_checkpointing_func: Callable = None
1125
- ):
1126
- """
1127
- New-format hook expected by `PreTrainedModel.gradient_checkpointing_enable`.
1128
- Only LLaDAModel (the heavy transformer) actually toggles checkpointing.
1129
- """
1130
- from torch.utils.checkpoint import checkpoint
1131
-
1132
- if gradient_checkpointing_func is None:
1133
- gradient_checkpointing_func = checkpoint
1134
-
1135
- # When called on the HF wrapper (LLaDAModelLM), reach into the inner LLaDAModel.
1136
- target = self.model if isinstance(self, LLaDAModelLM) else self
1137
-
1138
- if isinstance(target, LLaDAModel):
1139
- target._gradient_checkpointing_func = gradient_checkpointing_func
1140
- target.gradient_checkpointing = enable
1141
- strategy = ActivationCheckpointingStrategy.whole_layer if enable else None
1142
- target.set_activation_checkpointing(strategy)
1143
- return
1144
-
1145
- # Fallback: walk modules to find the core model.
1146
- for module in self.modules():
1147
- if isinstance(module, LLaDAModel):
1148
- module._gradient_checkpointing_func = gradient_checkpointing_func
1149
- module.gradient_checkpointing = enable
1150
- strategy = ActivationCheckpointingStrategy.whole_layer if enable else None
1151
- module.set_activation_checkpointing(strategy)
1152
- break
1153
-
1154
-
1155
- class LLaDAModel(LLaDAPreTrainedModel):
1156
  def __init__(self, config: ModelConfig, init_params: bool = True):
1157
- super().__init__(config)
1158
- self.gradient_checkpointing: bool = False
1159
  self.__cache = BufferCache()
1160
 
1161
  # Validate config.
@@ -1224,7 +1166,7 @@ class LLaDAModel(LLaDAPreTrainedModel):
1224
  )
1225
  # When `init_device="meta"` FSDP will call `reset_parameters()` to initialize weights.
1226
  if init_params and self.config.init_device != "meta":
1227
- self.post_init()
1228
  self.__num_fwd_flops: Optional[int] = None
1229
 
1230
  # Warm up cache.
@@ -1513,7 +1455,7 @@ def create_model_config_from_pretrained_config(config: LLaDAConfig):
1513
  return model_config
1514
 
1515
 
1516
- class LLaDAModelLM(LLaDAPreTrainedModel):
1517
  """
1518
  Extremely barebones HF model wrapper.
1519
  """
 
1094
  block.set_activation_checkpointing(strategy)
1095
 
1096
 
1097
+ class LLaDAModel(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1098
  def __init__(self, config: ModelConfig, init_params: bool = True):
1099
+ super().__init__()
1100
+ self.config = config
1101
  self.__cache = BufferCache()
1102
 
1103
  # Validate config.
 
1166
  )
1167
  # When `init_device="meta"` FSDP will call `reset_parameters()` to initialize weights.
1168
  if init_params and self.config.init_device != "meta":
1169
+ self.reset_parameters()
1170
  self.__num_fwd_flops: Optional[int] = None
1171
 
1172
  # Warm up cache.
 
1455
  return model_config
1456
 
1457
 
1458
+ class LLaDAModelLM(PreTrainedModel):
1459
  """
1460
  Extremely barebones HF model wrapper.
1461
  """
modeling_recursive.py CHANGED
@@ -13,14 +13,6 @@ from transformers.utils import ModelOutput
13
 
14
  from .configuration_recursive import RecursiveMLMConfig
15
 
16
- # Register the custom LLaDA model so AutoConfig.for_model("llada") works
17
- # when constructing the base model from base_model_config.
18
- from .configuration_llada import LLaDAConfig
19
- from .modeling_llada import LLaDAModelLM
20
-
21
- AutoConfig.register("llada", LLaDAConfig)
22
- AutoModelForMaskedLM.register(LLaDAConfig, LLaDAModelLM)
23
-
24
 
25
  @dataclass
26
  class IterationMetrics(ModelOutput):
@@ -75,8 +67,15 @@ class RecursiveMaskedLM(PreTrainedModel):
75
  # to avoid reinitializing the pre-trained weights via _init_weights()
76
  self.mlm = base_model
77
  elif config.base_model_config is not None:
78
- base_config = AutoConfig.for_model(**config.base_model_config)
79
- self.mlm = AutoModelForMaskedLM.from_config(base_config)
 
 
 
 
 
 
 
80
  # Only call post_init() for freshly created models (needs weight init)
81
  self.post_init()
82
  else:
 
13
 
14
  from .configuration_recursive import RecursiveMLMConfig
15
 
 
 
 
 
 
 
 
 
16
 
17
  @dataclass
18
  class IterationMetrics(ModelOutput):
 
67
  # to avoid reinitializing the pre-trained weights via _init_weights()
68
  self.mlm = base_model
69
  elif config.base_model_config is not None:
70
+ model_type = config.base_model_config.get("model_type", "")
71
+ if model_type == "llada":
72
+ from .configuration_llada import LLaDAConfig
73
+ from .modeling_llada import LLaDAModelLM
74
+ base_config = LLaDAConfig.from_dict(config.base_model_config)
75
+ self.mlm = LLaDAModelLM(base_config)
76
+ else:
77
+ base_config = AutoConfig.for_model(**config.base_model_config)
78
+ self.mlm = AutoModelForMaskedLM.from_config(base_config)
79
  # Only call post_init() for freshly created models (needs weight init)
80
  self.post_init()
81
  else: