fix(model): apply gate fp32 only for mixtral (#1241)
Browse files* fix(model): apply gate fp32 only for mixtral
* Update src/axolotl/utils/models.py
* fix gate layer check
---------
Co-authored-by: Wing Lian <wing.lian@gmail.com>
src/axolotl/utils/models.py
CHANGED
|
@@ -676,7 +676,7 @@ def load_model(
|
|
| 676 |
if not cfg.fsdp:
|
| 677 |
# FSDP doesn't like mixed Float and BFloat16
|
| 678 |
for name, module in model.named_modules():
|
| 679 |
-
if
|
| 680 |
module.to(torch.float32)
|
| 681 |
if model_config.model_type == "btlm":
|
| 682 |
# don't upcast lm_head for btlm
|
|
|
|
| 676 |
if not cfg.fsdp:
|
| 677 |
# FSDP doesn't like mixed Float and BFloat16
|
| 678 |
for name, module in model.named_modules():
|
| 679 |
+
if "norm" in name or name.endswith(".gate"):
|
| 680 |
module.to(torch.float32)
|
| 681 |
if model_config.model_type == "btlm":
|
| 682 |
# don't upcast lm_head for btlm
|