Fraser commited on
Commit
3497606
·
1 Parent(s): 1df308b

use latest

Browse files
Files changed (2) hide show
  1. run_clm_flax.py +18 -13
  2. train.py +18 -10
run_clm_flax.py CHANGED
@@ -1,6 +1,3 @@
1
- '''
2
- Using this to compare & make latest changes to train.py
3
- '''
4
  #!/usr/bin/env python
5
  # coding=utf-8
6
  # Copyright 2021 The HuggingFace Team All rights reserved.
@@ -401,7 +398,8 @@ def main():
401
  total_length = len(concatenated_examples[list(examples.keys())[0]])
402
  # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
403
  # customize this part to your needs.
404
- total_length = (total_length // block_size) * block_size
 
405
  # Split by chunks of max_len.
406
  result = {
407
  k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
@@ -492,17 +490,24 @@ def main():
492
  return traverse_util.unflatten_dict(flat_mask)
493
 
494
  # create adam optimizer
495
- adamw = optax.adamw(
496
- learning_rate=linear_decay_lr_schedule_fn,
497
- b1=training_args.adam_beta1,
498
- b2=training_args.adam_beta2,
499
- eps=training_args.adam_epsilon,
500
- weight_decay=training_args.weight_decay,
501
- mask=decay_mask_fn,
502
- )
 
 
 
 
 
 
 
503
 
504
  # Setup train state
505
- state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
506
 
507
  def loss_fn(logits, labels):
508
  shift_logits = logits[..., :-1, :]
 
 
 
 
1
  #!/usr/bin/env python
2
  # coding=utf-8
3
  # Copyright 2021 The HuggingFace Team All rights reserved.
 
398
  total_length = len(concatenated_examples[list(examples.keys())[0]])
399
  # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
400
  # customize this part to your needs.
401
+ if total_length >= block_size:
402
+ total_length = (total_length // block_size) * block_size
403
  # Split by chunks of max_len.
404
  result = {
405
  k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
 
490
  return traverse_util.unflatten_dict(flat_mask)
491
 
492
  # create adam optimizer
493
+ if training_args.adafactor:
494
+ # We use the default parameters here to initialize adafactor,
495
+ # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
496
+ optimizer = optax.adafactor(
497
+ learning_rate=linear_decay_lr_schedule_fn,
498
+ )
499
+ else:
500
+ optimizer = optax.adamw(
501
+ learning_rate=linear_decay_lr_schedule_fn,
502
+ b1=training_args.adam_beta1,
503
+ b2=training_args.adam_beta2,
504
+ eps=training_args.adam_epsilon,
505
+ weight_decay=training_args.weight_decay,
506
+ mask=decay_mask_fn,
507
+ )
508
 
509
  # Setup train state
510
+ state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)
511
 
512
  def loss_fn(logits, labels):
513
  shift_logits = logits[..., :-1, :]
train.py CHANGED
@@ -401,7 +401,8 @@ def main():
401
  total_length = len(concatenated_examples[list(examples.keys())[0]])
402
  # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
403
  # customize this part to your needs.
404
- total_length = (total_length // block_size) * block_size
 
405
  # Split by chunks of max_len.
406
  result = {
407
  k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
@@ -492,17 +493,24 @@ def main():
492
  return traverse_util.unflatten_dict(flat_mask)
493
 
494
  # create adam optimizer
495
- adamw = optax.adamw(
496
- learning_rate=linear_decay_lr_schedule_fn,
497
- b1=training_args.adam_beta1,
498
- b2=training_args.adam_beta2,
499
- eps=training_args.adam_epsilon,
500
- weight_decay=training_args.weight_decay,
501
- mask=decay_mask_fn,
502
- )
 
 
 
 
 
 
 
503
 
504
  # Setup train state
505
- state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
506
 
507
  def compute_kernel(x, y):
508
  x_size = x.shape[0]
 
401
  total_length = len(concatenated_examples[list(examples.keys())[0]])
402
  # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
403
  # customize this part to your needs.
404
+ if total_length >= block_size:
405
+ total_length = (total_length // block_size) * block_size
406
  # Split by chunks of max_len.
407
  result = {
408
  k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
 
493
  return traverse_util.unflatten_dict(flat_mask)
494
 
495
  # create adam optimizer
496
+ if training_args.adafactor:
497
+ # We use the default parameters here to initialize adafactor,
498
+ # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
499
+ optimizer = optax.adafactor(
500
+ learning_rate=linear_decay_lr_schedule_fn,
501
+ )
502
+ else:
503
+ optimizer = optax.adamw(
504
+ learning_rate=linear_decay_lr_schedule_fn,
505
+ b1=training_args.adam_beta1,
506
+ b2=training_args.adam_beta2,
507
+ eps=training_args.adam_epsilon,
508
+ weight_decay=training_args.weight_decay,
509
+ mask=decay_mask_fn,
510
+ )
511
 
512
  # Setup train state
513
+ state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)
514
 
515
  def compute_kernel(x, y):
516
  x_size = x.shape[0]