| |
| import fuson_plm.benchmarking.caid.config as config |
| import os |
| os.environ['CUDA_VISIBLE_DEVICES'] = config.CUDA_VISIBLE_DEVICES |
|
|
| |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, precision_recall_curve, average_precision_score |
|
|
| from sklearn.model_selection import ParameterGrid |
| from tqdm import tqdm |
| import pandas as pd |
| import numpy as np |
| import sys |
| from datetime import datetime |
| import logging |
|
|
| from fuson_plm.benchmarking.embed import embed_dataset_for_benchmark |
| from fuson_plm.benchmarking.caid.model import DisorderPredictor |
| from fuson_plm.benchmarking.caid.utils import DisorderDataset, get_dataloader, check_dataloaders |
| from fuson_plm.benchmarking.caid.plot import make_auroc_curve, make_benchmark_auroc_curve |
| from fuson_plm.utils.logging import get_local_time, open_logfile, log_update, print_configpy |
|
|
| |
| logging.getLogger("transformers").setLevel(logging.ERROR) |
|
|
| def check_env_variables(): |
| log_update("\nChecking on environment variables...") |
| log_update(f"\tCUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}") |
| log_update(f"\ttorch.cuda.device_count(): {torch.cuda.device_count()}") |
| for i in range(torch.cuda.device_count()): |
| log_update(f"\t\tDevice {i}: {torch.cuda.get_device_name(i)}") |
| |
| def check_splits(df): |
| |
| if len(df.loc[df['split'].isna()])>0: |
| raise Exception("Error: not every benchmarking sequence has been allocated to a split (train or test)") |
| |
| if len({'train','test'} - set(df['split'].unique()))!=0: |
| raise Exception("Error: splits column should only have \'train\' and \'test\'.") |
| |
| if len(df.loc[df['Sequence'].duplicated()])>0: |
| raise Exception("Error: duplicate sequences provided") |
|
|
| |
| def train(model, train_loader, optimizer, n_epochs, criterion, device): |
| """ |
| Trains the model for a single epoch. |
| Args: |
| model (nn.Module): model that will be trained |
| dataloader (DataLoader): PyTorch DataLoader with training data |
| optimizer (torch.optim): optimizer |
| criterion (nn.Module): loss function |
| device (torch.device): device (GPU or CPU to train the model |
| Returns: |
| total_loss (float): model loss |
| """ |
| |
| model.train() |
| |
| |
| avg_train_losses = [] |
| |
| |
| for epoch in range(1, 1+n_epochs): |
| log_update(f"EPOCH {epoch}/{n_epochs}") |
| |
| |
| total_train_loss = 0 |
| |
| |
| total_steps = len(train_loader) |
| update_interval = total_steps // min(20,total_steps) |
| prog_bar = tqdm(total=total_steps, leave=True, file=sys.stdout) |
| |
| |
| |
| |
| for batch_idx, (_, embeddings, labels) in enumerate(train_loader, start=1): |
| |
| embeddings, labels = embeddings.to(device), labels.to(device) |
| |
| |
| optimizer.zero_grad() |
| outputs = model(embeddings) |
|
|
| loss = criterion(outputs, labels) |
| loss.backward() |
| |
| |
| optimizer.step() |
| |
| |
| total_train_loss += loss.item() |
|
|
| if batch_idx % update_interval == 0 or batch_idx == total_steps: |
| prog_bar.update(update_interval) |
| sys.stdout.flush() |
| |
| prog_bar.close() |
| |
| |
| avg_train_loss = total_train_loss / total_steps |
| avg_train_losses.append(avg_train_loss) |
| |
| return avg_train_losses |
|
|
|
|
| |
| def evaluate(model, test_loader, device): |
| """ |
| Performs inference on a trained model |
| Args: |
| model (nn.Module): the trained model |
| test_loader (DataLoader): PyTorch DataLoader with testing data |
| device (torch.device): device (GPU or CPU) to be used for inference |
| Returns: |
| preds (list): predicted per-residue disorder labels |
| true_labels (list): ground truth per-residue disorder labels |
| """ |
| model.eval() |
| test_sequences, test_preds, true_labels = [], [], [] |
| |
| |
| total_steps = len(test_loader) |
| update_interval = total_steps // min(20,total_steps) |
| prog_bar = tqdm(total=total_steps, leave=True, file=sys.stdout) |
| |
| with torch.no_grad(): |
| for batch_idx, (sequences, embeddings, labels) in enumerate(test_loader,start=1): |
| embeddings, labels = embeddings.to(device), labels.to(device) |
| |
| |
| outputs = model(embeddings) |
| |
| assert len(sequences)==1 |
| test_sequences.append(sequences[0]) |
| test_preds.append(outputs.cpu().numpy()) |
| true_labels.append(labels.cpu().numpy()) |
| |
| if batch_idx % update_interval == 0 or batch_idx == total_steps: |
| prog_bar.update(update_interval) |
| sys.stdout.flush() |
| prog_bar.close() |
| return test_sequences, test_preds, true_labels |
|
|
| |
| def benchmark(model, bench_loader, device): |
| """ |
| Performs inference on a trained model |
| Args: |
| model (nn.Module): the trained model |
| bench_loader (DataLoader): PyTorch DataLoader with benchmarking data |
| device (torch.device): device (GPU or CPU) to be used for inference |
| Returns: |
| preds (list): predicted per-residue disorder labels |
| true_labels (list): ground truth per-residue disorder labels |
| """ |
| model.eval() |
| bench_sequences, bench_preds, true_labels = [], [], [] |
| |
| |
| total_steps = len(bench_loader) |
| update_interval = total_steps // min(20,total_steps) |
| prog_bar = tqdm(total=total_steps, leave=True, file=sys.stdout) |
| |
| with torch.no_grad(): |
| for batch_idx, (sequences, embeddings, labels) in enumerate(bench_loader,start=1): |
| embeddings, labels = embeddings.to(device), labels.to(device) |
| |
| |
| outputs = model(embeddings) |
| |
| assert len(sequences)==1 |
| bench_sequences.append(sequences[0]) |
| bench_preds.append(outputs.cpu().numpy()) |
| true_labels.append(labels.cpu().numpy()) |
| |
| if batch_idx % update_interval == 0 or batch_idx == total_steps: |
| prog_bar.update(update_interval) |
| sys.stdout.flush() |
| prog_bar.close() |
| return bench_sequences, bench_preds, true_labels |
|
|
| def grid_search_caid_predictor(embedding_path, details, output_dir, param_grid, overwrite_saved_model=True): |
| |
| grid = ParameterGrid(param_grid) |
| |
| |
| training_hyperparams = { |
| "learning_rate": None, |
| "num_epochs": None, |
| "num_layers": None, |
| "num_heads": None, |
| "dropout": None |
| } |
| |
| for params in grid: |
| |
| training_hyperparams.update(params) |
| log_update(f"\nHyperparams:{training_hyperparams}") |
| train_and_evaluate_caid_predictor(embedding_path, details, output_dir, training_hyperparams, overwrite_saved_model=overwrite_saved_model) |
| |
| |
| def find_best_hyperparams(output_dir, param_grid): |
| |
| param_cols = [f"caid_model_{k}" for k in param_grid.keys()] |
| |
| |
| test_metrics = pd.read_csv(f'{output_dir}/caid_hyperparam_screen_test_metrics.csv') |
| train_losses = pd.read_csv(f'{output_dir}/caid_hyperparam_screen_train_losses.csv') |
| bench_metrics = pd.read_csv(f'{output_dir}/caid_hyperparam_screen_fusion_benchmark_metrics.csv') |
| |
| |
| test_metrics['Model Epoch'] = test_metrics['Model Epoch'].fillna('') |
| train_losses['Model Epoch'] = train_losses['Model Epoch'].fillna('') |
| bench_metrics['Model Epoch'] = bench_metrics['Model Epoch'].fillna('') |
| |
| |
| benchmarked_model_key = ['Model Type','Model Name','Model Epoch'] |
| ordered_priority_stats = ['AUROC','F1 Score','Accuracy','Precision','Recall'] |
| sort_order = benchmarked_model_key + ordered_priority_stats |
| sort_bools = [True]*len(benchmarked_model_key) + [False]*len(ordered_priority_stats) |
| test_metrics = test_metrics.sort_values( |
| sort_order, |
| ascending=sort_bools |
| ).groupby(benchmarked_model_key).head(1).reset_index(drop=True) |
|
|
| |
| group_order = benchmarked_model_key+param_cols |
| sort_order = group_order+["caid_model_epoch"] |
| sort_bools = [True]*(len(group_order))+[False]*1 |
| train_losses = train_losses.sort_values( |
| by=sort_order, |
| ascending=sort_bools, |
| ).groupby(group_order).head(1).reset_index(drop=True) |
| |
| |
| merge_cols = benchmarked_model_key+param_cols+['path_to_model'] |
| combined_results = pd.merge( |
| test_metrics,train_losses, |
| on=merge_cols, |
| how='left' |
| ) |
| |
| bench_metrics = bench_metrics.rename(columns = {'AUROC': 'Fusion AUROC', |
| 'F1 Score': 'Fusion F1 Score', |
| 'Accuracy': 'Fusion Accuracy', |
| 'Precision': 'Fusion Precision', |
| 'Recall': 'Fusion Recall'}) |
| combined_results = pd.merge( |
| combined_results,bench_metrics, |
| on=merge_cols, |
| how='left' |
| ) |
| |
| |
| combined_results = combined_results[[ |
| 'Model Type','Model Name','Model Epoch', |
| 'Accuracy','Precision','Recall','F1 Score','AUROC', |
| 'Fusion Accuracy','Fusion Precision','Fusion Recall','Fusion F1 Score','Fusion AUROC', |
| 'caid_model_learning_rate','caid_model_num_epochs','caid_model_num_layers','caid_model_num_heads','caid_model_dropout','caid_model_epoch','caid_model_loss','path_to_model' |
| ]] |
| combined_results.to_csv(f"{output_dir}/best_caid_model_results.csv",index=False) |
|
|
| def get_fresh_model(training_hyperparams, device): |
| input_dim, hidden_dim = 1280, 1280 |
| model = DisorderPredictor( |
| input_dim=input_dim, |
| hidden_dim=hidden_dim, |
| num_layers=training_hyperparams["num_layers"], |
| num_heads=training_hyperparams["num_heads"], |
| dropout=training_hyperparams['dropout'] |
| ) |
| model.to(device) |
| |
| return model |
|
|
| def predict_from_best_thresh(prob_and_label_df, seq_label_dict=None): |
| """ |
| Finds the best prediction threshold for disorder by maximizing F1 Score. Makes predictions |
| Args: |
| prob_and_label_df: DataFrame with columns: sequence,prob_1 |
| seq_label_dict: dictionary of sequences to true labels. e.g. 'MKLP': '1100' |
| Returns: |
| prob_and_label_df: new version of original dataframe with added columns: threshold,pred_labels |
| """ |
| |
| prob_and_label_df['labels'] = prob_and_label_df['sequence'].map(seq_label_dict) |
| |
| assert prob_and_label_df['labels'].notna().all() |
| |
| probs = ','.join(prob_and_label_df['prob_1'].tolist()) |
| probs = [float(x) for x in probs.split(",")] |
| true_labels = ''.join(prob_and_label_df['labels'].tolist()) |
| true_labels = [int(x) for x in list(true_labels)] |
| total_aas = sum(prob_and_label_df['sequence'].str.len()) |
| log_update(f"\tLength of dataframe (number of seqs in dataset): {len(prob_and_label_df)}") |
| log_update(f"\tTotal AAs in dataset: {total_aas}\ttotal probabilities: {len(probs)}\ttotal labels: {len(true_labels)}") |
|
|
| y_true = np.array(true_labels) |
| y_probs = np.array(probs) |
|
|
| |
| precision, recall, thresholds = precision_recall_curve(y_true, y_probs) |
| precision = precision[:-1] |
| recall = recall[:-1] |
| |
| f1_scores = 2 * (precision * recall) / (precision + recall) |
|
|
| |
| best_threshold_index = np.argmax(f1_scores) |
| best_threshold = thresholds[best_threshold_index] |
|
|
| |
| auprc = average_precision_score(y_true, y_probs) |
|
|
| log_update(f"\tBest Threshold: {best_threshold}") |
| log_update(f"\tBest F1 Score: {f1_scores[best_threshold_index]:.2f}") |
| log_update(f"\tAUPRC: {auprc:.2f}") |
|
|
| |
| |
| prob_and_label_df['threshold'] = [best_threshold]*len(prob_and_label_df) |
| |
| prob_and_label_df['pred_labels'] = prob_and_label_df['prob_1'].apply(lambda x: ['1' if float(y)>best_threshold else '0' for y in x.split(",")]) |
| prob_and_label_df['pred_labels'] = prob_and_label_df['pred_labels'].apply(lambda x: ''.join(x)) |
| log_update("\tUsed calculated threshold to construct predicted labels for dataset") |
| return prob_and_label_df |
| |
| |
| def train_and_evaluate_caid_predictor(embedding_path, details, output_dir, training_hyperparams, overwrite_saved_model=True): |
| |
| benchmark_model_type = details['model_type'] |
| benchmark_model_name = details['model'] |
| benchmark_model_epoch = details['epoch'] |
| |
| |
| model_outer_folder = f"trained_models/{benchmark_model_type}" |
| if not(np.isnan(benchmark_model_epoch)): model_outer_folder+=f"/{benchmark_model_name}/epoch{benchmark_model_epoch}" |
| model_full_folder=f"{model_outer_folder}/lr{training_hyperparams['learning_rate']}_bs{1}_hd{1280}_epochs{training_hyperparams['num_epochs']}_layers{training_hyperparams['num_layers']}_heads{training_hyperparams['num_heads']}_drpt{training_hyperparams['dropout']}" |
| l_model_full_folder = model_full_folder.split("/") |
| for i in range(0,len(l_model_full_folder)): |
| newdir="/".join(l_model_full_folder[:i+1]) |
| os.makedirs(newdir, exist_ok=True) |
| |
| |
| model_full_path = f"{model_full_folder}/model.pth" |
| train_new_model=True |
| if os.path.exists(model_full_path): |
| |
| if overwrite_saved_model: |
| log_update(f"\nOverwriting previously trained model with same hyperparams at {model_full_path}") |
| |
| else: |
| log_update(f"\nWARNING: this model may already be trained at {model_full_path}. Skipping") |
| train_new_model=False |
| |
| |
| if train_new_model: |
| max_length=4500+2 |
| |
| train_dataloader = get_dataloader('splits/train_df.csv', embedding_path, max_length=max_length, batch_size=1, shuffle=True) |
| test_dataloader = get_dataloader('splits/test_df.csv', embedding_path, max_length=max_length, batch_size=1, shuffle=False) |
| benchmark_dataloader = get_dataloader('splits/fusion_bench_df.csv', embedding_path, max_length=max_length, batch_size=1, shuffle=False) |
| |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| model = get_fresh_model(training_hyperparams, device) |
| |
| |
| optimizer = optim.Adam(model.parameters(), lr=training_hyperparams["learning_rate"]) |
| criterion = nn.BCELoss() |
| num_epochs = training_hyperparams['num_epochs'] |
|
|
| |
| |
| avg_train_losses = train(model, train_dataloader, optimizer, num_epochs, criterion, device) |
| |
| formatted_hyperparams = {f"caid_model_{k}":v for k, v in training_hyperparams.items()} |
| train_loss_df = pd.DataFrame.from_dict(formatted_hyperparams,orient='index').T |
| train_loss_df['caid_model_epoch'] = [list(range(1,1+num_epochs))] |
| train_loss_df['caid_model_loss'] = [avg_train_losses] |
| train_loss_df[['Model Type','Model Name','Model Epoch']] = [[benchmark_model_type,benchmark_model_name,benchmark_model_epoch]] |
| train_loss_df = train_loss_df.explode(['caid_model_epoch', 'caid_model_loss']) |
| |
| |
| train_loss_combined_results_csv_path = f'{output_dir}/caid_hyperparam_screen_train_losses.csv' |
| train_loss_individual_results_csv_path = f'{model_full_folder}/caid_train_losses.csv' |
| train_loss_df.to_csv(train_loss_individual_results_csv_path,mode='w',index=False) |
| train_loss_df['path_to_model'] = model_full_path |
| if not(os.path.exists(train_loss_combined_results_csv_path)): |
| train_loss_df.to_csv(train_loss_combined_results_csv_path,index=False) |
| else: |
| train_loss_df.to_csv(train_loss_combined_results_csv_path,mode='a',index=False,header=False) |
| |
| log_update(f"Final train loss: {avg_train_losses[-1]:.4f}") |
|
|
| |
| |
| test_sequences, test_preds, test_labels = evaluate(model, test_dataloader, device) |
| test_metrics = calculate_metrics(test_preds, test_labels) |
| |
| test_results_df = pd.DataFrame.from_dict(test_metrics,orient='index').T |
| test_results_df[['Model Type','Model Name','Model Epoch']] = [[benchmark_model_type,benchmark_model_name,benchmark_model_epoch]] |
| |
| hyperparams_df = pd.DataFrame.from_dict(formatted_hyperparams,orient='index').T |
| test_results_df = pd.concat([test_results_df,hyperparams_df],axis=1) |
| |
| |
| |
| prob_and_label_df = pd.DataFrame(data = { |
| 'sequence': test_sequences, |
| 'prob_1': [arr.flatten() for arr in test_preds] |
| }) |
| prob_and_label_df['prob_1'] = prob_and_label_df['prob_1'].apply( |
| lambda prob_list: ",".join([f"{round(x, 3):.3f}" for x in prob_list]) |
| ) |
| prob_and_label_df['Model Type'] = [benchmark_model_type]*len(prob_and_label_df) |
| prob_and_label_df['Model Name'] = [benchmark_model_name]*len(prob_and_label_df) |
| prob_and_label_df['Model Epoch'] = [benchmark_model_epoch]*len(prob_and_label_df) |
| |
| |
| test_combined_results_csv_path = f'{output_dir}/caid_hyperparam_screen_test_metrics.csv' |
| test_results_csv_path = f'{model_full_folder}/caid_hyperparam_screen_test_metrics.csv' |
| test_results_df.to_csv(test_results_csv_path,mode='w',index=False) |
| test_results_df['path_to_model'] = model_full_path |
| if not(os.path.exists(test_combined_results_csv_path)): |
| test_results_df.to_csv(test_combined_results_csv_path,index=False) |
| else: |
| test_results_df.to_csv(test_combined_results_csv_path,mode='a',index=False,header=False) |
| |
| |
| test_probs_csv_path = f'{model_full_folder}/caid_hyperparam_screen_test_probs.csv' |
| seq_label_dict = pd.read_csv('splits/test_df.csv') |
| seq_label_dict = dict(zip(seq_label_dict['Sequence'],seq_label_dict['Label'])) |
| log_update("Finding best threshold for CAID test set predictions based on maximizing F1 Score...") |
| prob_and_label_df = predict_from_best_thresh(prob_and_label_df, seq_label_dict=seq_label_dict) |
| prob_and_label_df[['sequence','prob_1','threshold','pred_labels']].to_csv(test_probs_csv_path,mode='w',index=False) |
| |
| log_update(f"Test performance: {test_metrics}") |
| |
| |
| |
| benchmark_sequences, benchmark_preds, benchmark_labels = evaluate(model, benchmark_dataloader, device) |
| benchmark_metrics = calculate_metrics(benchmark_preds, benchmark_labels) |
| |
| benchmark_results_df = pd.DataFrame.from_dict(benchmark_metrics,orient='index').T |
| benchmark_results_df[['Model Type','Model Name','Model Epoch']] = [[benchmark_model_type,benchmark_model_name,benchmark_model_epoch]] |
| |
| hyperparams_df = pd.DataFrame.from_dict(formatted_hyperparams,orient='index').T |
| benchmark_results_df = pd.concat([benchmark_results_df,hyperparams_df],axis=1) |
| |
| |
| |
| prob_and_label_df = pd.DataFrame(data = { |
| 'sequence': benchmark_sequences, |
| 'prob_1': [arr.flatten() for arr in benchmark_preds] |
| }) |
| prob_and_label_df['prob_1'] = prob_and_label_df['prob_1'].apply( |
| lambda prob_list: ",".join([f"{round(x, 3):.3f}" for x in prob_list]) |
| ) |
| prob_and_label_df['Model Type'] = [benchmark_model_type]*len(prob_and_label_df) |
| prob_and_label_df['Model Name'] = [benchmark_model_name]*len(prob_and_label_df) |
| prob_and_label_df['Model Epoch'] = [benchmark_model_epoch]*len(prob_and_label_df) |
| |
| |
| benchmark_combined_results_csv_path = f'{output_dir}/caid_hyperparam_screen_fusion_benchmark_metrics.csv' |
| benchmark_results_csv_path = f'{model_full_folder}/caid_hyperparam_screen_fusion_benchmark_metrics.csv' |
| benchmark_results_df.to_csv(benchmark_results_csv_path,mode='w',index=False) |
| benchmark_results_df['path_to_model'] = model_full_path |
| if not(os.path.exists(benchmark_combined_results_csv_path)): |
| benchmark_results_df.to_csv(benchmark_combined_results_csv_path,index=False) |
| else: |
| benchmark_results_df.to_csv(benchmark_combined_results_csv_path,mode='a',index=False,header=False) |
| |
| |
| benchmark_probs_csv_path = f'{model_full_folder}/caid_hyperparam_screen_fusion_benchmark_probs.csv' |
| seq_label_dict = pd.read_csv('splits/fusion_bench_df.csv') |
| seq_label_dict = dict(zip(seq_label_dict['Sequence'],seq_label_dict['Label'])) |
| log_update("Finding best threshold for fusion benchmark set predictions based on maximizing F1 Score...") |
| prob_and_label_df = predict_from_best_thresh(prob_and_label_df, seq_label_dict=seq_label_dict) |
| prob_and_label_df[['sequence','prob_1','threshold','pred_labels']].to_csv(benchmark_probs_csv_path,mode='w',index=False) |
| |
| log_update(f"benchmark performance: {benchmark_metrics}") |
|
|
| |
| |
| torch.save(model.state_dict(), model_full_path) |
|
|
| |
| else: |
| |
| train_loss_combined_results_csv_path = f'{output_dir}/caid_hyperparam_screen_train_losses.csv' |
| train_loss_individual_results_csv_path = f'{model_full_folder}/caid_train_losses.csv' |
| train_loss_individual_results = pd.read_csv(train_loss_individual_results_csv_path) |
| train_loss_individual_results['path_to_model'] = [model_full_path]*len(train_loss_individual_results) |
| |
| if not(os.path.exists(train_loss_combined_results_csv_path)): |
| train_loss_individual_results.to_csv(train_loss_combined_results_csv_path,index=False) |
| else: |
| train_loss_individual_results.to_csv(train_loss_combined_results_csv_path,mode='a',index=False,header=False) |
| |
| |
| test_combined_results_csv_path = f'{output_dir}/caid_hyperparam_screen_test_metrics.csv' |
| test_results_csv_path = f'{model_full_folder}/caid_hyperparam_screen_test_metrics.csv' |
| test_individual_results = pd.read_csv(test_results_csv_path) |
| test_individual_results['path_to_model'] = [model_full_path]*len(test_individual_results) |
| |
| if not(os.path.exists(test_combined_results_csv_path)): |
| test_individual_results.to_csv(test_combined_results_csv_path,index=False) |
| else: |
| test_individual_results.to_csv(test_combined_results_csv_path,mode='a',index=False,header=False) |
| |
| |
| benchmark_combined_results_csv_path = f'{output_dir}/caid_hyperparam_screen_fusion_benchmark_metrics.csv' |
| benchmark_results_csv_path = f'{model_full_folder}/caid_hyperparam_screen_fusion_benchmark_metrics.csv' |
| benchmark_individual_results = pd.read_csv(benchmark_results_csv_path) |
| benchmark_individual_results['path_to_model'] = [model_full_path]*len(benchmark_individual_results) |
| |
| if not(os.path.exists(benchmark_combined_results_csv_path)): |
| benchmark_individual_results.to_csv(benchmark_combined_results_csv_path,index=False) |
| else: |
| benchmark_individual_results.to_csv(benchmark_combined_results_csv_path,mode='a',index=False,header=False) |
| |
| |
| def calculate_metrics(preds, labels, threshold=0.5): |
| """ |
| Calculates metrics to assess model performance |
| Args: |
| preds (list): model's predictions (probabilities) |
| labels (list): ground truth labels |
| threshold (float): minimum threshold a prediction must be met to be considered disordered |
| Returns: |
| accuracy (float): accuracy |
| precision (float): precision |
| recall (float): recall |
| f1 (float): F1 score |
| roc_auc (float): AUROC score |
| """ |
| flat_binary_preds, flat_prob_preds, flat_labels = [], [], [] |
|
|
| for pred, label in zip(preds, labels): |
| flat_binary_preds.extend((pred > threshold).astype(int).flatten()) |
| flat_prob_preds.extend(pred.flatten()) |
| flat_labels.extend(label.flatten()) |
|
|
| flat_binary_preds = np.array(flat_binary_preds) |
| flat_prob_preds = np.array(flat_prob_preds) |
| flat_labels = np.array(flat_labels) |
|
|
| accuracy = accuracy_score(flat_labels, flat_binary_preds) |
| precision = precision_score(flat_labels, flat_binary_preds) |
| recall = recall_score(flat_labels, flat_binary_preds) |
| f1 = f1_score(flat_labels, flat_binary_preds) |
| roc_auc = roc_auc_score(flat_labels, flat_prob_preds) |
| |
| |
| metrics_dict = { |
| 'Accuracy': accuracy, |
| 'Precision': precision, |
| 'Recall': recall, |
| 'F1 Score': f1, |
| 'AUROC': roc_auc |
| } |
|
|
| return metrics_dict |
|
|
| def main(): |
| |
| os.makedirs('results',exist_ok=True) |
| output_dir = f'results/{get_local_time()}' |
| os.makedirs(output_dir,exist_ok=True) |
| |
| with open_logfile(f'{output_dir}/caid_benchmark_log.txt'): |
| |
| print_configpy(config) |
| |
| |
| check_env_variables() |
| |
| |
| all_embedding_paths = embed_dataset_for_benchmark( |
| fuson_ckpts=config.FUSONPLM_CKPTS, |
| input_data_path='splits/splits.csv', |
| input_fname='CAID2_competition_sequences', |
| average=False, seq_col='Sequence', |
| benchmark_fusonplm=config.BENCHMARK_FUSONPLM, |
| benchmark_esm=config.BENCHMARK_ESM, |
| benchmark_fo_puncta_ml=False, |
| overwrite=config.PERMISSION_TO_OVERWRITE_EMBEDDINGS) |
| |
| |
| splits_df = pd.read_csv('splits/splits.csv') |
| log_update(f"\nSplit breakdown...\n\t{len(splits_df.loc[splits_df['Split']=='Train'])} train seqs\n\t{len(splits_df.loc[splits_df['Split']=='Test'])} test seqs") |
| |
| log_update("\nTraining and evaluating models") |
| |
| |
| param_grid = { |
| 'learning_rate': [5e-5], |
| 'num_heads': [5, 8, 10], |
| 'num_layers': [2, 4, 6], |
| 'dropout': [0.2, 0.5], |
| 'num_epochs': [2] |
| } |
| |
| |
| for embedding_path, details in all_embedding_paths.items(): |
| log_update(f"\nBenchmarking embeddings at: {embedding_path}") |
| |
| grid_search_caid_predictor(embedding_path, details, output_dir, param_grid, overwrite_saved_model=config.PERMISSION_TO_OVERWRITE_MODELS) |
| |
| |
| find_best_hyperparams(output_dir, param_grid) |
| |
| |
| |
| best_caid_model_results = pd.read_csv(f"{output_dir}/best_caid_model_results.csv") |
| |
| best_caid_model_results_benchmark = best_caid_model_results.drop(columns= |
| ['AUROC','F1 Score','Accuracy','Precision','Recall'] |
| ).rename(columns={ |
| 'Fusion AUROC': 'AUROC', |
| 'Fusion F1 Score': 'F1 Score', |
| 'Fusion Accuracy': 'Accuracy', |
| 'Fusion Precision': 'Precision', |
| 'Fusion Recall': 'Recall' |
| }) |
| |
| if __name__ == "__main__": |
| main() |