Spaces:
Runtime error
Runtime error
fix(seq2seq): memory issue
Browse files
dev/seq2seq/run_seq2seq_flax.py
CHANGED
|
@@ -100,12 +100,6 @@ class ModelArguments:
|
|
| 100 |
"help": "Pretrained config name or path if not the same as model_name"
|
| 101 |
},
|
| 102 |
)
|
| 103 |
-
tokenizer_name: Optional[str] = field(
|
| 104 |
-
default=None,
|
| 105 |
-
metadata={
|
| 106 |
-
"help": "Pretrained tokenizer name or path if not the same as model_name"
|
| 107 |
-
},
|
| 108 |
-
)
|
| 109 |
cache_dir: Optional[str] = field(
|
| 110 |
default=None,
|
| 111 |
metadata={
|
|
@@ -539,11 +533,6 @@ def main():
|
|
| 539 |
)
|
| 540 |
|
| 541 |
else:
|
| 542 |
-
base_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
|
| 543 |
-
model_args.model_name_or_path,
|
| 544 |
-
seed=training_args.seed,
|
| 545 |
-
dtype=getattr(jnp, model_args.dtype),
|
| 546 |
-
)
|
| 547 |
# Set up our new model config
|
| 548 |
config = BartConfig.from_pretrained(model_args.model_name_or_path)
|
| 549 |
config.tie_word_embeddings = False
|
|
@@ -568,11 +557,6 @@ def main():
|
|
| 568 |
config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
| 569 |
)
|
| 570 |
|
| 571 |
-
# Use pre-trained weights for encoder
|
| 572 |
-
model.params["model"]["encoder"] = base_model.params["model"]["encoder"]
|
| 573 |
-
model.params["model"]["shared"] = base_model.params["model"]["shared"]
|
| 574 |
-
del base_model
|
| 575 |
-
|
| 576 |
# Load tokenizer if it has not been set
|
| 577 |
if tokenizer is None:
|
| 578 |
tokenizer = AutoTokenizer.from_pretrained(
|
|
@@ -960,12 +944,12 @@ def main():
|
|
| 960 |
artifact.add_file(
|
| 961 |
str(Path(training_args.output_dir) / "training_state.json")
|
| 962 |
)
|
| 963 |
-
wandb.run.log_artifact(artifact)
|
| 964 |
-
|
| 965 |
# save some space
|
| 966 |
c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
|
| 967 |
c.cleanup(wandb.util.from_human_size("5GB"))
|
| 968 |
|
|
|
|
|
|
|
| 969 |
# save to the hub
|
| 970 |
if training_args.push_to_hub:
|
| 971 |
model.save_pretrained(
|
|
|
|
| 100 |
"help": "Pretrained config name or path if not the same as model_name"
|
| 101 |
},
|
| 102 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
cache_dir: Optional[str] = field(
|
| 104 |
default=None,
|
| 105 |
metadata={
|
|
|
|
| 533 |
)
|
| 534 |
|
| 535 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 536 |
# Set up our new model config
|
| 537 |
config = BartConfig.from_pretrained(model_args.model_name_or_path)
|
| 538 |
config.tie_word_embeddings = False
|
|
|
|
| 557 |
config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
| 558 |
)
|
| 559 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 560 |
# Load tokenizer if it has not been set
|
| 561 |
if tokenizer is None:
|
| 562 |
tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
| 944 |
artifact.add_file(
|
| 945 |
str(Path(training_args.output_dir) / "training_state.json")
|
| 946 |
)
|
|
|
|
|
|
|
| 947 |
# save some space
|
| 948 |
c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
|
| 949 |
c.cleanup(wandb.util.from_human_size("5GB"))
|
| 950 |
|
| 951 |
+
wandb.run.log_artifact(artifact)
|
| 952 |
+
|
| 953 |
# save to the hub
|
| 954 |
if training_args.push_to_hub:
|
| 955 |
model.save_pretrained(
|