Spaces:
Runtime error
Runtime error
fix(train): update model name
Browse files- tools/train/train.py +3 -4
tools/train/train.py
CHANGED
|
@@ -41,10 +41,9 @@ from flax.training import train_state
|
|
| 41 |
from flax.training.common_utils import get_metrics, onehot, shard_prng_key
|
| 42 |
from tqdm import tqdm
|
| 43 |
from transformers import AutoTokenizer, HfArgumentParser
|
| 44 |
-
from transformers.models.bart.modeling_flax_bart import BartConfig
|
| 45 |
|
| 46 |
from dalle_mini.data import Dataset
|
| 47 |
-
from dalle_mini.model import DalleBartConfig,
|
| 48 |
|
| 49 |
logger = logging.getLogger(__name__)
|
| 50 |
|
|
@@ -418,7 +417,7 @@ def main():
|
|
| 418 |
|
| 419 |
# Load or create new model
|
| 420 |
if model_args.model_name_or_path:
|
| 421 |
-
model =
|
| 422 |
model_args.model_name_or_path,
|
| 423 |
config=config,
|
| 424 |
seed=training_args.seed_model,
|
|
@@ -427,7 +426,7 @@ def main():
|
|
| 427 |
# avoid OOM on TPU: see https://github.com/google/flax/issues/1658
|
| 428 |
print(model.params)
|
| 429 |
else:
|
| 430 |
-
model =
|
| 431 |
config,
|
| 432 |
seed=training_args.seed_model,
|
| 433 |
dtype=getattr(jnp, model_args.dtype),
|
|
|
|
| 41 |
from flax.training.common_utils import get_metrics, onehot, shard_prng_key
|
| 42 |
from tqdm import tqdm
|
| 43 |
from transformers import AutoTokenizer, HfArgumentParser
|
|
|
|
| 44 |
|
| 45 |
from dalle_mini.data import Dataset
|
| 46 |
+
from dalle_mini.model import DalleBartConfig, DalleBart
|
| 47 |
|
| 48 |
logger = logging.getLogger(__name__)
|
| 49 |
|
|
|
|
| 417 |
|
| 418 |
# Load or create new model
|
| 419 |
if model_args.model_name_or_path:
|
| 420 |
+
model = DalleBart.from_pretrained(
|
| 421 |
model_args.model_name_or_path,
|
| 422 |
config=config,
|
| 423 |
seed=training_args.seed_model,
|
|
|
|
| 426 |
# avoid OOM on TPU: see https://github.com/google/flax/issues/1658
|
| 427 |
print(model.params)
|
| 428 |
else:
|
| 429 |
+
model = DalleBart(
|
| 430 |
config,
|
| 431 |
seed=training_args.seed_model,
|
| 432 |
dtype=getattr(jnp, model_args.dtype),
|