working saving & loading
Browse files- ag_news.sh +1 -0
- ag_news_load.sh +20 -0
- model/config.py +4 -5
- tests/test_t5_vae.py +15 -42
- train.py +9 -8
ag_news.sh
CHANGED
|
@@ -17,3 +17,4 @@ export RUN_NAME=test
|
|
| 17 |
--adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \
|
| 18 |
--overwrite_output_dir \
|
| 19 |
--num_train_epochs="20" \
|
|
|
|
|
|
| 17 |
--adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \
|
| 18 |
--overwrite_output_dir \
|
| 19 |
--num_train_epochs="20" \
|
| 20 |
+
--push_to_hub \
|
ag_news_load.sh
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export RUN_NAME=test
|
| 2 |
+
|
| 3 |
+
./venv/bin/python train.py \
|
| 4 |
+
--model_name_or_path="output/test" \
|
| 5 |
+
--t5_model_name_or_path="t5-base" \
|
| 6 |
+
--output_dir="output/from_save/${RUN_NAME}" \
|
| 7 |
+
--overwrite_output_dir \
|
| 8 |
+
--dataset_name="ag_news" \
|
| 9 |
+
--do_train --do_eval \
|
| 10 |
+
--n_latent_tokens 6 \
|
| 11 |
+
--latent_token_size 32 \
|
| 12 |
+
--save_steps="2500" \
|
| 13 |
+
--eval_steps="2500" \
|
| 14 |
+
--block_size="32" \
|
| 15 |
+
--per_device_train_batch_size="1" \
|
| 16 |
+
--per_device_eval_batch_size="1" \
|
| 17 |
+
--learning_rate="5e-3" --warmup_steps="1000" \
|
| 18 |
+
--adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \
|
| 19 |
+
--overwrite_output_dir \
|
| 20 |
+
--num_train_epochs="20" \
|
model/config.py
CHANGED
|
@@ -46,6 +46,7 @@ class T5VaeConfig(PretrainedConfig):
|
|
| 46 |
cache_dir=None,
|
| 47 |
tie_word_embeddings=True,
|
| 48 |
# T5 config
|
|
|
|
| 49 |
vocab_size=32128,
|
| 50 |
d_model=512,
|
| 51 |
d_kv=64,
|
|
@@ -86,12 +87,10 @@ class T5VaeConfig(PretrainedConfig):
|
|
| 86 |
if t5_model_name_or_path:
|
| 87 |
self.t5 = AutoConfig.from_pretrained(t5_model_name_or_path, cache_dir=cache_dir)
|
| 88 |
assertEqual(self.t5.model_type, "t5", "Need t5 model type for transformer_decoder.")
|
| 89 |
-
if num_layers:
|
| 90 |
-
self.t5.num_layers = num_layers
|
| 91 |
-
if num_heads:
|
| 92 |
-
self.t5.num_heads = num_heads
|
| 93 |
self.t5.decoder_start_token_id = decoder_start_token_id
|
| 94 |
-
|
|
|
|
|
|
|
| 95 |
else:
|
| 96 |
self.t5 = T5Config(
|
| 97 |
vocab_size=vocab_size,
|
|
|
|
| 46 |
cache_dir=None,
|
| 47 |
tie_word_embeddings=True,
|
| 48 |
# T5 config
|
| 49 |
+
t5=dict(),
|
| 50 |
vocab_size=32128,
|
| 51 |
d_model=512,
|
| 52 |
d_kv=64,
|
|
|
|
| 87 |
if t5_model_name_or_path:
|
| 88 |
self.t5 = AutoConfig.from_pretrained(t5_model_name_or_path, cache_dir=cache_dir)
|
| 89 |
assertEqual(self.t5.model_type, "t5", "Need t5 model type for transformer_decoder.")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
self.t5.decoder_start_token_id = decoder_start_token_id
|
| 91 |
+
elif t5:
|
| 92 |
+
# use for loading a config
|
| 93 |
+
self.t5 = T5Config(**t5)
|
| 94 |
else:
|
| 95 |
self.t5 = T5Config(
|
| 96 |
vocab_size=vocab_size,
|
tests/test_t5_vae.py
CHANGED
|
@@ -294,51 +294,21 @@ class FlaxT5VaeModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unitte
|
|
| 294 |
for jitted_output, output in zip(jitted_outputs, outputs):
|
| 295 |
self.assertEqual(jitted_output.shape, output.shape)
|
| 296 |
|
| 297 |
-
|
| 298 |
-
def test_save_load_from_base(self):
|
| 299 |
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
| 300 |
-
base_class = FLAX_MODEL_MAPPING[config.__class__]
|
| 301 |
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
continue
|
| 305 |
-
|
| 306 |
-
model = base_class(config)
|
| 307 |
-
base_params = flatten_dict(unfreeze(model.params))
|
| 308 |
-
|
| 309 |
-
# check that all base model weights are loaded correctly
|
| 310 |
-
with tempfile.TemporaryDirectory() as tmpdirname:
|
| 311 |
-
model.save_pretrained(tmpdirname)
|
| 312 |
-
head_model = model_class.from_pretrained(tmpdirname)
|
| 313 |
-
|
| 314 |
-
base_param_from_head = flatten_dict(unfreeze(head_model.params))
|
| 315 |
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
|
|
|
|
|
|
| 319 |
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
base_class = FLAX_MODEL_MAPPING[config.__class__]
|
| 324 |
-
|
| 325 |
-
for model_class in self.all_model_classes:
|
| 326 |
-
if model_class == base_class:
|
| 327 |
-
continue
|
| 328 |
-
|
| 329 |
-
model = model_class(config)
|
| 330 |
-
base_params_from_head = flatten_dict(unfreeze(model.params))
|
| 331 |
-
|
| 332 |
-
# check that all base model weights are loaded correctly
|
| 333 |
-
with tempfile.TemporaryDirectory() as tmpdirname:
|
| 334 |
-
model.save_pretrained(tmpdirname)
|
| 335 |
-
base_model = base_class.from_pretrained(tmpdirname)
|
| 336 |
-
|
| 337 |
-
base_params = flatten_dict(unfreeze(base_model.params))
|
| 338 |
-
|
| 339 |
-
for key in base_params_from_head.keys():
|
| 340 |
-
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
| 341 |
-
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
| 342 |
|
| 343 |
|
| 344 |
## Copied training methdos
|
|
@@ -354,8 +324,8 @@ def compute_mmd(x, y):
|
|
| 354 |
x_kernel = compute_kernel(x, x)
|
| 355 |
y_kernel = compute_kernel(y, y)
|
| 356 |
xy_kernel = compute_kernel(x, y)
|
| 357 |
-
|
| 358 |
return jnp.mean(x_kernel) + jnp.mean(y_kernel) - 2 * jnp.mean(xy_kernel)
|
|
|
|
| 359 |
def regulariser_loss(latent_codes, rng):
|
| 360 |
true_samples = jax.random.normal(rng, latent_codes.shape)
|
| 361 |
return compute_mmd(true_samples, latent_codes)
|
|
@@ -403,3 +373,6 @@ class FlaxT5VaeModelIntegrationTests(unittest.TestCase):
|
|
| 403 |
outputs = model(input_ids, decoder_input_ids=decoder_input_ids)
|
| 404 |
logits, latent_codes = outputs[0], outputs[1]
|
| 405 |
loss = loss_fn(logits, labels, latent_codes, jax.random.PRNGKey(42))
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
for jitted_output, output in zip(jitted_outputs, outputs):
|
| 295 |
self.assertEqual(jitted_output.shape, output.shape)
|
| 296 |
|
| 297 |
+
def test_save_and_load(self):
|
|
|
|
| 298 |
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
|
|
| 299 |
|
| 300 |
+
model = FlaxT5VaeForAutoencoding(config)
|
| 301 |
+
model_params = flatten_dict(unfreeze(model.params))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
|
| 303 |
+
# check that all base model weights are loaded correctly
|
| 304 |
+
with tempfile.TemporaryDirectory() as tmpdirname:
|
| 305 |
+
model.save_pretrained(tmpdirname)
|
| 306 |
+
head_model = FlaxT5VaeForAutoencoding.from_pretrained(tmpdirname)
|
| 307 |
+
new_params = flatten_dict(unfreeze(head_model.params))
|
| 308 |
|
| 309 |
+
for key in new_params.keys():
|
| 310 |
+
max_diff = (model_params[key] - new_params[key]).sum().item()
|
| 311 |
+
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
|
| 313 |
|
| 314 |
## Copied training methdos
|
|
|
|
| 324 |
x_kernel = compute_kernel(x, x)
|
| 325 |
y_kernel = compute_kernel(y, y)
|
| 326 |
xy_kernel = compute_kernel(x, y)
|
|
|
|
| 327 |
return jnp.mean(x_kernel) + jnp.mean(y_kernel) - 2 * jnp.mean(xy_kernel)
|
| 328 |
+
|
| 329 |
def regulariser_loss(latent_codes, rng):
|
| 330 |
true_samples = jax.random.normal(rng, latent_codes.shape)
|
| 331 |
return compute_mmd(true_samples, latent_codes)
|
|
|
|
| 373 |
outputs = model(input_ids, decoder_input_ids=decoder_input_ids)
|
| 374 |
logits, latent_codes = outputs[0], outputs[1]
|
| 375 |
loss = loss_fn(logits, labels, latent_codes, jax.random.PRNGKey(42))
|
| 376 |
+
import pdb
|
| 377 |
+
pdb.set_trace()
|
| 378 |
+
pass
|
train.py
CHANGED
|
@@ -156,6 +156,9 @@ class DataTrainingArguments:
|
|
| 156 |
"Default to the model max input length for single sentence inputs (take into account special tokens)."
|
| 157 |
},
|
| 158 |
)
|
|
|
|
|
|
|
|
|
|
| 159 |
overwrite_cache: bool = field(
|
| 160 |
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
| 161 |
)
|
|
@@ -293,7 +296,7 @@ def main():
|
|
| 293 |
if data_args.dataset_name is not None:
|
| 294 |
# Downloading and loading a dataset from the hub.
|
| 295 |
dataset = load_dataset(
|
| 296 |
-
data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False
|
| 297 |
)
|
| 298 |
|
| 299 |
if "validation" not in dataset.keys():
|
|
@@ -344,10 +347,6 @@ def main():
|
|
| 344 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 345 |
model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
| 346 |
)
|
| 347 |
-
elif model_args.model_name_or_path:
|
| 348 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
| 349 |
-
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
| 350 |
-
)
|
| 351 |
elif model_args.t5_model_name_or_path:
|
| 352 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 353 |
model_args.t5_model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
|
@@ -363,6 +362,7 @@ def main():
|
|
| 363 |
model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
| 364 |
)
|
| 365 |
# TODO assert token embedding size == len(tokenizer)
|
|
|
|
| 366 |
else:
|
| 367 |
vocab_size = len(tokenizer)
|
| 368 |
config.t5.vocab_size = vocab_size
|
|
@@ -563,7 +563,8 @@ def main():
|
|
| 563 |
|
| 564 |
def regulariser_loss(latent_codes, rng):
|
| 565 |
true_samples = jax.random.normal(rng, latent_codes.shape)
|
| 566 |
-
return jax.vmap(compute_mmd)(true_samples, latent_codes)
|
|
|
|
| 567 |
|
| 568 |
def loss_fn(logits, labels, latent_codes, regulariser_rng):
|
| 569 |
shift_logits = logits[..., :-1, :]
|
|
@@ -594,7 +595,7 @@ def main():
|
|
| 594 |
return new_state, metrics
|
| 595 |
|
| 596 |
# Define eval fn
|
| 597 |
-
def eval_step(params, batch):
|
| 598 |
labels = batch.pop("labels")
|
| 599 |
logits, latent_codes = model(**batch, params=params, train=False)[:2]
|
| 600 |
loss = loss_fn(logits, labels, latent_codes, rng)
|
|
@@ -660,7 +661,7 @@ def main():
|
|
| 660 |
for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
|
| 661 |
# Model forward
|
| 662 |
batch = next(eval_loader)
|
| 663 |
-
metrics = p_eval_step(state.params, batch)
|
| 664 |
eval_metrics.append(metrics)
|
| 665 |
|
| 666 |
# normalize eval metrics
|
|
|
|
| 156 |
"Default to the model max input length for single sentence inputs (take into account special tokens)."
|
| 157 |
},
|
| 158 |
)
|
| 159 |
+
streaming: bool = field(
|
| 160 |
+
default=False, metadata={"help": "Stream the dataset."}
|
| 161 |
+
)
|
| 162 |
overwrite_cache: bool = field(
|
| 163 |
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
| 164 |
)
|
|
|
|
| 296 |
if data_args.dataset_name is not None:
|
| 297 |
# Downloading and loading a dataset from the hub.
|
| 298 |
dataset = load_dataset(
|
| 299 |
+
data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, streaming=data_args.streaming, keep_in_memory=False
|
| 300 |
)
|
| 301 |
|
| 302 |
if "validation" not in dataset.keys():
|
|
|
|
| 347 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 348 |
model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
| 349 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
elif model_args.t5_model_name_or_path:
|
| 351 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 352 |
model_args.t5_model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
|
|
|
| 362 |
model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
| 363 |
)
|
| 364 |
# TODO assert token embedding size == len(tokenizer)
|
| 365 |
+
assert(model.params['t5']['shared'].shape[0] == len(tokenizer), "T5 Tokenizer doesn't match T5Vae embedding size.")
|
| 366 |
else:
|
| 367 |
vocab_size = len(tokenizer)
|
| 368 |
config.t5.vocab_size = vocab_size
|
|
|
|
| 563 |
|
| 564 |
def regulariser_loss(latent_codes, rng):
|
| 565 |
true_samples = jax.random.normal(rng, latent_codes.shape)
|
| 566 |
+
# return jax.vmap(compute_mmd)(true_samples, latent_codes)
|
| 567 |
+
return compute_mmd(true_samples, latent_codes)
|
| 568 |
|
| 569 |
def loss_fn(logits, labels, latent_codes, regulariser_rng):
|
| 570 |
shift_logits = logits[..., :-1, :]
|
|
|
|
| 595 |
return new_state, metrics
|
| 596 |
|
| 597 |
# Define eval fn
|
| 598 |
+
def eval_step(params, rng, batch):
|
| 599 |
labels = batch.pop("labels")
|
| 600 |
logits, latent_codes = model(**batch, params=params, train=False)[:2]
|
| 601 |
loss = loss_fn(logits, labels, latent_codes, rng)
|
|
|
|
| 661 |
for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
|
| 662 |
# Model forward
|
| 663 |
batch = next(eval_loader)
|
| 664 |
+
metrics = p_eval_step(state.params, state.dropout_rng, batch)
|
| 665 |
eval_metrics.append(metrics)
|
| 666 |
|
| 667 |
# normalize eval metrics
|