| import pandas as pd |
| import numpy as np |
| import os |
| import torch |
| import time |
| from typing import List, Optional |
| from scipy.stats import spearmanr |
| from .scorer import ProteinGymScorer |
| from .data_loader import load_proteingym_dms |
| from .dms_ids import ALL_SUBSTITUTION_DMS_IDS |
|
|
| try: |
| from base_models.get_base_models import get_base_model |
| except ImportError: |
| from ...base_models.get_base_models import get_base_model |
|
|
| def compare_scoring_methods( |
| model_names: List[str], |
| device: Optional[str] = None, |
| methods: Optional[List[str]] = None, |
| dms_ids: Optional[List[str]] = None, |
| progress: bool = True, |
| output_csv: Optional[str] = None, |
| batch_size: int = 32, |
| scoring_window: str = "optimal" |
| ) -> pd.DataFrame: |
| """ |
| Compare scoring methods across one or more models and DMS assays. |
| |
| Args: |
| model_names: List of model names to evaluate |
| device: Device string like 'cuda' or 'cpu' |
| methods: List of scoring methods to compare |
| dms_ids: List of DMS IDs to evaluate |
| progress: Whether to show progress bars |
| output_csv: Optional path to save results CSV |
| batch_size: Batch size for inference (default: 32) |
| scoring_window: Windowing strategy ('optimal' or 'sliding') |
| |
| Returns: |
| DataFrame with model_name, scoring_method, Average_Spearman, Average_Time_Seconds, Total_Time_Seconds, and n_assays columns |
| """ |
| if methods is None: |
| methods = ["masked_marginal", "mutant_marginal", "wildtype_marginal", "global_log_prob"] |
| |
| if dms_ids is None: |
| dms_ids = ALL_SUBSTITUTION_DMS_IDS |
| |
| all_summary_results = [] |
| device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu")) |
|
|
| for model_name in model_names: |
| print(f"\n{'='*80}") |
| print(f"PROCESSING MODEL: {model_name}") |
| print(f"{'='*80}") |
| |
| try: |
| |
| assay_results = [] |
| spearman_results = [] |
| timing_results = [] |
| |
| model, tokenizer = get_base_model(model_name, masked_lm=True) |
| model = model.to(device).eval() |
| scorer = ProteinGymScorer( |
| model_name=model_name, |
| model=model, |
| tokenizer=tokenizer, |
| device=device, |
| batch_size=batch_size, |
| ) |
| |
| for dms_id in dms_ids: |
| print(f"\nProcessing DMS ID: {dms_id}") |
| df = load_proteingym_dms( |
| dms_id=dms_id, |
| mode="benchmark", |
| repo_id="GleghornLab/ProteinGym_DMS", |
| ) |
| |
| assay_result = df.copy() |
| assay_result['dms_id'] = dms_id |
| |
| for method in methods: |
| print(f"Running {method} scoring for {dms_id}...") |
| |
| |
| start_time = time.time() |
| scored_df = scorer.score_substitutions( |
| df=df, |
| scoring_method=method, |
| scoring_window=scoring_window, |
| ) |
| end_time = time.time() |
| method_duration = end_time - start_time |
| |
| print(f" {method} scoring completed in {method_duration:.2f} seconds") |
| |
| assay_result[f'{method}_score'] = scored_df['delta_log_prob'] |
| |
| |
| x = scored_df["delta_log_prob"].to_numpy() |
| y = scored_df["DMS_score"].to_numpy() |
| if np.all(np.isnan(x)) or np.all(np.isnan(y)): |
| print(f"No valid scores for {method} scoring for {dms_id}") |
| spearman_rho = np.nan |
| else: |
| mask = ~(np.isnan(x) | np.isnan(y)) |
| if mask.sum() < 2: |
| print(f"Not enough valid scores for {method} scoring for {dms_id}") |
| spearman_rho = np.nan |
| else: |
| rho, _ = spearmanr(x[mask], y[mask]) |
| spearman_rho = rho |
| print(f"Spearman correlation for {method} on {dms_id}: {rho:.4f}") |
| |
| assay_result[f'{method}_spearman_rho'] = spearman_rho |
| |
| |
| spearman_results.append({ |
| 'dms_id': dms_id, |
| 'method': method, |
| 'spearman_rho': spearman_rho |
| }) |
| |
| |
| timing_results.append({ |
| 'dms_id': dms_id, |
| 'method': method, |
| 'duration_seconds': method_duration |
| }) |
| |
| assay_results.append(assay_result) |
| |
| |
| spearman_df = pd.DataFrame(spearman_results) |
| timing_df = pd.DataFrame(timing_results) |
| summary_results = [] |
| |
| for method in methods: |
| method_data = spearman_df[spearman_df['method'] == method]['spearman_rho'] |
| valid_correlations = method_data[~np.isnan(method_data)] |
| |
| if len(valid_correlations) > 0: |
| avg_spearman = valid_correlations.mean() |
| n_assays = len(valid_correlations) |
| else: |
| avg_spearman = np.nan |
| n_assays = 0 |
| |
| |
| method_timing_data = timing_df[timing_df['method'] == method]['duration_seconds'] |
| if len(method_timing_data) > 0: |
| avg_time = method_timing_data.mean() |
| total_time = method_timing_data.sum() |
| else: |
| avg_time = np.nan |
| total_time = np.nan |
| |
| summary_results.append({ |
| 'model_name': model_name, |
| 'scoring_method': method, |
| 'Average_Spearman': avg_spearman, |
| 'Average_Time_Seconds': avg_time, |
| 'Total_Time_Seconds': total_time, |
| 'n_assays': n_assays |
| }) |
| |
| model_summary_df = pd.DataFrame(summary_results) |
| all_summary_results.append(model_summary_df) |
| |
| |
| print(f"\n{'='*60}") |
| print(f"SUMMARY FOR MODEL: {model_name}") |
| print(f"{'='*60}") |
| print(model_summary_df.to_string(index=False)) |
| |
| except Exception as e: |
| print(f"Error processing model {model_name}: {e}") |
| |
| failed_summary = pd.DataFrame([{ |
| 'model_name': model_name, |
| 'scoring_method': method, |
| 'Average_Spearman': np.nan, |
| 'Average_Time_Seconds': np.nan, |
| 'Total_Time_Seconds': np.nan, |
| 'n_assays': 0 |
| } for method in methods]) |
| all_summary_results.append(failed_summary) |
| |
| |
| if all_summary_results: |
| combined_summary = pd.concat(all_summary_results, ignore_index=True) |
| else: |
| combined_summary = pd.DataFrame() |
| |
| |
| if output_csv: |
| os.makedirs(os.path.dirname(output_csv), exist_ok=True) |
| combined_summary.to_csv(output_csv, index=False) |
| print(f"\nResults saved to {output_csv}") |
| |
| |
| print(f"\n{'='*80}") |
| print("FINAL SUMMARY - AVERAGE SPEARMAN CORRELATIONS") |
| print(f"{'='*80}") |
| print(combined_summary.to_string(index=False)) |
| |
| return combined_summary |