Fraser commited on
Commit
7633929
·
1 Parent(s): 0c4e401

working saving & loading

Browse files
Files changed (5) hide show
  1. ag_news.sh +1 -0
  2. ag_news_load.sh +20 -0
  3. model/config.py +4 -5
  4. tests/test_t5_vae.py +15 -42
  5. train.py +9 -8
ag_news.sh CHANGED
@@ -17,3 +17,4 @@ export RUN_NAME=test
17
  --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \
18
  --overwrite_output_dir \
19
  --num_train_epochs="20" \
 
 
17
  --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \
18
  --overwrite_output_dir \
19
  --num_train_epochs="20" \
20
+ --push_to_hub \
ag_news_load.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export RUN_NAME=test
2
+
3
+ ./venv/bin/python train.py \
4
+ --model_name_or_path="output/test" \
5
+ --t5_model_name_or_path="t5-base" \
6
+ --output_dir="output/from_save/${RUN_NAME}" \
7
+ --overwrite_output_dir \
8
+ --dataset_name="ag_news" \
9
+ --do_train --do_eval \
10
+ --n_latent_tokens 6 \
11
+ --latent_token_size 32 \
12
+ --save_steps="2500" \
13
+ --eval_steps="2500" \
14
+ --block_size="32" \
15
+ --per_device_train_batch_size="1" \
16
+ --per_device_eval_batch_size="1" \
17
+ --learning_rate="5e-3" --warmup_steps="1000" \
18
+ --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \
19
+ --overwrite_output_dir \
20
+ --num_train_epochs="20" \
model/config.py CHANGED
@@ -46,6 +46,7 @@ class T5VaeConfig(PretrainedConfig):
46
  cache_dir=None,
47
  tie_word_embeddings=True,
48
  # T5 config
 
49
  vocab_size=32128,
50
  d_model=512,
51
  d_kv=64,
@@ -86,12 +87,10 @@ class T5VaeConfig(PretrainedConfig):
86
  if t5_model_name_or_path:
87
  self.t5 = AutoConfig.from_pretrained(t5_model_name_or_path, cache_dir=cache_dir)
88
  assertEqual(self.t5.model_type, "t5", "Need t5 model type for transformer_decoder.")
89
- if num_layers:
90
- self.t5.num_layers = num_layers
91
- if num_heads:
92
- self.t5.num_heads = num_heads
93
  self.t5.decoder_start_token_id = decoder_start_token_id
94
- self.t5.n_positions = self.set_seq_size
 
 
95
  else:
96
  self.t5 = T5Config(
97
  vocab_size=vocab_size,
 
46
  cache_dir=None,
47
  tie_word_embeddings=True,
48
  # T5 config
49
+ t5=dict(),
50
  vocab_size=32128,
51
  d_model=512,
52
  d_kv=64,
 
87
  if t5_model_name_or_path:
88
  self.t5 = AutoConfig.from_pretrained(t5_model_name_or_path, cache_dir=cache_dir)
89
  assertEqual(self.t5.model_type, "t5", "Need t5 model type for transformer_decoder.")
 
 
 
 
90
  self.t5.decoder_start_token_id = decoder_start_token_id
91
+ elif t5:
92
+ # use for loading a config
93
+ self.t5 = T5Config(**t5)
94
  else:
95
  self.t5 = T5Config(
96
  vocab_size=vocab_size,
tests/test_t5_vae.py CHANGED
@@ -294,51 +294,21 @@ class FlaxT5VaeModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unitte
294
  for jitted_output, output in zip(jitted_outputs, outputs):
295
  self.assertEqual(jitted_output.shape, output.shape)
296
 
297
- # overwrite since special base model prefix is used
298
- def test_save_load_from_base(self):
299
  config, _ = self.model_tester.prepare_config_and_inputs_for_common()
300
- base_class = FLAX_MODEL_MAPPING[config.__class__]
301
 
302
- for model_class in self.all_model_classes:
303
- if model_class == base_class:
304
- continue
305
-
306
- model = base_class(config)
307
- base_params = flatten_dict(unfreeze(model.params))
308
-
309
- # check that all base model weights are loaded correctly
310
- with tempfile.TemporaryDirectory() as tmpdirname:
311
- model.save_pretrained(tmpdirname)
312
- head_model = model_class.from_pretrained(tmpdirname)
313
-
314
- base_param_from_head = flatten_dict(unfreeze(head_model.params))
315
 
316
- for key in base_param_from_head.keys():
317
- max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
318
- self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
 
 
319
 
320
- # overwrite since special base model prefix is used
321
- def test_save_load_to_base(self):
322
- config, _ = self.model_tester.prepare_config_and_inputs_for_common()
323
- base_class = FLAX_MODEL_MAPPING[config.__class__]
324
-
325
- for model_class in self.all_model_classes:
326
- if model_class == base_class:
327
- continue
328
-
329
- model = model_class(config)
330
- base_params_from_head = flatten_dict(unfreeze(model.params))
331
-
332
- # check that all base model weights are loaded correctly
333
- with tempfile.TemporaryDirectory() as tmpdirname:
334
- model.save_pretrained(tmpdirname)
335
- base_model = base_class.from_pretrained(tmpdirname)
336
-
337
- base_params = flatten_dict(unfreeze(base_model.params))
338
-
339
- for key in base_params_from_head.keys():
340
- max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
341
- self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
342
 
343
 
344
  ## Copied training methdos
@@ -354,8 +324,8 @@ def compute_mmd(x, y):
354
  x_kernel = compute_kernel(x, x)
355
  y_kernel = compute_kernel(y, y)
356
  xy_kernel = compute_kernel(x, y)
357
-
358
  return jnp.mean(x_kernel) + jnp.mean(y_kernel) - 2 * jnp.mean(xy_kernel)
 
359
  def regulariser_loss(latent_codes, rng):
360
  true_samples = jax.random.normal(rng, latent_codes.shape)
361
  return compute_mmd(true_samples, latent_codes)
@@ -403,3 +373,6 @@ class FlaxT5VaeModelIntegrationTests(unittest.TestCase):
403
  outputs = model(input_ids, decoder_input_ids=decoder_input_ids)
404
  logits, latent_codes = outputs[0], outputs[1]
405
  loss = loss_fn(logits, labels, latent_codes, jax.random.PRNGKey(42))
 
 
 
 
294
  for jitted_output, output in zip(jitted_outputs, outputs):
295
  self.assertEqual(jitted_output.shape, output.shape)
296
 
297
+ def test_save_and_load(self):
 
298
  config, _ = self.model_tester.prepare_config_and_inputs_for_common()
 
299
 
300
+ model = FlaxT5VaeForAutoencoding(config)
301
+ model_params = flatten_dict(unfreeze(model.params))
 
 
 
 
 
 
 
 
 
 
 
302
 
303
+ # check that all base model weights are loaded correctly
304
+ with tempfile.TemporaryDirectory() as tmpdirname:
305
+ model.save_pretrained(tmpdirname)
306
+ head_model = FlaxT5VaeForAutoencoding.from_pretrained(tmpdirname)
307
+ new_params = flatten_dict(unfreeze(head_model.params))
308
 
309
+ for key in new_params.keys():
310
+ max_diff = (model_params[key] - new_params[key]).sum().item()
311
+ self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
 
313
 
314
  ## Copied training methdos
 
324
  x_kernel = compute_kernel(x, x)
325
  y_kernel = compute_kernel(y, y)
326
  xy_kernel = compute_kernel(x, y)
 
327
  return jnp.mean(x_kernel) + jnp.mean(y_kernel) - 2 * jnp.mean(xy_kernel)
328
+
329
  def regulariser_loss(latent_codes, rng):
330
  true_samples = jax.random.normal(rng, latent_codes.shape)
331
  return compute_mmd(true_samples, latent_codes)
 
373
  outputs = model(input_ids, decoder_input_ids=decoder_input_ids)
374
  logits, latent_codes = outputs[0], outputs[1]
375
  loss = loss_fn(logits, labels, latent_codes, jax.random.PRNGKey(42))
376
+ import pdb
377
+ pdb.set_trace()
378
+ pass
train.py CHANGED
@@ -156,6 +156,9 @@ class DataTrainingArguments:
156
  "Default to the model max input length for single sentence inputs (take into account special tokens)."
157
  },
158
  )
 
 
 
159
  overwrite_cache: bool = field(
160
  default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
161
  )
@@ -293,7 +296,7 @@ def main():
293
  if data_args.dataset_name is not None:
294
  # Downloading and loading a dataset from the hub.
295
  dataset = load_dataset(
296
- data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False
297
  )
298
 
299
  if "validation" not in dataset.keys():
@@ -344,10 +347,6 @@ def main():
344
  tokenizer = AutoTokenizer.from_pretrained(
345
  model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
346
  )
347
- elif model_args.model_name_or_path:
348
- tokenizer = AutoTokenizer.from_pretrained(
349
- model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
350
- )
351
  elif model_args.t5_model_name_or_path:
352
  tokenizer = AutoTokenizer.from_pretrained(
353
  model_args.t5_model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
@@ -363,6 +362,7 @@ def main():
363
  model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
364
  )
365
  # TODO assert token embedding size == len(tokenizer)
 
366
  else:
367
  vocab_size = len(tokenizer)
368
  config.t5.vocab_size = vocab_size
@@ -563,7 +563,8 @@ def main():
563
 
564
  def regulariser_loss(latent_codes, rng):
565
  true_samples = jax.random.normal(rng, latent_codes.shape)
566
- return jax.vmap(compute_mmd)(true_samples, latent_codes)
 
567
 
568
  def loss_fn(logits, labels, latent_codes, regulariser_rng):
569
  shift_logits = logits[..., :-1, :]
@@ -594,7 +595,7 @@ def main():
594
  return new_state, metrics
595
 
596
  # Define eval fn
597
- def eval_step(params, batch):
598
  labels = batch.pop("labels")
599
  logits, latent_codes = model(**batch, params=params, train=False)[:2]
600
  loss = loss_fn(logits, labels, latent_codes, rng)
@@ -660,7 +661,7 @@ def main():
660
  for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
661
  # Model forward
662
  batch = next(eval_loader)
663
- metrics = p_eval_step(state.params, batch)
664
  eval_metrics.append(metrics)
665
 
666
  # normalize eval metrics
 
156
  "Default to the model max input length for single sentence inputs (take into account special tokens)."
157
  },
158
  )
159
+ streaming: bool = field(
160
+ default=False, metadata={"help": "Stream the dataset."}
161
+ )
162
  overwrite_cache: bool = field(
163
  default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
164
  )
 
296
  if data_args.dataset_name is not None:
297
  # Downloading and loading a dataset from the hub.
298
  dataset = load_dataset(
299
+ data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, streaming=data_args.streaming, keep_in_memory=False
300
  )
301
 
302
  if "validation" not in dataset.keys():
 
347
  tokenizer = AutoTokenizer.from_pretrained(
348
  model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
349
  )
 
 
 
 
350
  elif model_args.t5_model_name_or_path:
351
  tokenizer = AutoTokenizer.from_pretrained(
352
  model_args.t5_model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
 
362
  model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
363
  )
364
  # TODO assert token embedding size == len(tokenizer)
365
+ assert(model.params['t5']['shared'].shape[0] == len(tokenizer), "T5 Tokenizer doesn't match T5Vae embedding size.")
366
  else:
367
  vocab_size = len(tokenizer)
368
  config.t5.vocab_size = vocab_size
 
563
 
564
  def regulariser_loss(latent_codes, rng):
565
  true_samples = jax.random.normal(rng, latent_codes.shape)
566
+ # return jax.vmap(compute_mmd)(true_samples, latent_codes)
567
+ return compute_mmd(true_samples, latent_codes)
568
 
569
  def loss_fn(logits, labels, latent_codes, regulariser_rng):
570
  shift_logits = logits[..., :-1, :]
 
595
  return new_state, metrics
596
 
597
  # Define eval fn
598
+ def eval_step(params, rng, batch):
599
  labels = batch.pop("labels")
600
  logits, latent_codes = model(**batch, params=params, train=False)[:2]
601
  loss = loss_fn(logits, labels, latent_codes, rng)
 
661
  for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
662
  # Model forward
663
  batch = next(eval_loader)
664
+ metrics = p_eval_step(state.params, state.dropout_rng, batch)
665
  eval_metrics.append(metrics)
666
 
667
  # normalize eval metrics