| import json |
| from pathlib import Path |
|
|
| import matplotlib.pyplot as plt |
| import numpy as np |
| import pandas as pd |
| from tqdm import tqdm |
|
|
| METRICS = ["lddt", "bb_lddt", "tm_score", "rmsd"] |
|
|
|
|
| def compute_af3_metrics(preds, evals, name): |
| metrics = {} |
|
|
| top_model = None |
| top_confidence = -1000 |
| for model_id in range(5): |
| |
| confidence_file = ( |
| Path(preds) / f"seed-1_sample-{model_id}" / "summary_confidences.json" |
| ) |
| with confidence_file.open("r") as f: |
| confidence_data = json.load(f) |
| confidence = confidence_data["ranking_score"] |
| if confidence > top_confidence: |
| top_model = model_id |
| top_confidence = confidence |
|
|
| |
| eval_file = Path(evals) / f"{name}_model_{model_id}.json" |
| with eval_file.open("r") as f: |
| eval_data = json.load(f) |
| for metric_name in METRICS: |
| if metric_name in eval_data: |
| metrics.setdefault(metric_name, []).append(eval_data[metric_name]) |
|
|
| if "dockq" in eval_data and eval_data["dockq"] is not None: |
| metrics.setdefault("dockq_>0.23", []).append( |
| np.mean( |
| [float(v > 0.23) for v in eval_data["dockq"] if v is not None] |
| ) |
| ) |
| metrics.setdefault("dockq_>0.49", []).append( |
| np.mean( |
| [float(v > 0.49) for v in eval_data["dockq"] if v is not None] |
| ) |
| ) |
| metrics.setdefault("len_dockq_", []).append( |
| len([v for v in eval_data["dockq"] if v is not None]) |
| ) |
|
|
| eval_file = Path(evals) / f"{name}_model_{model_id}_ligand.json" |
| with eval_file.open("r") as f: |
| eval_data = json.load(f) |
| if "lddt_pli" in eval_data: |
| lddt_plis = [ |
| x["score"] for x in eval_data["lddt_pli"]["assigned_scores"] |
| ] |
| for _ in eval_data["lddt_pli"][ |
| "model_ligand_unassigned_reason" |
| ].items(): |
| lddt_plis.append(0) |
| if not lddt_plis: |
| continue |
| lddt_pli = np.mean([x for x in lddt_plis]) |
| metrics.setdefault("lddt_pli", []).append(lddt_pli) |
| metrics.setdefault("len_lddt_pli", []).append(len(lddt_plis)) |
|
|
| if "rmsd" in eval_data: |
| rmsds = [x["score"] for x in eval_data["rmsd"]["assigned_scores"]] |
| for _ in eval_data["rmsd"]["model_ligand_unassigned_reason"].items(): |
| rmsds.append(100) |
| if not rmsds: |
| continue |
| rmsd2 = np.mean([x < 2.0 for x in rmsds]) |
| rmsd5 = np.mean([x < 5.0 for x in rmsds]) |
| metrics.setdefault("rmsd<2", []).append(rmsd2) |
| metrics.setdefault("rmsd<5", []).append(rmsd5) |
| metrics.setdefault("len_rmsd", []).append(len(rmsds)) |
|
|
| |
| oracle = {k: min(v) if k == "rmsd" else max(v) for k, v in metrics.items()} |
| avg = {k: sum(v) / len(v) for k, v in metrics.items()} |
| top1 = {k: v[top_model] for k, v in metrics.items()} |
|
|
| results = {} |
| for metric_name in metrics: |
| if metric_name.startswith("len_"): |
| continue |
| if metric_name == "lddt_pli": |
| l = metrics["len_lddt_pli"][0] |
| elif metric_name == "rmsd<2" or metric_name == "rmsd<5": |
| l = metrics["len_rmsd"][0] |
| elif metric_name == "dockq_>0.23" or metric_name == "dockq_>0.49": |
| l = metrics["len_dockq_"][0] |
| else: |
| l = 1 |
| results[metric_name] = { |
| "oracle": oracle[metric_name], |
| "average": avg[metric_name], |
| "top1": top1[metric_name], |
| "len": l, |
| } |
|
|
| return results |
|
|
|
|
| def compute_chai_metrics(preds, evals, name): |
| metrics = {} |
|
|
| top_model = None |
| top_confidence = 0 |
| for model_id in range(5): |
| |
| confidence_file = Path(preds) / f"scores.model_idx_{model_id}.npz" |
| confidence_data = np.load(confidence_file) |
| confidence = confidence_data["aggregate_score"].item() |
| if confidence > top_confidence: |
| top_model = model_id |
| top_confidence = confidence |
|
|
| |
| eval_file = Path(evals) / f"{name}_model_{model_id}.json" |
| with eval_file.open("r") as f: |
| eval_data = json.load(f) |
| for metric_name in METRICS: |
| if metric_name in eval_data: |
| metrics.setdefault(metric_name, []).append(eval_data[metric_name]) |
|
|
| if "dockq" in eval_data and eval_data["dockq"] is not None: |
| metrics.setdefault("dockq_>0.23", []).append( |
| np.mean( |
| [float(v > 0.23) for v in eval_data["dockq"] if v is not None] |
| ) |
| ) |
| metrics.setdefault("dockq_>0.49", []).append( |
| np.mean( |
| [float(v > 0.49) for v in eval_data["dockq"] if v is not None] |
| ) |
| ) |
| metrics.setdefault("len_dockq_", []).append( |
| len([v for v in eval_data["dockq"] if v is not None]) |
| ) |
|
|
| eval_file = Path(evals) / f"{name}_model_{model_id}_ligand.json" |
| with eval_file.open("r") as f: |
| eval_data = json.load(f) |
| if "lddt_pli" in eval_data: |
| lddt_plis = [ |
| x["score"] for x in eval_data["lddt_pli"]["assigned_scores"] |
| ] |
| for _ in eval_data["lddt_pli"][ |
| "model_ligand_unassigned_reason" |
| ].items(): |
| lddt_plis.append(0) |
| if not lddt_plis: |
| continue |
| lddt_pli = np.mean([x for x in lddt_plis]) |
| metrics.setdefault("lddt_pli", []).append(lddt_pli) |
| metrics.setdefault("len_lddt_pli", []).append(len(lddt_plis)) |
|
|
| if "rmsd" in eval_data: |
| rmsds = [x["score"] for x in eval_data["rmsd"]["assigned_scores"]] |
| for _ in eval_data["rmsd"]["model_ligand_unassigned_reason"].items(): |
| rmsds.append(100) |
| if not rmsds: |
| continue |
| rmsd2 = np.mean([x < 2.0 for x in rmsds]) |
| rmsd5 = np.mean([x < 5.0 for x in rmsds]) |
| metrics.setdefault("rmsd<2", []).append(rmsd2) |
| metrics.setdefault("rmsd<5", []).append(rmsd5) |
| metrics.setdefault("len_rmsd", []).append(len(rmsds)) |
|
|
| |
| oracle = {k: min(v) if k == "rmsd" else max(v) for k, v in metrics.items()} |
| avg = {k: sum(v) / len(v) for k, v in metrics.items()} |
| top1 = {k: v[top_model] for k, v in metrics.items()} |
|
|
| results = {} |
| for metric_name in metrics: |
| if metric_name.startswith("len_"): |
| continue |
| if metric_name == "lddt_pli": |
| l = metrics["len_lddt_pli"][0] |
| elif metric_name == "rmsd<2" or metric_name == "rmsd<5": |
| l = metrics["len_rmsd"][0] |
| elif metric_name == "dockq_>0.23" or metric_name == "dockq_>0.49": |
| l = metrics["len_dockq_"][0] |
| else: |
| l = 1 |
| results[metric_name] = { |
| "oracle": oracle[metric_name], |
| "average": avg[metric_name], |
| "top1": top1[metric_name], |
| "len": l, |
| } |
|
|
| return results |
|
|
|
|
| def compute_boltz_metrics(preds, evals, name): |
| metrics = {} |
|
|
| top_model = None |
| top_confidence = 0 |
| for model_id in range(5): |
| |
| confidence_file = ( |
| Path(preds) / f"confidence_{Path(preds).name}_model_{model_id}.json" |
| ) |
| with confidence_file.open("r") as f: |
| confidence_data = json.load(f) |
| confidence = confidence_data["confidence_score"] |
| if confidence > top_confidence: |
| top_model = model_id |
| top_confidence = confidence |
|
|
| |
| eval_file = Path(evals) / f"{name}_model_{model_id}.json" |
| with eval_file.open("r") as f: |
| eval_data = json.load(f) |
| for metric_name in METRICS: |
| if metric_name in eval_data: |
| metrics.setdefault(metric_name, []).append(eval_data[metric_name]) |
|
|
| if "dockq" in eval_data and eval_data["dockq"] is not None: |
| metrics.setdefault("dockq_>0.23", []).append( |
| np.mean( |
| [float(v > 0.23) for v in eval_data["dockq"] if v is not None] |
| ) |
| ) |
| metrics.setdefault("dockq_>0.49", []).append( |
| np.mean( |
| [float(v > 0.49) for v in eval_data["dockq"] if v is not None] |
| ) |
| ) |
| metrics.setdefault("len_dockq_", []).append( |
| len([v for v in eval_data["dockq"] if v is not None]) |
| ) |
|
|
| eval_file = Path(evals) / f"{name}_model_{model_id}_ligand.json" |
| with eval_file.open("r") as f: |
| eval_data = json.load(f) |
| if "lddt_pli" in eval_data: |
| lddt_plis = [ |
| x["score"] for x in eval_data["lddt_pli"]["assigned_scores"] |
| ] |
| for _ in eval_data["lddt_pli"][ |
| "model_ligand_unassigned_reason" |
| ].items(): |
| lddt_plis.append(0) |
| if not lddt_plis: |
| continue |
| lddt_pli = np.mean([x for x in lddt_plis]) |
| metrics.setdefault("lddt_pli", []).append(lddt_pli) |
| metrics.setdefault("len_lddt_pli", []).append(len(lddt_plis)) |
|
|
| if "rmsd" in eval_data: |
| rmsds = [x["score"] for x in eval_data["rmsd"]["assigned_scores"]] |
| for _ in eval_data["rmsd"]["model_ligand_unassigned_reason"].items(): |
| rmsds.append(100) |
| if not rmsds: |
| continue |
| rmsd2 = np.mean([x < 2.0 for x in rmsds]) |
| rmsd5 = np.mean([x < 5.0 for x in rmsds]) |
| metrics.setdefault("rmsd<2", []).append(rmsd2) |
| metrics.setdefault("rmsd<5", []).append(rmsd5) |
| metrics.setdefault("len_rmsd", []).append(len(rmsds)) |
|
|
| |
| oracle = {k: min(v) if k == "rmsd" else max(v) for k, v in metrics.items()} |
| avg = {k: sum(v) / len(v) for k, v in metrics.items()} |
| top1 = {k: v[top_model] for k, v in metrics.items()} |
|
|
| results = {} |
| for metric_name in metrics: |
| if metric_name.startswith("len_"): |
| continue |
| if metric_name == "lddt_pli": |
| l = metrics["len_lddt_pli"][0] |
| elif metric_name == "rmsd<2" or metric_name == "rmsd<5": |
| l = metrics["len_rmsd"][0] |
| elif metric_name == "dockq_>0.23" or metric_name == "dockq_>0.49": |
| l = metrics["len_dockq_"][0] |
| else: |
| l = 1 |
| results[metric_name] = { |
| "oracle": oracle[metric_name], |
| "average": avg[metric_name], |
| "top1": top1[metric_name], |
| "len": l, |
| } |
|
|
| return results |
|
|
|
|
| def eval_models( |
| chai_preds, |
| chai_evals, |
| af3_preds, |
| af3_evals, |
| boltz_preds, |
| boltz_evals, |
| boltz_preds_x, |
| boltz_evals_x, |
| ): |
| |
| chai_preds_names = { |
| x.name.lower(): x |
| for x in Path(chai_preds).iterdir() |
| if not x.name.lower().startswith(".") |
| } |
| af3_preds_names = { |
| x.name.lower(): x |
| for x in Path(af3_preds).iterdir() |
| if not x.name.lower().startswith(".") |
| } |
| boltz_preds_names = { |
| x.name.lower(): x |
| for x in Path(boltz_preds).iterdir() |
| if not x.name.lower().startswith(".") |
| } |
| boltz_preds_names_x = { |
| x.name.lower(): x |
| for x in Path(boltz_preds_x).iterdir() |
| if not x.name.lower().startswith(".") |
| } |
|
|
| print("Chai preds", len(chai_preds_names)) |
| print("Af3 preds", len(af3_preds_names)) |
| print("Boltz preds", len(boltz_preds_names)) |
| print("Boltzx preds", len(boltz_preds_names_x)) |
|
|
| common = ( |
| set(chai_preds_names.keys()) |
| & set(af3_preds_names.keys()) |
| & set(boltz_preds_names.keys()) |
| & set(boltz_preds_names_x.keys()) |
| ) |
|
|
| |
| keys_to_remove = ["t1133", "h1134", "r1134s1", "t1134s2", "t1121", "t1123", "t1159"] |
| for key in keys_to_remove: |
| if key in common: |
| common.remove(key) |
| print("Common", len(common)) |
|
|
| |
| |
| results = [] |
| for name in tqdm(common): |
| try: |
| af3_results = compute_af3_metrics( |
| af3_preds_names[name], |
| af3_evals, |
| name, |
| ) |
|
|
| except Exception as e: |
| import traceback |
|
|
| traceback.print_exc() |
| print(f"Error evaluating AF3 {name}: {e}") |
| continue |
| try: |
| chai_results = compute_chai_metrics( |
| chai_preds_names[name], |
| chai_evals, |
| name, |
| ) |
| except Exception as e: |
| import traceback |
|
|
| traceback.print_exc() |
| print(f"Error evaluating Chai {name}: {e}") |
| continue |
| try: |
| boltz_results = compute_boltz_metrics( |
| boltz_preds_names[name], |
| boltz_evals, |
| name, |
| ) |
| except Exception as e: |
| import traceback |
|
|
| traceback.print_exc() |
| print(f"Error evaluating Boltz {name}: {e}") |
| continue |
|
|
| try: |
| boltz_results_x = compute_boltz_metrics( |
| boltz_preds_names_x[name], |
| boltz_evals_x, |
| name, |
| ) |
| except Exception as e: |
| import traceback |
|
|
| traceback.print_exc() |
| print(f"Error evaluating Boltzx {name}: {e}") |
| continue |
|
|
| for metric_name in af3_results: |
| if metric_name in chai_results and metric_name in boltz_results: |
| if ( |
| ( |
| af3_results[metric_name]["len"] |
| == chai_results[metric_name]["len"] |
| ) |
| and ( |
| af3_results[metric_name]["len"] |
| == boltz_results[metric_name]["len"] |
| ) |
| and ( |
| af3_results[metric_name]["len"] |
| == boltz_results_x[metric_name]["len"] |
| ) |
| ): |
| results.append( |
| { |
| "tool": "AF3 oracle", |
| "target": name, |
| "metric": metric_name, |
| "value": af3_results[metric_name]["oracle"], |
| } |
| ) |
| results.append( |
| { |
| "tool": "AF3 top-1", |
| "target": name, |
| "metric": metric_name, |
| "value": af3_results[metric_name]["top1"], |
| } |
| ) |
| results.append( |
| { |
| "tool": "Chai-1 oracle", |
| "target": name, |
| "metric": metric_name, |
| "value": chai_results[metric_name]["oracle"], |
| } |
| ) |
| results.append( |
| { |
| "tool": "Chai-1 top-1", |
| "target": name, |
| "metric": metric_name, |
| "value": chai_results[metric_name]["top1"], |
| } |
| ) |
| results.append( |
| { |
| "tool": "Boltz-1 oracle", |
| "target": name, |
| "metric": metric_name, |
| "value": boltz_results[metric_name]["oracle"], |
| } |
| ) |
| results.append( |
| { |
| "tool": "Boltz-1 top-1", |
| "target": name, |
| "metric": metric_name, |
| "value": boltz_results[metric_name]["top1"], |
| } |
| ) |
| results.append( |
| { |
| "tool": "Boltz-1x oracle", |
| "target": name, |
| "metric": metric_name, |
| "value": boltz_results_x[metric_name]["oracle"], |
| } |
| ) |
| results.append( |
| { |
| "tool": "Boltz-1x top-1", |
| "target": name, |
| "metric": metric_name, |
| "value": boltz_results_x[metric_name]["top1"], |
| } |
| ) |
| else: |
| print( |
| "Different lengths", |
| name, |
| metric_name, |
| af3_results[metric_name]["len"], |
| chai_results[metric_name]["len"], |
| boltz_results[metric_name]["len"], |
| boltz_results_x[metric_name]["len"], |
| ) |
| else: |
| print( |
| "Missing metric", |
| name, |
| metric_name, |
| metric_name in chai_results, |
| metric_name in boltz_results, |
| metric_name in boltz_results_x, |
| ) |
|
|
| |
| df = pd.DataFrame(results) |
| return df |
|
|
|
|
| def eval_validity_checks(df): |
| |
| name_mapping = { |
| "af3": "AF3 top-1", |
| "chai": "Chai-1 top-1", |
| "boltz1": "Boltz-1 top-1", |
| "boltz1x": "Boltz-1x top-1", |
| } |
| top1 = df[df["model_idx"] == 0] |
| top1 = top1[["tool", "pdb_id", "valid"]] |
| top1["tool"] = top1["tool"].apply(lambda x: name_mapping[x]) |
| top1 = top1.rename(columns={"tool": "tool", "pdb_id": "target", "valid": "value"}) |
| top1["metric"] = "physical validity" |
| top1["target"] = top1["target"].apply(lambda x: x.lower()) |
| top1 = top1[["tool", "target", "metric", "value"]] |
|
|
| name_mapping = { |
| "af3": "AF3 oracle", |
| "chai": "Chai-1 oracle", |
| "boltz1": "Boltz-1 oracle", |
| "boltz1x": "Boltz-1x oracle", |
| } |
| oracle = df[["tool", "model_idx", "pdb_id", "valid"]] |
| oracle = oracle.groupby(["tool", "pdb_id"])["valid"].max().reset_index() |
| oracle = oracle.rename( |
| columns={"tool": "tool", "pdb_id": "target", "valid": "value"} |
| ) |
| oracle["tool"] = oracle["tool"].apply(lambda x: name_mapping[x]) |
| oracle["metric"] = "physical validity" |
| oracle = oracle[["tool", "target", "metric", "value"]] |
| oracle["target"] = oracle["target"].apply(lambda x: x.lower()) |
| out = pd.concat([top1, oracle]) |
| return out |
|
|
|
|
| def bootstrap_ci(series, n_boot=1000, alpha=0.05): |
| """ |
| Compute 95% bootstrap confidence intervals for the mean of 'series'. |
| """ |
| n = len(series) |
| boot_means = [] |
| |
| for _ in range(n_boot): |
| sample = series.sample(n, replace=True) |
| boot_means.append(sample.mean()) |
|
|
| boot_means = np.array(boot_means) |
| mean_val = np.mean(series) |
| lower = np.percentile(boot_means, 100 * alpha / 2) |
| upper = np.percentile(boot_means, 100 * (1 - alpha / 2)) |
| return mean_val, lower, upper |
|
|
|
|
| def plot_data(desired_tools, desired_metrics, df, dataset, filename): |
| filtered_df = df[ |
| df["tool"].isin(desired_tools) & df["metric"].isin(desired_metrics) |
| ] |
|
|
| |
| boot_stats = filtered_df.groupby(["tool", "metric"])["value"].apply(bootstrap_ci) |
|
|
| |
| boot_stats = boot_stats.apply(pd.Series) |
| boot_stats.columns = ["mean", "lower", "upper"] |
|
|
| |
| plot_data = boot_stats["mean"].unstack("tool") |
| plot_data = plot_data.reindex(desired_metrics) |
|
|
| lower_data = boot_stats["lower"].unstack("tool") |
| lower_data = lower_data.reindex(desired_metrics) |
|
|
| upper_data = boot_stats["upper"].unstack("tool") |
| upper_data = upper_data.reindex(desired_metrics) |
|
|
| |
| tool_order = [ |
| "AF3 oracle", |
| "AF3 top-1", |
| "Chai-1 oracle", |
| "Chai-1 top-1", |
| "Boltz-1 oracle", |
| "Boltz-1 top-1", |
| "Boltz-1x oracle", |
| "Boltz-1x top-1", |
| ] |
| plot_data = plot_data[tool_order] |
| lower_data = lower_data[tool_order] |
| upper_data = upper_data[tool_order] |
|
|
| |
| renaming = { |
| "lddt_pli": "Mean LDDT-PLI", |
| "rmsd<2": "L-RMSD < 2A", |
| "lddt": "Mean LDDT", |
| "dockq_>0.23": "DockQ > 0.23", |
| "physical validity": "Physical Validity", |
| } |
| plot_data = plot_data.rename(index=renaming) |
| lower_data = lower_data.rename(index=renaming) |
| upper_data = upper_data.rename(index=renaming) |
| mean_vals = plot_data.values |
|
|
| |
| tool_colors = [ |
| "#994C00", |
| "#FFB55A", |
| "#931652", |
| "#FC8AD9", |
| "#188F52", |
| "#86E935", |
| "#004D80", |
| "#55C2FF", |
| ] |
|
|
| fig, ax = plt.subplots(figsize=(10, 5)) |
|
|
| x = np.arange(len(plot_data.index)) |
| bar_spacing = 0.015 |
| total_width = 0.7 |
| |
| width = (total_width - (len(tool_order) - 1) * bar_spacing) / len(tool_order) |
|
|
| for i, tool in enumerate(tool_order): |
| |
| offsets = x - (total_width - width) / 2 + i * (width + bar_spacing) |
| |
| tool_means = plot_data[tool].values |
| tool_yerr_lower = mean_vals[:, i] - lower_data.values[:, i] |
| tool_yerr_upper = upper_data.values[:, i] - mean_vals[:, i] |
| |
| tool_yerr = np.vstack([tool_yerr_lower, tool_yerr_upper]) |
|
|
| ax.bar( |
| offsets, |
| tool_means, |
| width=width, |
| color=tool_colors[i], |
| label=tool, |
| yerr=tool_yerr, |
| capsize=2, |
| error_kw={"elinewidth": 0.75}, |
| ) |
|
|
| ax.set_xticks(x) |
| ax.set_xticklabels(plot_data.index, rotation=0) |
| ax.set_ylabel("Value") |
| ax.set_title(f"Performances on {dataset} with 95% CI (Bootstrap)") |
|
|
| plt.tight_layout() |
| ax.legend(loc="lower center", bbox_to_anchor=(0.5, 0.85), ncols=4, frameon=False) |
|
|
| plt.savefig(filename) |
| plt.show() |
|
|
|
|
| def main(): |
| eval_folder = "../../boltz_results_final/" |
| output_folder = "../../boltz_results_final/" |
|
|
| |
| chai_preds = eval_folder + "outputs/test/chai" |
| chai_evals = eval_folder + "evals/test/chai" |
|
|
| af3_preds = eval_folder + "outputs/test/af3" |
| af3_evals = eval_folder + "evals/test/af3" |
|
|
| boltz_preds = eval_folder + "outputs/test/boltz/predictions" |
| boltz_evals = eval_folder + "evals/test/boltz" |
|
|
| boltz_preds_x = eval_folder + "outputs/test/boltzx/predictions" |
| boltz_evals_x = eval_folder + "evals/test/boltzx" |
|
|
| validity_checks = eval_folder + "physical_checks_test.csv" |
|
|
| df_validity_checks = pd.read_csv(validity_checks) |
| df_validity_checks = eval_validity_checks(df_validity_checks) |
|
|
| df = eval_models( |
| chai_preds, |
| chai_evals, |
| af3_preds, |
| af3_evals, |
| boltz_preds, |
| boltz_evals, |
| boltz_preds_x, |
| boltz_evals_x, |
| ) |
|
|
| df = pd.concat([df, df_validity_checks]).reset_index(drop=True) |
| df.to_csv(output_folder + "results_test.csv", index=False) |
|
|
| desired_tools = [ |
| "AF3 oracle", |
| "AF3 top-1", |
| "Chai-1 oracle", |
| "Chai-1 top-1", |
| "Boltz-1 oracle", |
| "Boltz-1 top-1", |
| "Boltz-1x oracle", |
| "Boltz-1x top-1", |
| ] |
| desired_metrics = ["lddt", "dockq_>0.23", "lddt_pli", "rmsd<2", "physical validity"] |
| plot_data( |
| desired_tools, desired_metrics, df, "PDB Test", output_folder + "plot_test.pdf" |
| ) |
|
|
| |
| chai_preds = eval_folder + "outputs/casp15/chai" |
| chai_evals = eval_folder + "evals/casp15/chai" |
|
|
| af3_preds = eval_folder + "outputs/casp15/af3" |
| af3_evals = eval_folder + "evals/casp15/af3" |
|
|
| boltz_preds = eval_folder + "outputs/casp15/boltz/predictions" |
| boltz_evals = eval_folder + "evals/casp15/boltz" |
|
|
| boltz_preds_x = eval_folder + "outputs/casp15/boltzx/predictions" |
| boltz_evals_x = eval_folder + "evals/casp15/boltzx" |
|
|
| validity_checks = eval_folder + "physical_checks_casp.csv" |
|
|
| df_validity_checks = pd.read_csv(validity_checks) |
| df_validity_checks = eval_validity_checks(df_validity_checks) |
|
|
| df = eval_models( |
| chai_preds, |
| chai_evals, |
| af3_preds, |
| af3_evals, |
| boltz_preds, |
| boltz_evals, |
| boltz_preds_x, |
| boltz_evals_x, |
| ) |
|
|
| df = pd.concat([df, df_validity_checks]).reset_index(drop=True) |
| df.to_csv(output_folder + "results_casp.csv", index=False) |
|
|
| plot_data( |
| desired_tools, desired_metrics, df, "CASP15", output_folder + "plot_casp.pdf" |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|