| | import jax |
| | print(jax.local_device_count()) |
| | import jax.numpy as jnp |
| |
|
| | import flax |
| | import flax.linen as nn |
| | from flax.training.common_utils import get_metrics,onehot,shard,shard_prng_key |
| | from flax.training import train_state |
| | from flax.metrics.tensorboard import SummaryWriter |
| | from flax.training import checkpoints |
| |
|
| |
|
| | import logging |
| | import optax |
| | import math |
| | from tqdm import tqdm |
| |
|
| | from pathlib import Path |
| | from typing import Callable |
| | from itertools import chain |
| | from flax.metrics import tensorboard |
| |
|
| | from datasets import load_dataset,load_metric |
| | from transformers import GPT2Config,GPT2Tokenizer |
| |
|
| | from model_file import FlaxGPT2ForMultipleChoice |
| |
|
| | logger = logging.getLogger() |
| | logger.setLevel(logging.INFO) |
| |
|
| | def main(): |
| |
|
| |
|
| | tokenizer=GPT2Tokenizer.from_pretrained('gpt2',pad_token='<|endoftext|>') |
| |
|
| | dataset=load_dataset('cosmos_qa') |
| |
|
| | def preprocess(example): |
| | example['context&question']=example['context']+example['question'] |
| | example['first_sentence']=[example['context&question'],example['context&question'],example['context&question'],example['context&question']] |
| | example['second_sentence']=example['answer0'],example['answer1'],example['answer2'],example['answer3'] |
| | return example |
| |
|
| | train_dataset=dataset['train'].map(preprocess) |
| | validation_dataset=dataset['validation'].map(preprocess) |
| | test_dataset=dataset['test'].map(preprocess) |
| |
|
| | |
| | len_train_dataset=25262 |
| | len_validation_dataset=2985 |
| | len_test_dataset=6963 |
| |
|
| | train_dataset=train_dataset.select(range(len_train_dataset)) |
| | test_dataset=test_dataset.select(range(len_test_dataset)) |
| | validation_dataset=validation_dataset.select(range(len_validation_dataset)) |
| |
|
| | |
| |
|
| | def tokenize(examples): |
| | a=tokenizer(examples['first_sentence'],examples['second_sentence'],padding='max_length',truncation=True,max_length=256,return_tensors='jax') |
| | a['labels']=examples['label'] |
| | return a |
| |
|
| | train_dataset=train_dataset.map(tokenize) |
| | validation_dataset=validation_dataset.map(tokenize) |
| | test_dataset=test_dataset.map(tokenize) |
| |
|
| | remov_col=['id', 'context', 'question', 'answer0', 'answer1', 'answer2', 'answer3', 'labels', 'context&question', 'first_sentence', 'second_sentence'] |
| |
|
| | train_dataset=train_dataset.remove_columns(remov_col) |
| | validation_dataset=validation_dataset.remove_columns(remov_col) |
| | test_dataset=test_dataset.remove_columns(remov_col) |
| |
|
| | per_device_batch_size=4 |
| | seed=0 |
| | num_train_epochs=5 |
| | learning_rate=2e-5 |
| | |
| |
|
| | total_batch_size = per_device_batch_size * jax.local_device_count() |
| | print('The overall batch size (both for training and eval) is', total_batch_size) |
| | num_train_steps = len(train_dataset) // total_batch_size * num_train_epochs |
| | num_validation_steps=len(validation_dataset)//total_batch_size*num_train_epochs |
| |
|
| | learning_rate_function = optax.linear_schedule(init_value=learning_rate, end_value=0, transition_steps=num_train_steps) |
| |
|
| | class TrainState(train_state.TrainState): |
| | logits_function:Callable=flax.struct.field(pytree_node=False) |
| | loss_function:Callable=flax.struct.field(pytree_node=False) |
| |
|
| | def adamw(weight_decay): |
| | return optax.adamw(learning_rate=learning_rate_function,b1=0.9,b2=0.99,eps=1e-6,weight_decay=weight_decay) |
| |
|
| | decay_path=lambda p:not any(x in p for x in ['bias','LayerNorm.weight']) |
| |
|
| | def traverse(function): |
| | def mask(data): |
| | flat=flax.traverse_util.flatten_dict(data) |
| | return flax.traverse_util.unflatten_dict({k:function(k,v) for k,v in flat.items()}) |
| | return mask |
| |
|
| | gradient_transformation=optax.chain( |
| | optax.masked(adamw(0.0),mask=traverse(lambda path,_:decay_path(path))), |
| | optax.masked(adamw(0.01),mask=traverse(lambda path,_:not decay_path(path)))) |
| |
|
| | def loss_function(logits,labels): |
| | logits=flax.linen.log_softmax(logits) |
| | xentropy=optax.softmax_cross_entropy(logits,onehot(labels,num_classes=4)) |
| | return jnp.mean(xentropy) |
| |
|
| | def eval_function(logits): |
| | return logits.argmax(-1) |
| |
|
| | model = FlaxGPT2ForMultipleChoice.from_pretrained('gpt2-medium',input_shape=(1,4,1)) |
| |
|
| | state=TrainState.create(apply_fn=model.__call__, |
| | params=model.params, |
| | tx=gradient_transformation, |
| | logits_function=eval_function, |
| | loss_function=loss_function) |
| | |
| | def train_step(state,batch,dropout_rng): |
| | targets=batch.pop("label") |
| | dropout_rng,new_dropout_rng=jax.random.split(dropout_rng) |
| | def loss_function(params): |
| | logits=state.apply_fn(**batch,params=params,dropout_rng=dropout_rng,train=True)[0] |
| | loss=state.loss_function(logits,targets) |
| | return loss |
| | grad_function=jax.value_and_grad(loss_function) |
| | loss,grad=grad_function(state.params) |
| | grad=jax.lax.pmean(grad,"batch") |
| | new_state=state.apply_gradients(grads=grad) |
| | |
| | logits=new_state.apply_fn(**batch,params=new_state.params,dropout_rng=dropout_rng,train=True)[0] |
| | accuracy=jnp.equal(jnp.argmax(logits,axis=-1),targets) |
| | metrics=jax.lax.pmean({"loss":loss,"learning_rate":learning_rate_function(state.step),'accuracy':accuracy},axis_name="batch") |
| | return new_state,metrics,new_dropout_rng |
| |
|
| | parallel_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,)) |
| |
|
| | def eval_step(state, batch): |
| | targets=batch.pop('label') |
| | logits = state.apply_fn(**batch, params=state.params, train=False) |
| | loss=state.loss_function(logits,targets) |
| | predictions=state.logits_function(logits) |
| | eval_accuracy=jnp.equal(predictions,targets) |
| | |
| | metrics=jax.lax.pmean({"loss":loss,'accuracy':eval_accuracy},axis_name="batch") |
| | |
| | return targets,predictions,metrics |
| |
|
| | parallel_eval_step = jax.pmap(eval_step, axis_name="batch") |
| |
|
| | def glue_train_data_loader(rng,dataset,batch_size): |
| | steps_per_epoch=len_train_dataset//batch_size |
| | perms=jax.random.permutation(rng,len(dataset)) |
| | perms=perms[:steps_per_epoch*batch_size] |
| | perms=perms.reshape((steps_per_epoch,batch_size)) |
| | for perm in perms: |
| | batch=dataset[perm] |
| | batch={k:jnp.array(v) for k,v in batch.items()} |
| | batch=shard(batch) |
| | yield batch |
| |
|
| | rng=jax.random.PRNGKey(seed) |
| | dropout_rngs=jax.random.split(rng,jax.local_device_count()) |
| |
|
| | def glue_eval_data_loader(dataset, batch_size): |
| | for i in range(len_validation_dataset // batch_size): |
| | batch = dataset[i * batch_size : (i + 1) * batch_size] |
| | batch = {k: jnp.array(v) for k, v in batch.items()} |
| | batch = shard(batch) |
| |
|
| | yield batch |
| |
|
| | state = flax.jax_utils.replicate(state) |
| | |
| |
|
| | actual_task = "mnli" |
| | metric = load_metric('glue', "mnli") |
| | actual_taskmetric = load_metric('glue', actual_task) |
| |
|
| | workdir='../results_tensorboard' |
| | summary_writer = tensorboard.SummaryWriter(workdir) |
| | |
| |
|
| | logger.info(f"***** Running training *****") |
| | logger.info(f" Num examples = {len_train_dataset}") |
| | logger.info(f" Num Epochs = {num_train_epochs}") |
| | logger.info(f" Instantaneous batch size per device = {per_device_batch_size}") |
| | logger.info(f" Total train batch size = {total_batch_size}") |
| | logger.info(f" Total optimization steps = {num_train_steps}") |
| |
|
| | for i, epoch in enumerate(tqdm(range(1, num_train_epochs+1), desc=f"Epoch ...", position=0, leave=True)): |
| | rng, input_rng = jax.random.split(rng) |
| | train_acc_metrics=[] |
| | train_loss_metrics=[] |
| | eval_acc_metrics=[] |
| | eval_loss_metrics=[] |
| | |
| | with tqdm(total=len_train_dataset // total_batch_size, desc="Training...", leave=False) as progress_bar_train: |
| | for idx,batch in enumerate(glue_train_data_loader(input_rng, train_dataset, total_batch_size)): |
| | state, train_metric, dropout_rngs = parallel_train_step(state, batch, dropout_rngs) |
| | train_acc_metrics.append(jax.device_get(train_metric['accuracy']).mean().item()) |
| | train_loss_metrics.append(flax.jax_utils.unreplicate(train_metric)['loss'].item()) |
| | if idx%5==0: |
| | summary_writer.scalar('train_loss',flax.jax_utils.unreplicate(train_metric)['loss'].item(),idx) |
| | summary_writer.scalar('train_accuracy', jax.device_get(train_metric['accuracy']).mean().item(),idx) |
| | if idx%50==0: |
| | logger.info(f"train_step_loss{idx}: {flax.jax_utils.unreplicate(train_metric)['loss'].item()} train_step_acc{idx}: {jax.device_get(train_metric['accuracy']).mean().item()}") |
| | |
| | progress_bar_train.update(1) |
| |
|
| | |
| | with tqdm(total=len_validation_dataset // total_batch_size, desc="Evaluating...", leave=False) as progress_bar_eval: |
| | for idx,batch in enumerate(glue_eval_data_loader(validation_dataset, total_batch_size)): |
| | labels,predictions,eval_metric=parallel_eval_step(state, batch) |
| | eval_acc_metrics.append(jax.device_get(eval_metric['accuracy']).mean().item()) |
| | eval_loss_metrics.append(flax.jax_utils.unreplicate(eval_metric)['loss'].item()) |
| | progress_bar_eval.update(1) |
| | if idx%5==0: |
| | logger.info(f"eval_step_loss {idx} : {flax.jax_utils.unreplicate(eval_metric)['loss'].item()} eval_step_acc {idx} : {jax.device_get(eval_metric['accuracy']).mean().item()}") |
| | summary_writer.scalar('eval_loss : ', flax.jax_utils.unreplicate(eval_metric)['loss'].item(),idx) |
| | summary_writer.scalar('eval_accuracy : ', jax.device_get(eval_metric['accuracy']).mean().item(),idx) |
| |
|
| | if jax.process_index() == 0: |
| | params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) |
| | |
| | model.save_pretrained( |
| | '../', |
| | params=params, |
| | push_to_hub=True, |
| | commit_message=f"Saving weights of epoch {epoch} at step {idx}",) |
| | |
| | logger.info(f"---------------------Epoch {epoch} done-----------------") |
| | logger.info(f"Train loss: {jax.device_get(jnp.array(train_loss_metrics)).mean().item()} Train accuracy: {jax.device_get(jnp.array(train_acc_metrics)).mean().item()}") |
| | logger.info(f"Eval loss: {jax.device_get(jnp.array(eval_loss_metrics)).mean().item()} Eval accuracy: {jax.device_get(jnp.array(eval_acc_metrics)).mean().item()}") |
| | summary_writer.flush() |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|