| | ''' |
| | This is a training script for finetuning ESM. |
| | I am going to freeze the parameters in the head and unfreeze the last N layers in the model. |
| | ''' |
| |
|
| | import os |
| | import fuson_plm.training.config as config |
| |
|
| | |
| | os.environ['WANDB_API_KEY'] = config.WANDB_API_KEY |
| | os.environ['CUDA_VISIBLE_DEVICES'] = config.CUDA_VISIBLE_DEVICES |
| |
|
| | import torch |
| | import numpy as np |
| | import pandas as pd |
| | import tqdm |
| | from datetime import datetime |
| | import wandb |
| | import pytz |
| | import sys |
| |
|
| | from transformers import AdamW |
| |
|
| | from fuson_plm.utils.logging import print_configpy, get_local_time, open_logfile, open_errfile, log_update |
| | from fuson_plm.training.model import FusOnpLM |
| | from fuson_plm.training.utils import batch_sample_mask_tokens_with_probabilities, get_dataloader, check_dataloaders, get_mask_rate_scheduler |
| | from fuson_plm.training.plot import make_train_val_test_bd_plot |
| |
|
| | def prepare_model(model, n_unfrozen_layers, unfreeze_query=True, unfreeze_key=True, unfreeze_value=True): |
| | |
| | n_layers = model.count_encoder_layers() |
| | total_params = sum(p.numel() for p in model.parameters()) |
| | total_head_params = sum(p.numel() for p in model.lm_head.parameters()) |
| | log_update(f'\nInitial state:\n\tTotal number of layers in the model: {n_layers}') |
| | log_update(f'\tTotal parameters in the AutoModelforMaskedLM model: {total_params}') |
| | log_update(f'\tTotal parameters in the MLM Head ONLY: {total_head_params}') |
| | |
| | |
| | model.freeze_model() |
| | n_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| | log_update(f'Froze all {model.n_layers} model layers') |
| | log_update(f'\tTrainable params: {n_trainable_params}') |
| | |
| | |
| | model.unfreeze_last_n_layers(n_unfrozen_layers, unfreeze_query=unfreeze_query, unfreeze_key=unfreeze_key, unfreeze_value=unfreeze_value) |
| | n_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| | trainable_params = '\n\t\t'.join([name for name, param in model.named_parameters() if param.requires_grad]) |
| | num_trainable_params_lm_head = sum(p.numel() for p in model.lm_head.parameters() if p.requires_grad) |
| | num_trainable_params_esm = sum(p.numel() for p in model.esm.parameters() if p.requires_grad) |
| | log_update(f'Unfroze final {n_unfrozen_layers} layers') |
| | log_update(f'\tTrainable params: {n_trainable_params}\n\t\t{trainable_params}') |
| | log_update(f"\tTrainable parameters in the lm_head: {num_trainable_params_lm_head}") |
| | log_update(f"\tTrainable params in the ESM part: {num_trainable_params_esm}") |
| | |
| | def train(model, tokenizer, optimizer, train_loader, val_loader, n_epochs=10, start_epoch=1, mask_percentage=0.15, mask_rate_scheduler=None, device='cuda', checkpoint_dir='./checkpoints'): |
| | """ |
| | Train the model |
| | """ |
| | |
| | log_update("\n") |
| | |
| | for epoch in range(start_epoch, start_epoch+n_epochs): |
| | if mask_rate_scheduler is not None: |
| | mask_rate_scheduler.reset() |
| | |
| | model.train() |
| | total_train_loss = 0 |
| | total_weighted_train_loss = 0 |
| | total_train_masked_tokens = 0 |
| |
|
| | log_update(f"Epoch {epoch}") |
| | |
| | with tqdm.tqdm(enumerate(train_loader), total=len(train_loader), desc='Training Batch', leave=True, position=0) as pbar: |
| | for batch_idx, (inputs, prob) in pbar: |
| | |
| | masking_rate = mask_percentage |
| | if mask_rate_scheduler is not None: |
| | mask_rate_scheduler.step() |
| | masking_rate = mask_rate_scheduler.get_masking_rate() |
| | log_update(f"\tBatch index: {batch_idx}\tMasking rate: {masking_rate:.5f}") |
| | |
| | |
| | inputs = {k: v.to(device) for k, v in inputs.items()} |
| | prob = prob.to(device) |
| | |
| | |
| | masked_inputs = batch_sample_mask_tokens_with_probabilities(inputs, prob, tokenizer,mask_percentage=masking_rate) |
| | |
| | |
| | optimizer.zero_grad() |
| | outputs = model(**masked_inputs) |
| | loss = outputs.loss |
| | loss.backward() |
| | optimizer.step() |
| | |
| | |
| | num_masked_tokens = (masked_inputs["input_ids"] == tokenizer.mask_token_id).sum().item() |
| | |
| | |
| | total_train_loss += loss.item() |
| | total_weighted_train_loss += loss.item() * num_masked_tokens |
| | total_train_masked_tokens += num_masked_tokens |
| | wandb.log({"batch_loss": loss.item()}) |
| |
|
| | |
| | checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}') |
| | model.save_model(checkpoint_path, optimizer=optimizer) |
| | log_update(f'\nSaved checkpoint to {checkpoint_path}') |
| |
|
| | |
| | n_train_batches = len(train_loader) |
| | avg_train_loss = total_train_loss / n_train_batches |
| | avg_weighted_train_loss = total_weighted_train_loss / total_train_masked_tokens |
| | train_perplexity = np.exp(avg_weighted_train_loss) |
| | wandb.log({"epoch": epoch, |
| | "total_train_loss": total_train_loss, "weighted_train_loss": total_weighted_train_loss, |
| | "avg_train_loss": avg_train_loss, "avg_weighted_train_loss": avg_weighted_train_loss, |
| | "train_perplexity": train_perplexity}) |
| | |
| | |
| | train_stats_df = pd.DataFrame(data={ |
| | "epoch": [epoch], |
| | "total_train_loss": [total_train_loss], "weighted_train_loss": [total_weighted_train_loss], |
| | "avg_train_loss": [avg_train_loss], "avg_weighted_train_loss": [avg_weighted_train_loss], |
| | "train_perplexity": [train_perplexity] |
| | }) |
| | if os.path.exists(f"{checkpoint_dir}/train_curve.csv"): |
| | train_stats_df.to_csv(f"{checkpoint_dir}/train_curve.csv",index=False,header=False,mode='a') |
| | else: |
| | train_stats_df.to_csv(f"{checkpoint_dir}/train_curve.csv",index=False) |
| |
|
| | |
| | model.eval() |
| | total_val_loss = 0 |
| | total_weighted_val_loss = 0 |
| | total_val_masked_tokens = 0 |
| | |
| | with torch.no_grad(): |
| | |
| | with tqdm.tqdm(enumerate(val_loader), total=len(val_loader), desc='Validation Batch', leave=True, position=0) as vbar: |
| | for batch_idx, (inputs, prob) in vbar: |
| | |
| | inputs = {k: v.to(device) for k, v in inputs.items()} |
| | prob = prob.to(device) |
| | |
| | |
| | |
| | masked_inputs = batch_sample_mask_tokens_with_probabilities(inputs, prob, tokenizer,mask_percentage=0.15) |
| | |
| | |
| | outputs = model(**masked_inputs) |
| | val_loss = outputs.loss |
| | |
| | |
| | num_masked_tokens = (masked_inputs["input_ids"] == tokenizer.mask_token_id).sum().item() |
| | |
| | |
| | total_val_loss += val_loss.item() |
| | total_weighted_val_loss += val_loss.item() * num_masked_tokens |
| | total_val_masked_tokens += num_masked_tokens |
| |
|
| | |
| | n_val_batches = len(val_loader) |
| | avg_val_loss = total_val_loss / n_val_batches |
| | avg_weighted_val_loss = total_weighted_val_loss / total_val_masked_tokens |
| | val_perplexity = np.exp(avg_weighted_val_loss) |
| | wandb.log({"epoch": epoch, |
| | "total_val_loss": total_val_loss, "weighted_val_loss": total_weighted_val_loss, |
| | "avg_val_loss": avg_val_loss, "avg_weighted_val_loss": avg_weighted_val_loss, |
| | "val_perplexity": val_perplexity}) |
| | |
| | |
| | val_stats_df = pd.DataFrame(data={ |
| | "epoch": [epoch], |
| | "total_val_loss": [total_val_loss], "weighted_val_loss": [total_weighted_val_loss], |
| | "avg_val_loss": [avg_val_loss], "avg_weighted_val_loss": [avg_weighted_val_loss], |
| | "val_perplexity": [val_perplexity] |
| | }) |
| | if os.path.exists(f"{checkpoint_dir}/val_curve.csv"): |
| | val_stats_df.to_csv(f"{checkpoint_dir}/val_curve.csv",index=False,header=False,mode='a') |
| | else: |
| | val_stats_df.to_csv(f"{checkpoint_dir}/val_curve.csv",index=False) |
| | |
| | log_update(f"Epoch: {epoch}") |
| | log_update(f"\tTrain set: Total batches = {n_train_batches}, Total masked tokens = {total_train_masked_tokens}, Total Loss = {total_train_loss:.4f}, Avg Batch Loss = {avg_train_loss:.4f}, Avg Masked Token-Weighted Loss = {avg_weighted_train_loss:.4f}, Perplexity = {train_perplexity:.4f}") |
| | log_update(f"\tValidation set: Total batches = {n_val_batches}, Total masked tokens = {total_val_masked_tokens}, Total Loss = {total_val_loss:.4f}, Avg Batch Loss = {avg_val_loss:.4f}, Avg Masked Token-Weighted Loss = {avg_weighted_val_loss:.4f}, Perplexity = {val_perplexity:.4f}") |
| |
|
| | def test(model, tokenizer, test_loader, mask_percentage=0.15, device='cuda', checkpoint_dir='./checkpoints'): |
| | """ |
| | """ |
| | model.to(device) |
| | model.eval() |
| | total_test_loss = 0 |
| | total_weighted_test_loss = 0 |
| | total_test_masked_tokens = 0 |
| |
|
| | with torch.no_grad(): |
| | |
| | with tqdm.tqdm(enumerate(test_loader), total=len(test_loader), desc='Test Batch', leave=True, position=0) as tbar: |
| | for batch_idx, (inputs, prob) in tbar: |
| | |
| | inputs = {k: v.to(device) for k, v in inputs.items()} |
| | prob = prob.to(device) |
| | |
| | |
| | |
| | masked_inputs = batch_sample_mask_tokens_with_probabilities(inputs, prob, tokenizer, mask_percentage=0.15) |
| | |
| | |
| | outputs = model(**masked_inputs) |
| | test_loss = outputs.loss |
| | |
| | |
| | num_masked_tokens = (masked_inputs["input_ids"] == tokenizer.mask_token_id).sum().item() |
| |
|
| | |
| | total_test_loss += test_loss.item() |
| | total_weighted_test_loss += test_loss.item() * num_masked_tokens |
| | total_test_masked_tokens += num_masked_tokens |
| |
|
| | |
| | n_test_batches = len(test_loader) |
| | avg_test_loss = total_test_loss / n_test_batches |
| | avg_weighted_test_loss = total_weighted_test_loss / total_test_masked_tokens |
| | test_perplexity = np.exp(avg_weighted_test_loss) |
| |
|
| | log_update(f"\nTest results:\nTotal batches = {n_test_batches}, Total masked tokens = {total_test_masked_tokens}, Total Loss = {total_test_loss:.4f}, Avg Batch Loss = {avg_test_loss:.4f}, Avg Masked Token-Weighted Loss = {avg_weighted_test_loss:.4f}, Perplexity = {test_perplexity:.4f}") |
| | |
| | |
| | test_stats_df = pd.DataFrame(data={ |
| | "total_test_loss": [total_test_loss], "weighted_test_loss": [total_weighted_test_loss], |
| | "avg_test_loss": [avg_test_loss], "avg_weighted_test_loss": [avg_weighted_test_loss], |
| | "test_perplexity": [test_perplexity] |
| | }) |
| | test_stats_df.to_csv(f"{checkpoint_dir}/test_results.csv",index=False) |
| |
|
| | def check_env_variables(): |
| | log_update("\nChecking on environment variables...") |
| | log_update(f"\tWANDB_API_KEY: {os.environ.get('WANDB_API_KEY')}") |
| | log_update(f"\tCUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}") |
| | log_update(f"\ttorch.cuda.device_count(): {torch.cuda.device_count()}") |
| | for i in range(torch.cuda.device_count()): |
| | log_update(f"\t\tDevice {i}: {torch.cuda.get_device_name(i)}") |
| | |
| | def intialize_model_and_optimizer(finetune_from_scratch, device, path_to_starting_ckpt=None, learning_rate=1e-4, n_unfrozen_layers=0, unfreeze_query=False, unfreeze_key=False, unfreeze_value=False): |
| | """ |
| | Initializes the model, either from ESM-2-650M if finetuning from scratch, or from a prior checkpoint if not finetuning from scratch. |
| | Also prepares |
| | |
| | Args: |
| | finetune_from_scratch (bool): True if finetuning from scratch. False if finetuning from a previous ckpt |
| | path_to_starting_ckpt (str): path to starting ckpt for finetuning (optional) |
| | """ |
| | if not(finetune_from_scratch) and not(os.path.exists(path_to_starting_ckpt)): |
| | raise Exception(f"Error: could not find {path_to_starting_ckpt}. When finetuning from a prior checkpoint, you must provide a valid path to that checkpoint.") |
| | |
| | |
| | if finetune_from_scratch: |
| | log_update(f"\nInitializing FusOn-pLM model to be finetuned from scratch") |
| | model = FusOnpLM() |
| | model.to(device) |
| | prepare_model(model, n_unfrozen_layers, |
| | unfreeze_query=unfreeze_query, unfreeze_key=unfreeze_key, unfreeze_value=unfreeze_value) |
| | |
| | |
| | optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate) |
| | |
| | return model, optimizer |
| |
|
| | |
| | else: |
| | log_update(f"\nInitializing FusOn-pLM model to be finetuned from previous checkpoint: {path_to_starting_ckpt}") |
| | model = FusOnpLM(ckpt_path = path_to_starting_ckpt, mlm_head=True) |
| | model.to(device) |
| | prepare_model(model, n_unfrozen_layers, |
| | unfreeze_query=unfreeze_query, unfreeze_key=unfreeze_key, unfreeze_value=unfreeze_value) |
| | |
| | log_update(f"Loading optimizer state_dict from previous checkpoint") |
| | optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters())) |
| | optimizer.load_state_dict(torch.load(os.path.join(path_to_starting_ckpt, "optimizer.pt"), map_location=device)) |
| | |
| | return model, optimizer |
| | |
| | def main(): |
| | |
| | config.PROBABILITY_TYPE = "uniform" |
| | |
| | |
| | kqv_tag = f"{'Q' if config.UNFREEZE_QUERY else ''}" + f"{'K' if config.UNFREEZE_KEY else ''}" + f"{'V' if config.UNFREEZE_VALUE else ''}" |
| | timestamp = get_local_time() |
| | |
| | mask_tag = f"mask{config.MASK_PERCENTAGE}" |
| | if config.VAR_MASK_RATE: |
| | mask_tag=f"maskvar_{config.MASK_SCHEDULER}_low{config.MASK_LOW}_high{config.MASK_HIGH}" |
| | |
| | |
| | TRAIN_SETTINGS_STRING = f"{config.PROBABILITY_TYPE}_{config.MAX_LENGTH}_ft_{config.N_UNFROZEN_LAYERS}layers_{kqv_tag}_b{config.BATCH_SIZE}_lr{config.LEARNING_RATE}_{mask_tag}" |
| | WANDB_NAME = f'{TRAIN_SETTINGS_STRING}-{timestamp}' |
| | |
| | |
| | checkpoint_dir = f'checkpoints/{WANDB_NAME}' |
| | start_epoch = 1 |
| | |
| | |
| | logmode='w' |
| | |
| | |
| | |
| | if not(config.FINETUNE_FROM_SCRATCH): |
| | logmode='a' |
| | path_to_starting_ckpt = config.PATH_TO_STARTING_CKPT |
| | checkpoint_dir = path_to_starting_ckpt[0:path_to_starting_ckpt.rindex('/')] |
| | START_MODEL_TRAIN_SETTINGS_STRING = checkpoint_dir[checkpoint_dir.index('checkpoints/')+len('checkpoints/'):checkpoint_dir.index('-')] |
| | start_epoch = int(path_to_starting_ckpt.split('/checkpoint_epoch_')[1])+1 |
| | |
| | os.makedirs(f'checkpoints', exist_ok=True) |
| | os.makedirs(checkpoint_dir, exist_ok=True) |
| | |
| | |
| | LOG_PATH = f'{checkpoint_dir}/training_log.txt' |
| | ERR_PATH = f'{checkpoint_dir}/training_errors.txt' |
| | with open_logfile(LOG_PATH,mode=logmode), open_errfile(ERR_PATH,mode=logmode): |
| | if not(config.FINETUNE_FROM_SCRATCH): |
| | log_update(f"\n{'-'*200}\nResuming finetuning from checkpoint {start_epoch-1} (first new checkpoint: {start_epoch})\n") |
| | log_update(f"Settings tag for original model (starting point for finetuning) = {START_MODEL_TRAIN_SETTINGS_STRING}\nSettings tag for new model based on configs = {TRAIN_SETTINGS_STRING}\nSame: {START_MODEL_TRAIN_SETTINGS_STRING==TRAIN_SETTINGS_STRING}\n") |
| | |
| | assert START_MODEL_TRAIN_SETTINGS_STRING==TRAIN_SETTINGS_STRING |
| | |
| | |
| | print_configpy(config) |
| | |
| | |
| | check_env_variables() |
| | |
| | |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | log_update(f"\nUsing device: {device}") |
| | |
| | |
| | wandb.init(project=config.WANDB_PROJECT, entity=config.WANDB_ENTITY, name=WANDB_NAME , config={ |
| | "batch_size": config.BATCH_SIZE, |
| | "epochs": config.EPOCHS, |
| | "learning_rate": config.LEARNING_RATE, |
| | }) |
| | |
| | |
| | model, optimizer = intialize_model_and_optimizer(config.FINETUNE_FROM_SCRATCH, device, |
| | path_to_starting_ckpt=config.PATH_TO_STARTING_CKPT, |
| | learning_rate=config.LEARNING_RATE, |
| | n_unfrozen_layers=config.N_UNFROZEN_LAYERS, |
| | unfreeze_query=config.UNFREEZE_QUERY, |
| | unfreeze_key=config.UNFREEZE_KEY, |
| | unfreeze_value=config.UNFREEZE_VALUE) |
| | |
| | |
| | tokenizer = model.tokenizer |
| |
|
| | |
| | train_loader = get_dataloader(config.TRAIN_PATH, tokenizer, probability_type=config.PROBABILITY_TYPE, batch_size=config.BATCH_SIZE, max_length=config.MAX_LENGTH, shuffle=True) |
| | val_loader = get_dataloader(config.VAL_PATH, tokenizer, probability_type=config.PROBABILITY_TYPE, batch_size=config.BATCH_SIZE, max_length=config.MAX_LENGTH, shuffle=False) |
| | test_loader = get_dataloader(config.TEST_PATH, tokenizer, probability_type=config.PROBABILITY_TYPE, batch_size=config.BATCH_SIZE, max_length=config.MAX_LENGTH, shuffle=False) |
| |
|
| | |
| | check_dataloaders(train_loader, val_loader, test_loader, max_length=config.MAX_LENGTH, checkpoint_dir=checkpoint_dir) |
| | |
| | |
| | mask_rate_scheduler = None |
| | if config.VAR_MASK_RATE: |
| | mask_rate_scheduler = get_mask_rate_scheduler(scheduler_type=config.MASK_SCHEDULER, |
| | min_masking_rate=config.MASK_LOW, |
| | max_masking_rate=config.MASK_HIGH, |
| | total_batches=len(train_loader), |
| | total_steps=config.MASK_STEPS) |
| | |
| | |
| | train(model, tokenizer, optimizer, train_loader, val_loader, |
| | n_epochs=config.EPOCHS, |
| | start_epoch = start_epoch, |
| | device=device, |
| | mask_rate_scheduler=mask_rate_scheduler, |
| | mask_percentage=config.MASK_PERCENTAGE, |
| | checkpoint_dir=checkpoint_dir) |
| |
|
| | |
| | test(model, tokenizer, test_loader, mask_percentage=0.15, device=device, checkpoint_dir=checkpoint_dir) |
| | |
| | if __name__ == "__main__": |
| | main() |