Update modeling.py
Browse files- modeling.py +5 -1
modeling.py
CHANGED
|
@@ -114,10 +114,14 @@ class MARModel(PreTrainedModel):
|
|
| 114 |
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
| 115 |
config = MARConfig.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
| 116 |
model = cls(config)
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
| 118 |
model.model.load_state_dict(state_dict)
|
| 119 |
return model
|
| 120 |
|
|
|
|
| 121 |
def save_pretrained(self, save_directory):
|
| 122 |
# we will save to safetensors
|
| 123 |
os.makedirs(save_directory, exist_ok=True)
|
|
|
|
| 114 |
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
| 115 |
config = MARConfig.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
| 116 |
model = cls(config)
|
| 117 |
+
safetensors_path = os.path.join(pretrained_model_name_or_path, "checkpoint-last.safetensors")
|
| 118 |
+
if not os.path.exists(safetensors_path):
|
| 119 |
+
raise FileNotFoundError(f"safetensors file not found at {safetensors_path}")
|
| 120 |
+
state_dict = torch.load(safetensors_path, map_location='cpu')
|
| 121 |
model.model.load_state_dict(state_dict)
|
| 122 |
return model
|
| 123 |
|
| 124 |
+
|
| 125 |
def save_pretrained(self, save_directory):
|
| 126 |
# we will save to safetensors
|
| 127 |
os.makedirs(save_directory, exist_ok=True)
|