Spaces:
Running
on
Zero
Running
on
Zero
disable reshard after forward (#56)
Browse filesCo-authored-by: Srini Iyer <sviyer@meta.com>
bytelatent/transformer.py
CHANGED
|
@@ -146,16 +146,16 @@ def build_fsdp_grouping_plan(model_args: LMTransformerArgs):
|
|
| 146 |
group_plan.append(("output", True))
|
| 147 |
else:
|
| 148 |
for i in range(model_args.n_layers_local_encoder):
|
| 149 |
-
group_plan.append((f"local_encoder.layers.{i}",
|
| 150 |
-
group_plan.append((f"local_encoder.cross_attn_layers.{i}",
|
| 151 |
for i in range(model_args.n_layers_local_decoder):
|
| 152 |
-
group_plan.append((f"local_decoder.layers.{i}",
|
| 153 |
-
group_plan.append((f"local_decoder.cross_attn_layers.{i}",
|
| 154 |
for i in range(model_args.n_layers_global):
|
| 155 |
-
group_plan.append((f"global_transformer.layers.{i}",
|
| 156 |
|
| 157 |
for i in range(len(model_args.encoder_hash_byte_group_size)):
|
| 158 |
-
group_plan.append((f"encoder_hash_tok_embedding.{i}",
|
| 159 |
|
| 160 |
return group_plan
|
| 161 |
|
|
|
|
| 146 |
group_plan.append(("output", True))
|
| 147 |
else:
|
| 148 |
for i in range(model_args.n_layers_local_encoder):
|
| 149 |
+
group_plan.append((f"local_encoder.layers.{i}", False))
|
| 150 |
+
group_plan.append((f"local_encoder.cross_attn_layers.{i}", False))
|
| 151 |
for i in range(model_args.n_layers_local_decoder):
|
| 152 |
+
group_plan.append((f"local_decoder.layers.{i}", False))
|
| 153 |
+
group_plan.append((f"local_decoder.cross_attn_layers.{i}", False))
|
| 154 |
for i in range(model_args.n_layers_global):
|
| 155 |
+
group_plan.append((f"global_transformer.layers.{i}", False))
|
| 156 |
|
| 157 |
for i in range(len(model_args.encoder_hash_byte_group_size)):
|
| 158 |
+
group_plan.append((f"encoder_hash_tok_embedding.{i}", False))
|
| 159 |
|
| 160 |
return group_plan
|
| 161 |
|