show where to add reg loss
Browse files
train.py
CHANGED
|
@@ -56,17 +56,15 @@ from flax.jax_utils import unreplicate
|
|
| 56 |
from flax.training import train_state
|
| 57 |
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
| 58 |
from transformers import (
|
| 59 |
-
CONFIG_MAPPING,
|
| 60 |
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
|
| 61 |
-
AutoConfig,
|
| 62 |
AutoTokenizer,
|
| 63 |
-
FlaxAutoModelForCausalLM,
|
| 64 |
HfArgumentParser,
|
| 65 |
TrainingArguments,
|
| 66 |
is_tensorboard_available,
|
| 67 |
)
|
| 68 |
from transformers.testing_utils import CaptureLogger
|
| 69 |
|
|
|
|
| 70 |
from model.config import T5_VAE_Config
|
| 71 |
|
| 72 |
|
|
@@ -526,10 +524,11 @@ def main():
|
|
| 526 |
# Setup train state
|
| 527 |
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
|
| 528 |
|
| 529 |
-
def loss_fn(logits, labels):
|
| 530 |
shift_logits = logits[..., :-1, :]
|
| 531 |
shift_labels = labels[..., 1:]
|
| 532 |
loss = optax.softmax_cross_entropy(shift_logits, onehot(shift_labels, shift_logits.shape[-1]))
|
|
|
|
| 533 |
return loss.mean()
|
| 534 |
|
| 535 |
# Define gradient update step fn
|
|
@@ -538,8 +537,8 @@ def main():
|
|
| 538 |
|
| 539 |
def compute_loss(params):
|
| 540 |
labels = batch.pop("labels")
|
| 541 |
-
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[
|
| 542 |
-
loss = loss_fn(logits, labels)
|
| 543 |
return loss
|
| 544 |
|
| 545 |
grad_fn = jax.value_and_grad(compute_loss)
|
|
@@ -556,8 +555,8 @@ def main():
|
|
| 556 |
# Define eval fn
|
| 557 |
def eval_step(params, batch):
|
| 558 |
labels = batch.pop("labels")
|
| 559 |
-
logits = model(**batch, params=params, train=False)[
|
| 560 |
-
loss = loss_fn(logits, labels)
|
| 561 |
|
| 562 |
# summarize metrics
|
| 563 |
metrics = {"loss": loss}
|
|
|
|
| 56 |
from flax.training import train_state
|
| 57 |
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
| 58 |
from transformers import (
|
|
|
|
| 59 |
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
|
|
|
|
| 60 |
AutoTokenizer,
|
|
|
|
| 61 |
HfArgumentParser,
|
| 62 |
TrainingArguments,
|
| 63 |
is_tensorboard_available,
|
| 64 |
)
|
| 65 |
from transformers.testing_utils import CaptureLogger
|
| 66 |
|
| 67 |
+
from model.t5_vae import Funnel_T5_VAE_Model
|
| 68 |
from model.config import T5_VAE_Config
|
| 69 |
|
| 70 |
|
|
|
|
| 524 |
# Setup train state
|
| 525 |
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
|
| 526 |
|
| 527 |
+
def loss_fn(logits, labels, latent_codes):
|
| 528 |
shift_logits = logits[..., :-1, :]
|
| 529 |
shift_labels = labels[..., 1:]
|
| 530 |
loss = optax.softmax_cross_entropy(shift_logits, onehot(shift_labels, shift_logits.shape[-1]))
|
| 531 |
+
# TODO add reg loss here
|
| 532 |
return loss.mean()
|
| 533 |
|
| 534 |
# Define gradient update step fn
|
|
|
|
| 537 |
|
| 538 |
def compute_loss(params):
|
| 539 |
labels = batch.pop("labels")
|
| 540 |
+
logits, latent_codes = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[:2]
|
| 541 |
+
loss = loss_fn(logits, labels, latent_codes)
|
| 542 |
return loss
|
| 543 |
|
| 544 |
grad_fn = jax.value_and_grad(compute_loss)
|
|
|
|
| 555 |
# Define eval fn
|
| 556 |
def eval_step(params, batch):
|
| 557 |
labels = batch.pop("labels")
|
| 558 |
+
logits, latent_codes = model(**batch, params=params, train=False)[:2]
|
| 559 |
+
loss = loss_fn(logits, labels, latent_codes)
|
| 560 |
|
| 561 |
# summarize metrics
|
| 562 |
metrics = {"loss": loss}
|