Fraser commited on
Commit
03c20bd
·
1 Parent(s): 923329d

show where to add reg loss

Browse files
Files changed (1) hide show
  1. train.py +7 -8
train.py CHANGED
@@ -56,17 +56,15 @@ from flax.jax_utils import unreplicate
56
  from flax.training import train_state
57
  from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
58
  from transformers import (
59
- CONFIG_MAPPING,
60
  FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
61
- AutoConfig,
62
  AutoTokenizer,
63
- FlaxAutoModelForCausalLM,
64
  HfArgumentParser,
65
  TrainingArguments,
66
  is_tensorboard_available,
67
  )
68
  from transformers.testing_utils import CaptureLogger
69
 
 
70
  from model.config import T5_VAE_Config
71
 
72
 
@@ -526,10 +524,11 @@ def main():
526
  # Setup train state
527
  state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
528
 
529
- def loss_fn(logits, labels):
530
  shift_logits = logits[..., :-1, :]
531
  shift_labels = labels[..., 1:]
532
  loss = optax.softmax_cross_entropy(shift_logits, onehot(shift_labels, shift_logits.shape[-1]))
 
533
  return loss.mean()
534
 
535
  # Define gradient update step fn
@@ -538,8 +537,8 @@ def main():
538
 
539
  def compute_loss(params):
540
  labels = batch.pop("labels")
541
- logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
542
- loss = loss_fn(logits, labels)
543
  return loss
544
 
545
  grad_fn = jax.value_and_grad(compute_loss)
@@ -556,8 +555,8 @@ def main():
556
  # Define eval fn
557
  def eval_step(params, batch):
558
  labels = batch.pop("labels")
559
- logits = model(**batch, params=params, train=False)[0]
560
- loss = loss_fn(logits, labels)
561
 
562
  # summarize metrics
563
  metrics = {"loss": loss}
 
56
  from flax.training import train_state
57
  from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
58
  from transformers import (
 
59
  FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
 
60
  AutoTokenizer,
 
61
  HfArgumentParser,
62
  TrainingArguments,
63
  is_tensorboard_available,
64
  )
65
  from transformers.testing_utils import CaptureLogger
66
 
67
+ from model.t5_vae import Funnel_T5_VAE_Model
68
  from model.config import T5_VAE_Config
69
 
70
 
 
524
  # Setup train state
525
  state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
526
 
527
+ def loss_fn(logits, labels, latent_codes):
528
  shift_logits = logits[..., :-1, :]
529
  shift_labels = labels[..., 1:]
530
  loss = optax.softmax_cross_entropy(shift_logits, onehot(shift_labels, shift_logits.shape[-1]))
531
+ # TODO add reg loss here
532
  return loss.mean()
533
 
534
  # Define gradient update step fn
 
537
 
538
  def compute_loss(params):
539
  labels = batch.pop("labels")
540
+ logits, latent_codes = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[:2]
541
+ loss = loss_fn(logits, labels, latent_codes)
542
  return loss
543
 
544
  grad_fn = jax.value_and_grad(compute_loss)
 
555
  # Define eval fn
556
  def eval_step(params, batch):
557
  labels = batch.pop("labels")
558
+ logits, latent_codes = model(**batch, params=params, train=False)[:2]
559
+ loss = loss_fn(logits, labels, latent_codes)
560
 
561
  # summarize metrics
562
  metrics = {"loss": loss}