Update modeling.py
Browse files- modeling.py +10 -10
modeling.py
CHANGED
|
@@ -110,16 +110,16 @@ class MARModel(PreTrainedModel):
|
|
| 110 |
# call the sample_tokens method from the MAR class
|
| 111 |
return self.model.sample_tokens(bsz, num_iter, cfg, cfg_schedule, labels, temperature, progress)
|
| 112 |
|
| 113 |
-
@classmethod
|
| 114 |
-
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
|
| 124 |
|
| 125 |
def save_pretrained(self, save_directory):
|
|
|
|
| 110 |
# call the sample_tokens method from the MAR class
|
| 111 |
return self.model.sample_tokens(bsz, num_iter, cfg, cfg_schedule, labels, temperature, progress)
|
| 112 |
|
| 113 |
+
# @classmethod
|
| 114 |
+
# def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
| 115 |
+
# config = MARConfig.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
| 116 |
+
# model = cls(config)
|
| 117 |
+
# safetensors_path = os.path.join(pretrained_model_name_or_path, "checkpoint-last.safetensors")
|
| 118 |
+
# if not os.path.exists(safetensors_path):
|
| 119 |
+
# raise FileNotFoundError(f"safetensors file not found at {safetensors_path}")
|
| 120 |
+
# state_dict = torch.load(safetensors_path, map_location='cpu')
|
| 121 |
+
# model.model.load_state_dict(state_dict)
|
| 122 |
+
# return model
|
| 123 |
|
| 124 |
|
| 125 |
def save_pretrained(self, save_directory):
|