Spaces:
Runtime error
Runtime error
Pedro Cuenca
commited on
Commit
·
9f522b8
1
Parent(s):
290e443
Accept changes suggested by linter.
Browse files- src/dalle_mini/model/modeling.py +7 -3
- tools/train/train.py +3 -1
src/dalle_mini/model/modeling.py
CHANGED
|
@@ -569,14 +569,18 @@ class DalleBart(FlaxBartPreTrainedModel, FlaxBartForConditionalGeneration):
|
|
| 569 |
"""
|
| 570 |
Initializes from a wandb artifact, or delegates loading to the superclass.
|
| 571 |
"""
|
| 572 |
-
if
|
|
|
|
|
|
|
| 573 |
# wandb artifact
|
| 574 |
artifact = wandb.Api().artifact(pretrained_model_name_or_path)
|
| 575 |
-
|
| 576 |
# we download everything, including opt_state, so we can resume training if needed
|
| 577 |
# see also: #120
|
| 578 |
pretrained_model_name_or_path = artifact.download()
|
| 579 |
|
| 580 |
-
model = super(DalleBart, cls).from_pretrained(
|
|
|
|
|
|
|
| 581 |
model.config.resolved_name_or_path = pretrained_model_name_or_path
|
| 582 |
return model
|
|
|
|
| 569 |
"""
|
| 570 |
Initializes from a wandb artifact, or delegates loading to the superclass.
|
| 571 |
"""
|
| 572 |
+
if ":" in pretrained_model_name_or_path and not os.path.isdir(
|
| 573 |
+
pretrained_model_name_or_path
|
| 574 |
+
):
|
| 575 |
# wandb artifact
|
| 576 |
artifact = wandb.Api().artifact(pretrained_model_name_or_path)
|
| 577 |
+
|
| 578 |
# we download everything, including opt_state, so we can resume training if needed
|
| 579 |
# see also: #120
|
| 580 |
pretrained_model_name_or_path = artifact.download()
|
| 581 |
|
| 582 |
+
model = super(DalleBart, cls).from_pretrained(
|
| 583 |
+
pretrained_model_name_or_path, *model_args, **kwargs
|
| 584 |
+
)
|
| 585 |
model.config.resolved_name_or_path = pretrained_model_name_or_path
|
| 586 |
return model
|
tools/train/train.py
CHANGED
|
@@ -437,7 +437,9 @@ def main():
|
|
| 437 |
if training_args.resume_from_checkpoint is not None:
|
| 438 |
# load model
|
| 439 |
model = DalleBart.from_pretrained(
|
| 440 |
-
training_args.resume_from_checkpoint,
|
|
|
|
|
|
|
| 441 |
)
|
| 442 |
# avoid OOM on TPU: see https://github.com/google/flax/issues/1658
|
| 443 |
print(model.params)
|
|
|
|
| 437 |
if training_args.resume_from_checkpoint is not None:
|
| 438 |
# load model
|
| 439 |
model = DalleBart.from_pretrained(
|
| 440 |
+
training_args.resume_from_checkpoint,
|
| 441 |
+
dtype=getattr(jnp, model_args.dtype),
|
| 442 |
+
abstract_init=True,
|
| 443 |
)
|
| 444 |
# avoid OOM on TPU: see https://github.com/google/flax/issues/1658
|
| 445 |
print(model.params)
|