Update modeling_custom_seq2seq_llm.py
Browse files- modeling_custom_seq2seq_llm.py +27 -27
modeling_custom_seq2seq_llm.py
CHANGED
|
@@ -1228,33 +1228,33 @@ class CustomSeq2SeqLLM(PreTrainedModel):
|
|
| 1228 |
torch_filepath = os.path.join(save_directory, "pytorch_model.bin")
|
| 1229 |
torch.save(cpu_state_dict, torch_filepath)
|
| 1230 |
|
| 1231 |
-
@classmethod
|
| 1232 |
-
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
| 1233 |
-
|
| 1234 |
-
|
| 1235 |
-
|
| 1236 |
-
|
| 1237 |
-
|
| 1238 |
-
|
| 1239 |
-
|
| 1240 |
-
|
| 1241 |
-
|
| 1242 |
-
|
| 1243 |
-
|
| 1244 |
-
|
| 1245 |
-
|
| 1246 |
-
|
| 1247 |
-
|
| 1248 |
-
|
| 1249 |
-
|
| 1250 |
-
|
| 1251 |
-
|
| 1252 |
-
|
| 1253 |
-
|
| 1254 |
-
|
| 1255 |
-
|
| 1256 |
-
|
| 1257 |
-
|
| 1258 |
|
| 1259 |
class CustomEncoder(nn.Module):
|
| 1260 |
def __init__(self, config):
|
|
|
|
| 1228 |
torch_filepath = os.path.join(save_directory, "pytorch_model.bin")
|
| 1229 |
torch.save(cpu_state_dict, torch_filepath)
|
| 1230 |
|
| 1231 |
+
# @classmethod
|
| 1232 |
+
# def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
| 1233 |
+
# config = kwargs.pop("config", None)
|
| 1234 |
+
# state_dict = kwargs.pop("state_dict", None)
|
| 1235 |
+
|
| 1236 |
+
# if config is None:
|
| 1237 |
+
# config = cls.config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
| 1238 |
+
|
| 1239 |
+
# model = cls(config)
|
| 1240 |
+
|
| 1241 |
+
# if state_dict is None:
|
| 1242 |
+
# # Try loading safetensors first
|
| 1243 |
+
# safe_filepath = os.path.join(pretrained_model_name_or_path, "model.safetensors")
|
| 1244 |
+
# if os.path.exists(safe_filepath):
|
| 1245 |
+
# from safetensors.torch import load_file
|
| 1246 |
+
# state_dict = load_file(safe_filepath)
|
| 1247 |
+
# else:
|
| 1248 |
+
# # Fall back to PyTorch format
|
| 1249 |
+
# torch_filepath = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
|
| 1250 |
+
# state_dict = torch.load(torch_filepath, map_location="cpu")
|
| 1251 |
+
|
| 1252 |
+
# # Handle shared weights
|
| 1253 |
+
# if config.tie_word_embeddings and "lm_head.weight" not in state_dict:
|
| 1254 |
+
# state_dict["lm_head.weight"] = state_dict["shared.weight"]
|
| 1255 |
+
|
| 1256 |
+
# model.load_state_dict(state_dict)
|
| 1257 |
+
# return model
|
| 1258 |
|
| 1259 |
class CustomEncoder(nn.Module):
|
| 1260 |
def __init__(self, config):
|