| | |
| | import os |
| | import fuson_plm.training.config as config |
| | |
| | os.environ['WANDB_API_KEY'] = config.WANDB_API_KEY |
| | os.environ['CUDA_VISIBLE_DEVICES'] = config.CUDA_VISIBLE_DEVICES |
| |
|
| | import torch |
| | import tqdm |
| | import numpy as np |
| | import pandas as pd |
| | import logging |
| | from transformers import AutoModelForMaskedLM, AutoTokenizer |
| | from fuson_plm.utils.logging import log_update, open_logfile, print_configpy |
| | from fuson_plm.benchmarking.caid.utils import DisorderDataset, get_dataloader, check_dataloaders |
| | from fuson_plm.training.utils import batch_sample_mask_tokens_with_probabilities, get_dataloader, check_dataloaders |
| | from fuson_plm.training.train import test |
| |
|
| | def load_esm2_maskedlm(esm_type, device=None): |
| | """ |
| | Loads ESM-2 version of a specified version (e.g. esm2_t33_650M_UR50D) |
| | """ |
| | |
| | logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR) |
| | |
| | if device is None: |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | print(f"Using device: {device}") |
| |
|
| | model = AutoModelForMaskedLM.from_pretrained(f"facebook/{esm_type}") |
| | tokenizer = AutoTokenizer.from_pretrained(f"facebook/{esm_type}") |
| |
|
| | model.to(device) |
| | model.eval() |
| | |
| | return model, tokenizer, device |
| |
|
| |
|
| | def val(model, tokenizer, val_loader, mask_percentage=0.15, device='cuda', checkpoint_dir='./checkpoints'): |
| | """ |
| | Same method as val, just for running the val set |
| | """ |
| | model.to(device) |
| | model.eval() |
| | total_val_loss = 0 |
| | total_weighted_val_loss = 0 |
| | total_val_masked_tokens = 0 |
| |
|
| | with torch.no_grad(): |
| | |
| | with tqdm.tqdm(enumerate(val_loader), total=len(val_loader), desc='Val Batch', leave=True, position=0) as tbar: |
| | for batch_idx, (inputs, prob) in tbar: |
| | |
| | inputs = {k: v.to(device) for k, v in inputs.items()} |
| | prob = prob.to(device) |
| | |
| | |
| | masked_inputs = batch_sample_mask_tokens_with_probabilities(inputs, prob, tokenizer, mask_percentage=mask_percentage) |
| | |
| | |
| | outputs = model(**masked_inputs) |
| | val_loss = outputs.loss |
| | |
| | |
| | num_masked_tokens = (masked_inputs["input_ids"] == tokenizer.mask_token_id).sum().item() |
| |
|
| | |
| | total_val_loss += val_loss.item() |
| | total_weighted_val_loss += val_loss.item() * num_masked_tokens |
| | total_val_masked_tokens += num_masked_tokens |
| |
|
| | |
| | n_val_batches = len(val_loader) |
| | avg_val_loss = total_val_loss / n_val_batches |
| | avg_weighted_val_loss = total_weighted_val_loss / total_val_masked_tokens |
| | val_perplexity = np.exp(avg_weighted_val_loss) |
| |
|
| | log_update(f"\nval results:\nTotal batches = {n_val_batches}, Total masked tokens = {total_val_masked_tokens}, Total Loss = {total_val_loss:.4f}, Avg Batch Loss = {avg_val_loss:.4f}, Avg Masked Token-Weighted Loss = {avg_weighted_val_loss:.4f}, Perplexity = {val_perplexity:.4f}") |
| | |
| | |
| | val_stats_df = pd.DataFrame(data={ |
| | "total_val_loss": [total_val_loss], "weighted_val_loss": [total_weighted_val_loss], |
| | "avg_val_loss": [avg_val_loss], "avg_weighted_val_loss": [avg_weighted_val_loss], |
| | "val_perplexity": [val_perplexity] |
| | }) |
| | val_stats_df.to_csv(f"{checkpoint_dir}/val_results.csv",index=False) |
| |
|
| | def main(): |
| | |
| | model, tokenizer, device = load_esm2_maskedlm("esm2_t33_650M_UR50D") |
| | |
| | checkpoint_dir = f"checkpoints/esm2_t33_650M_UR50D_{config.PROBABILITY_TYPE}_mask{config.MASK_PERCENTAGE}" |
| | os.makedirs(checkpoint_dir,exist_ok=True) |
| | |
| | with open_logfile(f"{checkpoint_dir}/evaluate_val_test_esm.txt"): |
| | |
| | print_configpy(config) |
| | |
| | |
| | val_loader = get_dataloader(config.VAL_PATH, tokenizer, |
| | probability_type=config.PROBABILITY_TYPE, |
| | batch_size=config.BATCH_SIZE, |
| | max_length=config.MAX_LENGTH, shuffle=False) |
| | |
| | |
| | val(model, tokenizer, val_loader, config.MASK_PERCENTAGE, device=device, checkpoint_dir=checkpoint_dir) |
| | |
| | |
| | |
| | |
| | test_loader = get_dataloader(config.TEST_PATH, |
| | tokenizer, |
| | probability_type=config.PROBABILITY_TYPE, |
| | batch_size=config.BATCH_SIZE, |
| | max_length=config.MAX_LENGTH, shuffle=False) |
| |
|
| |
|
| | |
| | test(model, tokenizer, test_loader, config.MASK_PERCENTAGE, device=device, checkpoint_dir=checkpoint_dir) |
| |
|
| | if __name__ == "__main__": |
| | main() |