Fraser commited on
Commit
92871f7
·
1 Parent(s): 6f4a0d9

ALMOST WORKING

Browse files
Files changed (7) hide show
  1. ag_news_clm.sh +18 -0
  2. model/encoders.py +4 -2
  3. model/outputs.py +52 -0
  4. model/t5_vae.py +8 -6
  5. model/vae.py +3 -11
  6. run_clm_flax.py +3 -1
  7. train.py +32 -26
ag_news_clm.sh ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # test CLM works
2
+ export RUN_NAME=test_clm
3
+
4
+ ./venv/bin/python run_clm_flax.py \
5
+ --model_name_or_path="gpt2" \
6
+ --output_dir="output/${RUN_NAME}" \
7
+ --overwrite_output_dir \
8
+ --dataset_name="ag_news" \
9
+ --do_train --do_eval \
10
+ --save_steps="2500" \
11
+ --eval_steps="2500" \
12
+ --block_size="128" \
13
+ --per_device_train_batch_size="1" \
14
+ --per_device_eval_batch_size="1" \
15
+ --learning_rate="5e-3" --warmup_steps="1000" \
16
+ --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \
17
+ --overwrite_output_dir \
18
+ --num_train_epochs="20" \
model/encoders.py CHANGED
@@ -1,4 +1,5 @@
1
  import logging
 
2
  import flax.linen as nn
3
 
4
  logger = logging.getLogger(__name__)
@@ -14,8 +15,9 @@ class Encoder(nn.Module):
14
  @nn.compact
15
  def __call__(self, encoding):
16
  latent_tokens = nn.Dense(self.latent_size)(encoding)
17
- raw_latent_code = latent_tokens[:, : self.n_tokens, :]
18
- latent_code = nn.Tanh()(raw_latent_code)
 
19
  return latent_code # (batch, latent_tokens_per_sequence, latent_token_dim)
20
 
21
 
 
1
  import logging
2
+ import jax.numpy as jnp
3
  import flax.linen as nn
4
 
5
  logger = logging.getLogger(__name__)
 
15
  @nn.compact
16
  def __call__(self, encoding):
17
  latent_tokens = nn.Dense(self.latent_size)(encoding)
18
+ raw_latent_code = latent_tokens[:, : self.n_latent_tokens, :]
19
+ # TODO does this just apply tanh to each latent token? Or across the whole batch
20
+ latent_code = jnp.tanh(raw_latent_code)
21
  return latent_code # (batch, latent_tokens_per_sequence, latent_token_dim)
22
 
23
 
model/outputs.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import flax
2
  import jaxlib.xla_extension as jax_xla
3
 
@@ -14,6 +16,56 @@ class TransformerVAE_Output(ModelOutput):
14
  Latent codes representing encoded sequences.
15
  remade_encoder_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, n_tokens, model_dim)`):
16
  Reconstructed encoder hidden states representing sequences.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  """
18
  latent_codes: jax_xla.DeviceArray = None
19
  remade_encoder_hidden_state: jax_xla.DeviceArray = None
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
  import flax
4
  import jaxlib.xla_extension as jax_xla
5
 
 
16
  Latent codes representing encoded sequences.
17
  remade_encoder_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, n_tokens, model_dim)`):
18
  Reconstructed encoder hidden states representing sequences.
19
+
20
+ (std Seq2Seq) Args:
21
+ logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
22
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
23
+ past_key_values (:obj:`tuple(tuple(jax_xla.DeviceArray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
24
+ Tuple of :obj:`tuple(jax_xla.DeviceArray)` of length :obj:`config.n_layers`, with each tuple having 2
25
+ tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional
26
+ tensors of shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
27
+
28
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
29
+ blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding.
30
+ decoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
31
+ Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
32
+ layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
33
+
34
+ Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
35
+ decoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
36
+ Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
37
+ sequence_length, sequence_length)`.
38
+
39
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
40
+ self-attention heads.
41
+ cross_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
42
+ Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
43
+ sequence_length, sequence_length)`.
44
+
45
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
46
+ weighted average in the cross-attention heads.
47
+ encoder_last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
48
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
49
+ encoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
50
+ Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
51
+ layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
52
+
53
+ Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
54
+ encoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
55
+ Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
56
+ sequence_length, sequence_length)`.
57
+
58
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
59
+ self-attention heads.
60
  """
61
  latent_codes: jax_xla.DeviceArray = None
62
  remade_encoder_hidden_state: jax_xla.DeviceArray = None
63
+ # seq2seq
64
+ logits: jax_xla.DeviceArray = None
65
+ past_key_values: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None
66
+ decoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
67
+ decoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
68
+ cross_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
69
+ encoder_last_hidden_state: Optional[jax_xla.DeviceArray] = None
70
+ encoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
71
+ encoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
model/t5_vae.py CHANGED
@@ -35,12 +35,14 @@ class FlaxT5_VAE_ForAutoencodingModule(nn.Module):
35
  self,
36
  input_ids=None,
37
  attention_mask=None,
 
38
  latent_codes=None,
39
  output_attentions=None,
40
  output_hidden_states=None,
41
  return_dict=None,
42
  deterministic: bool = True,
43
  ):
 
44
  """
45
  Adapted from `FlaxT5ForConditionalGenerationModule`
46
  """
@@ -75,16 +77,16 @@ class FlaxT5_VAE_ForAutoencodingModule(nn.Module):
75
 
76
  sequence_output = decoder_outputs[0]
77
 
78
- if self.config.tie_word_embeddings:
79
  # Rescale output before projecting on vocab
80
  # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
81
- sequence_output = sequence_output * (self.config.t5.d_model ** -0.5)
82
 
83
- if self.config.tie_word_embeddings:
84
- shared_embedding = self.shared.variables["params"]["embedding"]
85
- lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output)
86
  else:
87
- lm_logits = self.lm_head(sequence_output)
88
 
89
  if not return_dict:
90
  return (lm_logits,) + decoder_outputs[1:] + encoder_outputs
 
35
  self,
36
  input_ids=None,
37
  attention_mask=None,
38
+ encoder_outputs=None,
39
  latent_codes=None,
40
  output_attentions=None,
41
  output_hidden_states=None,
42
  return_dict=None,
43
  deterministic: bool = True,
44
  ):
45
+ # TODO should I use None args when everything has to be computed anyway?
46
  """
47
  Adapted from `FlaxT5ForConditionalGenerationModule`
48
  """
 
77
 
78
  sequence_output = decoder_outputs[0]
79
 
80
+ if self.t5.config.tie_word_embeddings:
81
  # Rescale output before projecting on vocab
82
  # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
83
+ sequence_output = sequence_output * (self.t5.config.d_model ** -0.5)
84
 
85
+ if self.t5.config.tie_word_embeddings:
86
+ shared_embedding = self.t5.shared.variables["params"]["embedding"]
87
+ lm_logits = self.t5.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output)
88
  else:
89
+ lm_logits = self.t5.lm_head(sequence_output)
90
 
91
  if not return_dict:
92
  return (lm_logits,) + decoder_outputs[1:] + encoder_outputs
model/vae.py CHANGED
@@ -3,7 +3,6 @@ import flax.linen as nn
3
 
4
  from model.encoders import VAE_ENCODER_MODELS
5
  from model.decoders import VAE_DECODER_MODELS
6
- from model.outputs import TransformerVAE_Output
7
  from model.config import T5_VAE_Config
8
 
9
 
@@ -18,21 +17,14 @@ class VAE(nn.Module):
18
 
19
  def setup(self):
20
  self.encoder = VAE_ENCODER_MODELS[self.config.vae_encoder_model](self.config.latent_size, self.config.n_latent_tokens)
21
- self.decoder = VAE_DECODER_MODELS[self.config.vae_decoder_model](self.config.t5.d_model, self.config.n_latent_tokens)
22
 
23
  def __call__(self, encoding=None, latent_codes=None):
24
- if latent_codes is None:
25
- latent_codes = self.encode(encoding)
26
- # return latent_codes for regulariser loss
27
- return TransformerVAE_Output(
28
- latent_codes,
29
- self.decoder(latent_codes),
30
- )
31
 
32
  def encode(self, encoding):
33
- assert encoding.shape[1:] == self.input_shape
34
  return self.encoder(encoding)
35
 
36
  def decode(self, latent):
37
- assert latent.shape[1:] == self.input_shape
38
  return self.decoder(latent)
 
3
 
4
  from model.encoders import VAE_ENCODER_MODELS
5
  from model.decoders import VAE_DECODER_MODELS
 
6
  from model.config import T5_VAE_Config
7
 
8
 
 
17
 
18
  def setup(self):
19
  self.encoder = VAE_ENCODER_MODELS[self.config.vae_encoder_model](self.config.latent_size, self.config.n_latent_tokens)
20
+ self.decoder = VAE_DECODER_MODELS[self.config.vae_decoder_model](self.config.t5.d_model, self.config.n_latent_tokens)
21
 
22
  def __call__(self, encoding=None, latent_codes=None):
23
+ latent_codes = self.encode(encoding)
24
+ return self.decode(latent_codes), latent_codes
 
 
 
 
 
25
 
26
  def encode(self, encoding):
 
27
  return self.encoder(encoding)
28
 
29
  def decode(self, latent):
 
30
  return self.decoder(latent)
run_clm_flax.py CHANGED
@@ -405,7 +405,7 @@ def main():
405
  k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
406
  for k, t in concatenated_examples.items()
407
  }
408
- result["labels"] = result["input_ids"].copy()
409
  return result
410
 
411
  # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
@@ -421,6 +421,8 @@ def main():
421
  num_proc=data_args.preprocessing_num_workers,
422
  load_from_cache_file=not data_args.overwrite_cache,
423
  )
 
 
424
 
425
  if training_args.do_train:
426
  if "train" not in tokenized_datasets:
 
405
  k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
406
  for k, t in concatenated_examples.items()
407
  }
408
+ result["label"] = result["input_ids"].copy()
409
  return result
410
 
411
  # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
 
421
  num_proc=data_args.preprocessing_num_workers,
422
  load_from_cache_file=not data_args.overwrite_cache,
423
  )
424
+ import pdb
425
+ pdb.set_trace()
426
 
427
  if training_args.do_train:
428
  if "train" not in tokenized_datasets:
train.py CHANGED
@@ -2,6 +2,7 @@
2
  Pre-training/Fine-tuning seq2seq models on autoencoding a dataset.
3
 
4
  TODO:
 
5
  - [x] Don't make decoder input ids.
6
  - [ ] Add reg loss
7
  - [x] calculate MMD loss
@@ -15,7 +16,7 @@
15
  use_extra_logs (:obj:`bool`, `optional`, defaults to False):
16
  Store extra logs during each training inference.
17
 
18
- - [ ] Send the scedule time to the compute_loss method and calculate a coefficient based on that.
19
  '''
20
  import logging
21
  import math
@@ -379,6 +380,10 @@ def main():
379
  )
380
  return output
381
 
 
 
 
 
382
  tokenized_datasets = dataset.map(
383
  tokenize_function,
384
  batched=True,
@@ -394,22 +399,23 @@ def main():
394
  )
395
  block_size = min(data_args.block_size, tokenizer.model_max_length)
396
 
397
- # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
398
- def group_texts(examples):
399
- # Concatenate all texts.
400
- concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
401
- total_length = len(concatenated_examples[list(examples.keys())[0]])
402
- # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
403
- # customize this part to your needs.
404
- if total_length >= block_size:
405
- total_length = (total_length // block_size) * block_size
406
- # Split by chunks of max_len.
407
- result = {
408
- k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
409
- for k, t in concatenated_examples.items()
410
- }
411
- result["labels"] = result["input_ids"].copy()
412
- return result
 
413
 
414
  # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
415
  # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
@@ -419,7 +425,7 @@ def main():
419
  # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
420
 
421
  lm_datasets = tokenized_datasets.map(
422
- group_texts,
423
  batched=True,
424
  num_proc=data_args.preprocessing_num_workers,
425
  load_from_cache_file=not data_args.overwrite_cache,
@@ -516,8 +522,8 @@ def main():
516
  x_size = x.shape[0]
517
  y_size = y.shape[0]
518
  dim = x.shape[1]
519
- tiled_x = jnp.repeat(jnp.reshape(x, (x_size, 1, dim)), y_size, axis = 1)
520
- tiled_y = jnp.repeat(jnp.reshape(y, (1, y_size, dim)), x_size, axis = 0)
521
  return jnp.exp(-jnp.mean((tiled_x - tiled_y) ** 2, axis=2) / dim * 1.0)
522
 
523
  def compute_mmd(x, y):
@@ -526,16 +532,16 @@ def main():
526
  xy_kernel = compute_kernel(x, y)
527
  return jnp.mean(x_kernel) + jnp.mean(y_kernel) - 2 * jnp.mean(xy_kernel)
528
 
529
- def regulariser_loss(latent_codes):
530
- true_samples = jnp.random.randn(latent_codes.shape())
531
  return compute_mmd(true_samples, latent_codes)
532
 
533
- def loss_fn(logits, labels, latent_codes):
534
  shift_logits = logits[..., :-1, :]
535
  shift_labels = labels[..., 1:]
536
  loss = optax.softmax_cross_entropy(shift_logits, onehot(shift_labels, shift_logits.shape[-1]))
537
 
538
- reg_loss = regulariser_loss(latent_codes)
539
  return loss.mean() + reg_loss.mean()
540
 
541
  # Define gradient update step fn
@@ -544,8 +550,8 @@ def main():
544
 
545
  def compute_loss(params):
546
  labels = batch.pop("labels")
547
- logits, latent_codes = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[:2]
548
- loss = loss_fn(logits, labels, latent_codes)
549
  return loss
550
 
551
  grad_fn = jax.value_and_grad(compute_loss)
 
2
  Pre-training/Fine-tuning seq2seq models on autoencoding a dataset.
3
 
4
  TODO:
5
+ - [ ] Get this running.
6
  - [x] Don't make decoder input ids.
7
  - [ ] Add reg loss
8
  - [x] calculate MMD loss
 
16
  use_extra_logs (:obj:`bool`, `optional`, defaults to False):
17
  Store extra logs during each training inference.
18
 
19
+ - [ ] Send the schedule time to the compute_loss method and calculate a coefficient based on that.
20
  '''
21
  import logging
22
  import math
 
380
  )
381
  return output
382
 
383
+ # remove dataset tasks
384
+ for k in dataset.keys():
385
+ dataset[k].info.task_templates = []
386
+
387
  tokenized_datasets = dataset.map(
388
  tokenize_function,
389
  batched=True,
 
399
  )
400
  block_size = min(data_args.block_size, tokenizer.model_max_length)
401
 
402
+ # Limits each input sequence to size block_size.
403
+ pad_token_id = tokenizer.pad_token_id
404
+
405
+ def limit_length(examples):
406
+ examples["labels"] = examples["input_ids"].copy()
407
+
408
+ for i, input_ids in enumerate(examples["input_ids"]):
409
+ if len(input_ids) > block_size:
410
+ for k in examples.keys():
411
+ examples[k][i] = examples[k][i][:block_size]
412
+ elif len(input_ids) < block_size:
413
+ delta = block_size - len(input_ids)
414
+ examples['input_ids'][i] = examples['input_ids'][i] + [pad_token_id] * delta
415
+ examples['attention_mask'][i] = examples['attention_mask'][i] + [0] * delta
416
+ examples['labels'][i] = examples['labels'][i] + [-100] * delta
417
+
418
+ return examples
419
 
420
  # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
421
  # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
 
425
  # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
426
 
427
  lm_datasets = tokenized_datasets.map(
428
+ limit_length,
429
  batched=True,
430
  num_proc=data_args.preprocessing_num_workers,
431
  load_from_cache_file=not data_args.overwrite_cache,
 
522
  x_size = x.shape[0]
523
  y_size = y.shape[0]
524
  dim = x.shape[1]
525
+ tiled_x = jnp.repeat(jnp.reshape(x, (x_size, 1, dim)), y_size, axis=1)
526
+ tiled_y = jnp.repeat(jnp.reshape(y, (1, y_size, dim)), x_size, axis=0)
527
  return jnp.exp(-jnp.mean((tiled_x - tiled_y) ** 2, axis=2) / dim * 1.0)
528
 
529
  def compute_mmd(x, y):
 
532
  xy_kernel = compute_kernel(x, y)
533
  return jnp.mean(x_kernel) + jnp.mean(y_kernel) - 2 * jnp.mean(xy_kernel)
534
 
535
+ def regulariser_loss(latent_codes, rng: jax.random.PRNGKey):
536
+ true_samples = jax.random.normal(rng, latent_codes.shape())
537
  return compute_mmd(true_samples, latent_codes)
538
 
539
+ def loss_fn(logits, labels, latent_codes, rng: jax.random.PRNGKey):
540
  shift_logits = logits[..., :-1, :]
541
  shift_labels = labels[..., 1:]
542
  loss = optax.softmax_cross_entropy(shift_logits, onehot(shift_labels, shift_logits.shape[-1]))
543
 
544
+ reg_loss = regulariser_loss(latent_codes, rng)
545
  return loss.mean() + reg_loss.mean()
546
 
547
  # Define gradient update step fn
 
550
 
551
  def compute_loss(params):
552
  labels = batch.pop("labels")
553
+ outputs = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)
554
+ loss = loss_fn(outputs.logits, labels, outputs.latent_codes, state.dropout_rng)
555
  return loss
556
 
557
  grad_fn = jax.value_and_grad(compute_loss)