mup base shapes path
Browse files- modeling_nt_bert.py +4 -2
modeling_nt_bert.py
CHANGED
|
@@ -78,11 +78,13 @@ class BertPreTrainedModel(PreTrainedModel):
|
|
| 78 |
|
| 79 |
# since we used MuP, need to reset values since they're not saved with the model
|
| 80 |
if os.path.exists("base_shapes.bsh") is False:
|
| 81 |
-
hf_hub_download(
|
| 82 |
"zpn/human_bp_bert", "base_shapes.bsh"
|
| 83 |
)
|
|
|
|
|
|
|
| 84 |
|
| 85 |
-
set_base_shapes(model,
|
| 86 |
|
| 87 |
return model
|
| 88 |
|
|
|
|
| 78 |
|
| 79 |
# since we used MuP, need to reset values since they're not saved with the model
|
| 80 |
if os.path.exists("base_shapes.bsh") is False:
|
| 81 |
+
path = hf_hub_download(
|
| 82 |
"zpn/human_bp_bert", "base_shapes.bsh"
|
| 83 |
)
|
| 84 |
+
else:
|
| 85 |
+
path = "base_shapes.bsh"
|
| 86 |
|
| 87 |
+
set_base_shapes(model, path, rescale_params=False)
|
| 88 |
|
| 89 |
return model
|
| 90 |
|