ALMOST WORKING
Browse files- ag_news_clm.sh +18 -0
- model/encoders.py +4 -2
- model/outputs.py +52 -0
- model/t5_vae.py +8 -6
- model/vae.py +3 -11
- run_clm_flax.py +3 -1
- train.py +32 -26
ag_news_clm.sh
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# test CLM works
|
| 2 |
+
export RUN_NAME=test_clm
|
| 3 |
+
|
| 4 |
+
./venv/bin/python run_clm_flax.py \
|
| 5 |
+
--model_name_or_path="gpt2" \
|
| 6 |
+
--output_dir="output/${RUN_NAME}" \
|
| 7 |
+
--overwrite_output_dir \
|
| 8 |
+
--dataset_name="ag_news" \
|
| 9 |
+
--do_train --do_eval \
|
| 10 |
+
--save_steps="2500" \
|
| 11 |
+
--eval_steps="2500" \
|
| 12 |
+
--block_size="128" \
|
| 13 |
+
--per_device_train_batch_size="1" \
|
| 14 |
+
--per_device_eval_batch_size="1" \
|
| 15 |
+
--learning_rate="5e-3" --warmup_steps="1000" \
|
| 16 |
+
--adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \
|
| 17 |
+
--overwrite_output_dir \
|
| 18 |
+
--num_train_epochs="20" \
|
model/encoders.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import logging
|
|
|
|
| 2 |
import flax.linen as nn
|
| 3 |
|
| 4 |
logger = logging.getLogger(__name__)
|
|
@@ -14,8 +15,9 @@ class Encoder(nn.Module):
|
|
| 14 |
@nn.compact
|
| 15 |
def __call__(self, encoding):
|
| 16 |
latent_tokens = nn.Dense(self.latent_size)(encoding)
|
| 17 |
-
raw_latent_code = latent_tokens[:, : self.
|
| 18 |
-
|
|
|
|
| 19 |
return latent_code # (batch, latent_tokens_per_sequence, latent_token_dim)
|
| 20 |
|
| 21 |
|
|
|
|
| 1 |
import logging
|
| 2 |
+
import jax.numpy as jnp
|
| 3 |
import flax.linen as nn
|
| 4 |
|
| 5 |
logger = logging.getLogger(__name__)
|
|
|
|
| 15 |
@nn.compact
|
| 16 |
def __call__(self, encoding):
|
| 17 |
latent_tokens = nn.Dense(self.latent_size)(encoding)
|
| 18 |
+
raw_latent_code = latent_tokens[:, : self.n_latent_tokens, :]
|
| 19 |
+
# TODO does this just apply tanh to each latent token? Or across the whole batch
|
| 20 |
+
latent_code = jnp.tanh(raw_latent_code)
|
| 21 |
return latent_code # (batch, latent_tokens_per_sequence, latent_token_dim)
|
| 22 |
|
| 23 |
|
model/outputs.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
import flax
|
| 2 |
import jaxlib.xla_extension as jax_xla
|
| 3 |
|
|
@@ -14,6 +16,56 @@ class TransformerVAE_Output(ModelOutput):
|
|
| 14 |
Latent codes representing encoded sequences.
|
| 15 |
remade_encoder_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, n_tokens, model_dim)`):
|
| 16 |
Reconstructed encoder hidden states representing sequences.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
"""
|
| 18 |
latent_codes: jax_xla.DeviceArray = None
|
| 19 |
remade_encoder_hidden_state: jax_xla.DeviceArray = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Tuple
|
| 2 |
+
|
| 3 |
import flax
|
| 4 |
import jaxlib.xla_extension as jax_xla
|
| 5 |
|
|
|
|
| 16 |
Latent codes representing encoded sequences.
|
| 17 |
remade_encoder_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, n_tokens, model_dim)`):
|
| 18 |
Reconstructed encoder hidden states representing sequences.
|
| 19 |
+
|
| 20 |
+
(std Seq2Seq) Args:
|
| 21 |
+
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
| 22 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 23 |
+
past_key_values (:obj:`tuple(tuple(jax_xla.DeviceArray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
| 24 |
+
Tuple of :obj:`tuple(jax_xla.DeviceArray)` of length :obj:`config.n_layers`, with each tuple having 2
|
| 25 |
+
tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional
|
| 26 |
+
tensors of shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
| 27 |
+
|
| 28 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
| 29 |
+
blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding.
|
| 30 |
+
decoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
| 31 |
+
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
| 32 |
+
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
| 33 |
+
|
| 34 |
+
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
|
| 35 |
+
decoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
| 36 |
+
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
| 37 |
+
sequence_length, sequence_length)`.
|
| 38 |
+
|
| 39 |
+
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
|
| 40 |
+
self-attention heads.
|
| 41 |
+
cross_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
| 42 |
+
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
| 43 |
+
sequence_length, sequence_length)`.
|
| 44 |
+
|
| 45 |
+
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
|
| 46 |
+
weighted average in the cross-attention heads.
|
| 47 |
+
encoder_last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
| 48 |
+
Sequence of hidden-states at the output of the last layer of the encoder of the model.
|
| 49 |
+
encoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
| 50 |
+
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
| 51 |
+
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
| 52 |
+
|
| 53 |
+
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
|
| 54 |
+
encoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
| 55 |
+
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
| 56 |
+
sequence_length, sequence_length)`.
|
| 57 |
+
|
| 58 |
+
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
|
| 59 |
+
self-attention heads.
|
| 60 |
"""
|
| 61 |
latent_codes: jax_xla.DeviceArray = None
|
| 62 |
remade_encoder_hidden_state: jax_xla.DeviceArray = None
|
| 63 |
+
# seq2seq
|
| 64 |
+
logits: jax_xla.DeviceArray = None
|
| 65 |
+
past_key_values: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None
|
| 66 |
+
decoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
| 67 |
+
decoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
| 68 |
+
cross_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
| 69 |
+
encoder_last_hidden_state: Optional[jax_xla.DeviceArray] = None
|
| 70 |
+
encoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
| 71 |
+
encoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
model/t5_vae.py
CHANGED
|
@@ -35,12 +35,14 @@ class FlaxT5_VAE_ForAutoencodingModule(nn.Module):
|
|
| 35 |
self,
|
| 36 |
input_ids=None,
|
| 37 |
attention_mask=None,
|
|
|
|
| 38 |
latent_codes=None,
|
| 39 |
output_attentions=None,
|
| 40 |
output_hidden_states=None,
|
| 41 |
return_dict=None,
|
| 42 |
deterministic: bool = True,
|
| 43 |
):
|
|
|
|
| 44 |
"""
|
| 45 |
Adapted from `FlaxT5ForConditionalGenerationModule`
|
| 46 |
"""
|
|
@@ -75,16 +77,16 @@ class FlaxT5_VAE_ForAutoencodingModule(nn.Module):
|
|
| 75 |
|
| 76 |
sequence_output = decoder_outputs[0]
|
| 77 |
|
| 78 |
-
if self.config.tie_word_embeddings:
|
| 79 |
# Rescale output before projecting on vocab
|
| 80 |
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
|
| 81 |
-
sequence_output = sequence_output * (self.
|
| 82 |
|
| 83 |
-
if self.config.tie_word_embeddings:
|
| 84 |
-
shared_embedding = self.shared.variables["params"]["embedding"]
|
| 85 |
-
lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output)
|
| 86 |
else:
|
| 87 |
-
lm_logits = self.lm_head(sequence_output)
|
| 88 |
|
| 89 |
if not return_dict:
|
| 90 |
return (lm_logits,) + decoder_outputs[1:] + encoder_outputs
|
|
|
|
| 35 |
self,
|
| 36 |
input_ids=None,
|
| 37 |
attention_mask=None,
|
| 38 |
+
encoder_outputs=None,
|
| 39 |
latent_codes=None,
|
| 40 |
output_attentions=None,
|
| 41 |
output_hidden_states=None,
|
| 42 |
return_dict=None,
|
| 43 |
deterministic: bool = True,
|
| 44 |
):
|
| 45 |
+
# TODO should I use None args when everything has to be computed anyway?
|
| 46 |
"""
|
| 47 |
Adapted from `FlaxT5ForConditionalGenerationModule`
|
| 48 |
"""
|
|
|
|
| 77 |
|
| 78 |
sequence_output = decoder_outputs[0]
|
| 79 |
|
| 80 |
+
if self.t5.config.tie_word_embeddings:
|
| 81 |
# Rescale output before projecting on vocab
|
| 82 |
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
|
| 83 |
+
sequence_output = sequence_output * (self.t5.config.d_model ** -0.5)
|
| 84 |
|
| 85 |
+
if self.t5.config.tie_word_embeddings:
|
| 86 |
+
shared_embedding = self.t5.shared.variables["params"]["embedding"]
|
| 87 |
+
lm_logits = self.t5.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output)
|
| 88 |
else:
|
| 89 |
+
lm_logits = self.t5.lm_head(sequence_output)
|
| 90 |
|
| 91 |
if not return_dict:
|
| 92 |
return (lm_logits,) + decoder_outputs[1:] + encoder_outputs
|
model/vae.py
CHANGED
|
@@ -3,7 +3,6 @@ import flax.linen as nn
|
|
| 3 |
|
| 4 |
from model.encoders import VAE_ENCODER_MODELS
|
| 5 |
from model.decoders import VAE_DECODER_MODELS
|
| 6 |
-
from model.outputs import TransformerVAE_Output
|
| 7 |
from model.config import T5_VAE_Config
|
| 8 |
|
| 9 |
|
|
@@ -18,21 +17,14 @@ class VAE(nn.Module):
|
|
| 18 |
|
| 19 |
def setup(self):
|
| 20 |
self.encoder = VAE_ENCODER_MODELS[self.config.vae_encoder_model](self.config.latent_size, self.config.n_latent_tokens)
|
| 21 |
-
self.decoder = VAE_DECODER_MODELS[self.config.vae_decoder_model](self.config.t5.d_model,
|
| 22 |
|
| 23 |
def __call__(self, encoding=None, latent_codes=None):
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
# return latent_codes for regulariser loss
|
| 27 |
-
return TransformerVAE_Output(
|
| 28 |
-
latent_codes,
|
| 29 |
-
self.decoder(latent_codes),
|
| 30 |
-
)
|
| 31 |
|
| 32 |
def encode(self, encoding):
|
| 33 |
-
assert encoding.shape[1:] == self.input_shape
|
| 34 |
return self.encoder(encoding)
|
| 35 |
|
| 36 |
def decode(self, latent):
|
| 37 |
-
assert latent.shape[1:] == self.input_shape
|
| 38 |
return self.decoder(latent)
|
|
|
|
| 3 |
|
| 4 |
from model.encoders import VAE_ENCODER_MODELS
|
| 5 |
from model.decoders import VAE_DECODER_MODELS
|
|
|
|
| 6 |
from model.config import T5_VAE_Config
|
| 7 |
|
| 8 |
|
|
|
|
| 17 |
|
| 18 |
def setup(self):
|
| 19 |
self.encoder = VAE_ENCODER_MODELS[self.config.vae_encoder_model](self.config.latent_size, self.config.n_latent_tokens)
|
| 20 |
+
self.decoder = VAE_DECODER_MODELS[self.config.vae_decoder_model](self.config.t5.d_model, self.config.n_latent_tokens)
|
| 21 |
|
| 22 |
def __call__(self, encoding=None, latent_codes=None):
|
| 23 |
+
latent_codes = self.encode(encoding)
|
| 24 |
+
return self.decode(latent_codes), latent_codes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
def encode(self, encoding):
|
|
|
|
| 27 |
return self.encoder(encoding)
|
| 28 |
|
| 29 |
def decode(self, latent):
|
|
|
|
| 30 |
return self.decoder(latent)
|
run_clm_flax.py
CHANGED
|
@@ -405,7 +405,7 @@ def main():
|
|
| 405 |
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
| 406 |
for k, t in concatenated_examples.items()
|
| 407 |
}
|
| 408 |
-
result["
|
| 409 |
return result
|
| 410 |
|
| 411 |
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
|
|
@@ -421,6 +421,8 @@ def main():
|
|
| 421 |
num_proc=data_args.preprocessing_num_workers,
|
| 422 |
load_from_cache_file=not data_args.overwrite_cache,
|
| 423 |
)
|
|
|
|
|
|
|
| 424 |
|
| 425 |
if training_args.do_train:
|
| 426 |
if "train" not in tokenized_datasets:
|
|
|
|
| 405 |
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
| 406 |
for k, t in concatenated_examples.items()
|
| 407 |
}
|
| 408 |
+
result["label"] = result["input_ids"].copy()
|
| 409 |
return result
|
| 410 |
|
| 411 |
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
|
|
|
|
| 421 |
num_proc=data_args.preprocessing_num_workers,
|
| 422 |
load_from_cache_file=not data_args.overwrite_cache,
|
| 423 |
)
|
| 424 |
+
import pdb
|
| 425 |
+
pdb.set_trace()
|
| 426 |
|
| 427 |
if training_args.do_train:
|
| 428 |
if "train" not in tokenized_datasets:
|
train.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
| 2 |
Pre-training/Fine-tuning seq2seq models on autoencoding a dataset.
|
| 3 |
|
| 4 |
TODO:
|
|
|
|
| 5 |
- [x] Don't make decoder input ids.
|
| 6 |
- [ ] Add reg loss
|
| 7 |
- [x] calculate MMD loss
|
|
@@ -15,7 +16,7 @@
|
|
| 15 |
use_extra_logs (:obj:`bool`, `optional`, defaults to False):
|
| 16 |
Store extra logs during each training inference.
|
| 17 |
|
| 18 |
-
- [ ] Send the
|
| 19 |
'''
|
| 20 |
import logging
|
| 21 |
import math
|
|
@@ -379,6 +380,10 @@ def main():
|
|
| 379 |
)
|
| 380 |
return output
|
| 381 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 382 |
tokenized_datasets = dataset.map(
|
| 383 |
tokenize_function,
|
| 384 |
batched=True,
|
|
@@ -394,22 +399,23 @@ def main():
|
|
| 394 |
)
|
| 395 |
block_size = min(data_args.block_size, tokenizer.model_max_length)
|
| 396 |
|
| 397 |
-
#
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
|
|
|
| 413 |
|
| 414 |
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
|
| 415 |
# for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
|
|
@@ -419,7 +425,7 @@ def main():
|
|
| 419 |
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
|
| 420 |
|
| 421 |
lm_datasets = tokenized_datasets.map(
|
| 422 |
-
|
| 423 |
batched=True,
|
| 424 |
num_proc=data_args.preprocessing_num_workers,
|
| 425 |
load_from_cache_file=not data_args.overwrite_cache,
|
|
@@ -516,8 +522,8 @@ def main():
|
|
| 516 |
x_size = x.shape[0]
|
| 517 |
y_size = y.shape[0]
|
| 518 |
dim = x.shape[1]
|
| 519 |
-
tiled_x = jnp.repeat(jnp.reshape(x, (x_size, 1, dim)), y_size, axis
|
| 520 |
-
tiled_y = jnp.repeat(jnp.reshape(y, (1, y_size, dim)), x_size, axis
|
| 521 |
return jnp.exp(-jnp.mean((tiled_x - tiled_y) ** 2, axis=2) / dim * 1.0)
|
| 522 |
|
| 523 |
def compute_mmd(x, y):
|
|
@@ -526,16 +532,16 @@ def main():
|
|
| 526 |
xy_kernel = compute_kernel(x, y)
|
| 527 |
return jnp.mean(x_kernel) + jnp.mean(y_kernel) - 2 * jnp.mean(xy_kernel)
|
| 528 |
|
| 529 |
-
def regulariser_loss(latent_codes):
|
| 530 |
-
true_samples =
|
| 531 |
return compute_mmd(true_samples, latent_codes)
|
| 532 |
|
| 533 |
-
def loss_fn(logits, labels, latent_codes):
|
| 534 |
shift_logits = logits[..., :-1, :]
|
| 535 |
shift_labels = labels[..., 1:]
|
| 536 |
loss = optax.softmax_cross_entropy(shift_logits, onehot(shift_labels, shift_logits.shape[-1]))
|
| 537 |
|
| 538 |
-
reg_loss = regulariser_loss(latent_codes)
|
| 539 |
return loss.mean() + reg_loss.mean()
|
| 540 |
|
| 541 |
# Define gradient update step fn
|
|
@@ -544,8 +550,8 @@ def main():
|
|
| 544 |
|
| 545 |
def compute_loss(params):
|
| 546 |
labels = batch.pop("labels")
|
| 547 |
-
|
| 548 |
-
loss = loss_fn(logits, labels, latent_codes)
|
| 549 |
return loss
|
| 550 |
|
| 551 |
grad_fn = jax.value_and_grad(compute_loss)
|
|
|
|
| 2 |
Pre-training/Fine-tuning seq2seq models on autoencoding a dataset.
|
| 3 |
|
| 4 |
TODO:
|
| 5 |
+
- [ ] Get this running.
|
| 6 |
- [x] Don't make decoder input ids.
|
| 7 |
- [ ] Add reg loss
|
| 8 |
- [x] calculate MMD loss
|
|
|
|
| 16 |
use_extra_logs (:obj:`bool`, `optional`, defaults to False):
|
| 17 |
Store extra logs during each training inference.
|
| 18 |
|
| 19 |
+
- [ ] Send the schedule time to the compute_loss method and calculate a coefficient based on that.
|
| 20 |
'''
|
| 21 |
import logging
|
| 22 |
import math
|
|
|
|
| 380 |
)
|
| 381 |
return output
|
| 382 |
|
| 383 |
+
# remove dataset tasks
|
| 384 |
+
for k in dataset.keys():
|
| 385 |
+
dataset[k].info.task_templates = []
|
| 386 |
+
|
| 387 |
tokenized_datasets = dataset.map(
|
| 388 |
tokenize_function,
|
| 389 |
batched=True,
|
|
|
|
| 399 |
)
|
| 400 |
block_size = min(data_args.block_size, tokenizer.model_max_length)
|
| 401 |
|
| 402 |
+
# Limits each input sequence to size block_size.
|
| 403 |
+
pad_token_id = tokenizer.pad_token_id
|
| 404 |
+
|
| 405 |
+
def limit_length(examples):
|
| 406 |
+
examples["labels"] = examples["input_ids"].copy()
|
| 407 |
+
|
| 408 |
+
for i, input_ids in enumerate(examples["input_ids"]):
|
| 409 |
+
if len(input_ids) > block_size:
|
| 410 |
+
for k in examples.keys():
|
| 411 |
+
examples[k][i] = examples[k][i][:block_size]
|
| 412 |
+
elif len(input_ids) < block_size:
|
| 413 |
+
delta = block_size - len(input_ids)
|
| 414 |
+
examples['input_ids'][i] = examples['input_ids'][i] + [pad_token_id] * delta
|
| 415 |
+
examples['attention_mask'][i] = examples['attention_mask'][i] + [0] * delta
|
| 416 |
+
examples['labels'][i] = examples['labels'][i] + [-100] * delta
|
| 417 |
+
|
| 418 |
+
return examples
|
| 419 |
|
| 420 |
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
|
| 421 |
# for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
|
|
|
|
| 425 |
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
|
| 426 |
|
| 427 |
lm_datasets = tokenized_datasets.map(
|
| 428 |
+
limit_length,
|
| 429 |
batched=True,
|
| 430 |
num_proc=data_args.preprocessing_num_workers,
|
| 431 |
load_from_cache_file=not data_args.overwrite_cache,
|
|
|
|
| 522 |
x_size = x.shape[0]
|
| 523 |
y_size = y.shape[0]
|
| 524 |
dim = x.shape[1]
|
| 525 |
+
tiled_x = jnp.repeat(jnp.reshape(x, (x_size, 1, dim)), y_size, axis=1)
|
| 526 |
+
tiled_y = jnp.repeat(jnp.reshape(y, (1, y_size, dim)), x_size, axis=0)
|
| 527 |
return jnp.exp(-jnp.mean((tiled_x - tiled_y) ** 2, axis=2) / dim * 1.0)
|
| 528 |
|
| 529 |
def compute_mmd(x, y):
|
|
|
|
| 532 |
xy_kernel = compute_kernel(x, y)
|
| 533 |
return jnp.mean(x_kernel) + jnp.mean(y_kernel) - 2 * jnp.mean(xy_kernel)
|
| 534 |
|
| 535 |
+
def regulariser_loss(latent_codes, rng: jax.random.PRNGKey):
|
| 536 |
+
true_samples = jax.random.normal(rng, latent_codes.shape())
|
| 537 |
return compute_mmd(true_samples, latent_codes)
|
| 538 |
|
| 539 |
+
def loss_fn(logits, labels, latent_codes, rng: jax.random.PRNGKey):
|
| 540 |
shift_logits = logits[..., :-1, :]
|
| 541 |
shift_labels = labels[..., 1:]
|
| 542 |
loss = optax.softmax_cross_entropy(shift_logits, onehot(shift_labels, shift_logits.shape[-1]))
|
| 543 |
|
| 544 |
+
reg_loss = regulariser_loss(latent_codes, rng)
|
| 545 |
return loss.mean() + reg_loss.mean()
|
| 546 |
|
| 547 |
# Define gradient update step fn
|
|
|
|
| 550 |
|
| 551 |
def compute_loss(params):
|
| 552 |
labels = batch.pop("labels")
|
| 553 |
+
outputs = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)
|
| 554 |
+
loss = loss_fn(outputs.logits, labels, outputs.latent_codes, state.dropout_rng)
|
| 555 |
return loss
|
| 556 |
|
| 557 |
grad_fn = jax.value_and_grad(compute_loss)
|