Spaces:
Runtime error
Runtime error
feat: don't ignore mismatched
Browse files
dev/seq2seq/run_seq2seq_flax.py
CHANGED
|
@@ -531,7 +531,6 @@ def main():
|
|
| 531 |
config=config,
|
| 532 |
seed=training_args.seed_model,
|
| 533 |
dtype=getattr(jnp, model_args.dtype),
|
| 534 |
-
ignore_mismatched_sizes=True,
|
| 535 |
)
|
| 536 |
# avoid OOM on TPU: see https://github.com/google/flax/issues/1658
|
| 537 |
print(model.params)
|
|
|
|
| 531 |
config=config,
|
| 532 |
seed=training_args.seed_model,
|
| 533 |
dtype=getattr(jnp, model_args.dtype),
|
|
|
|
| 534 |
)
|
| 535 |
# avoid OOM on TPU: see https://github.com/google/flax/issues/1658
|
| 536 |
print(model.params)
|