| """Self-consistency evaluation using ProteinMPNN and ESMFold."""
|
|
|
| from __future__ import annotations
|
|
|
| import csv
|
| import json
|
| import logging
|
| import os
|
| import shutil
|
| import subprocess
|
| import sys
|
| from dataclasses import dataclass
|
| from datetime import datetime
|
| from pathlib import Path
|
| from typing import Dict, List, Optional, Tuple
|
|
|
| import numpy as np
|
|
|
| from model_loader import REPO_ROOT
|
|
|
| LOGGER = logging.getLogger(__name__)
|
| SKIP_REFERENCE_SEQUENCE = True
|
|
|
|
|
| class SelfConsistencyError(RuntimeError):
|
| """Raised when self-consistency evaluation fails."""
|
|
|
|
|
| @dataclass
|
| class SelfConsistencyResult:
|
| sample_path: str
|
| run_dir: str
|
| parsed_pdbs_jsonl: str
|
| mpnn_fasta_path: str
|
| run_folding: bool
|
| metrics_csv_path: Optional[str]
|
| folded_pdb_paths: List[str]
|
| per_sequence_metrics: List[Dict[str, object]]
|
| summary: Dict[str, object]
|
|
|
|
|
| def _as_path(path_value: str) -> Path:
|
| raw = Path(path_value).expanduser()
|
| return raw if raw.is_absolute() else (REPO_ROOT / raw).resolve()
|
|
|
|
|
| def _safe_mean(values: List[float]) -> Optional[float]:
|
| return float(np.mean(values)) if values else None
|
|
|
|
|
| def _mean_bfactor_from_pdb(pdb_path: Path) -> Optional[float]:
|
| values: List[float] = []
|
| for line in pdb_path.read_text(encoding="utf-8", errors="replace").splitlines():
|
| if not (line.startswith("ATOM") or line.startswith("HETATM")):
|
| continue
|
| field = line[60:66].strip()
|
| if not field:
|
| continue
|
| try:
|
| values.append(float(field))
|
| except ValueError:
|
| continue
|
| return _safe_mean(values)
|
|
|
|
|
| def _read_fasta(path: Path) -> List[Tuple[str, str]]:
|
| if not path.exists():
|
| raise SelfConsistencyError(f"ProteinMPNN FASTA output not found: {path}")
|
|
|
| entries: List[Tuple[str, str]] = []
|
| header: Optional[str] = None
|
| seq_lines: List[str] = []
|
| for raw_line in path.read_text(encoding="utf-8", errors="replace").splitlines():
|
| line = raw_line.strip()
|
| if not line:
|
| continue
|
| if line.startswith(">"):
|
| if header is not None:
|
| entries.append((header, "".join(seq_lines)))
|
| header = line[1:].strip()
|
| seq_lines = []
|
| continue
|
| seq_lines.append(line)
|
|
|
| if header is not None:
|
| entries.append((header, "".join(seq_lines)))
|
|
|
| if not entries:
|
| raise SelfConsistencyError(f"ProteinMPNN FASTA output had no sequences: {path}")
|
| return entries
|
|
|
|
|
| def _calc_tm_score(pos_1, pos_2, seq_1: str, seq_2: str) -> Tuple[float, float]:
|
| try:
|
| from tmtools import tm_align
|
| except ImportError as exc:
|
| raise SelfConsistencyError(
|
| "tmtools is required for scTM calculation. Install `tmtools`."
|
| ) from exc
|
| tm_results = tm_align(pos_1, pos_2, seq_1, seq_2)
|
| return float(tm_results.tm_norm_chain1), float(tm_results.tm_norm_chain2)
|
|
|
|
|
| def _calc_aligned_rmsd(pos_1, pos_2) -> float:
|
| try:
|
| from utils import new_pdbUtils as du
|
| except ModuleNotFoundError as exc:
|
| raise SelfConsistencyError(
|
| "Missing dependency for structure parsing/alignment. "
|
| "Install `biopython` (module `Bio`) and re-run self-consistency."
|
| ) from exc
|
|
|
| aligned = du.rigid_transform_3D(pos_1, pos_2)[0]
|
| return float(np.mean(np.linalg.norm(aligned - pos_2, axis=-1)))
|
|
|
|
|
| class FlowProtSelfConsistencyService:
|
| """Runs self-consistency metrics for generated structures."""
|
|
|
| def __init__(self) -> None:
|
| self._pmpnn_dir = _as_path(os.getenv("FLOWPROT_PMPNN_DIR", "model/ProteinMPNN"))
|
| self._pmpnn_weights_dir = _as_path(
|
| os.getenv(
|
| "FLOWPROT_PMPNN_WEIGHTS_DIR",
|
| str(self._pmpnn_dir / "vanilla_model_weights"),
|
| )
|
| )
|
| self._pmpnn_model_name = os.getenv("FLOWPROT_PMPNN_MODEL_NAME", "v_48_020")
|
| self._esmfold_model_id = os.getenv("FLOWPROT_ESMFOLD_MODEL_ID", "facebook/esmfold_v1")
|
| self._requested_device = os.getenv("FLOWPROT_SC_DEVICE", "auto").strip().lower()
|
| self._seed = int(os.getenv("FLOWPROT_SC_SEED", "123"))
|
| self._folding_model = None
|
| self._tokenizer = None
|
| self._folding_device = None
|
|
|
| def health_check(self) -> Dict[str, object]:
|
| return {
|
| "pmpnn_dir": str(self._pmpnn_dir),
|
| "pmpnn_available": self._pmpnn_dir.exists(),
|
| "pmpnn_weights_dir": str(self._pmpnn_weights_dir),
|
| "pmpnn_weights_available": self._pmpnn_weights_dir.exists(),
|
| "pmpnn_model_name": self._pmpnn_model_name,
|
| "esmfold_model_id": self._esmfold_model_id,
|
| "sc_device": self._requested_device,
|
| "folding_model_loaded": self._folding_model is not None,
|
| "folding_device": self._folding_device,
|
| }
|
|
|
| def run(
|
| self,
|
| sample_path: str,
|
| num_seq_per_target: int = 4,
|
| run_folding: bool = True,
|
| progress_callback=None,
|
| ) -> SelfConsistencyResult:
|
| sample_file = Path(sample_path).expanduser().resolve()
|
| if not sample_file.exists() or not sample_file.is_file():
|
| raise SelfConsistencyError(f"Sample PDB does not exist: {sample_file}")
|
| if num_seq_per_target < 1 or num_seq_per_target > 64:
|
| raise SelfConsistencyError("num_seq_per_target must be in [1, 64].")
|
|
|
| parse_script = self._pmpnn_dir / "helper_scripts" / "parse_multiple_chains.py"
|
| pmpnn_script = self._pmpnn_dir / "protein_mpnn_run.py"
|
| if not parse_script.exists() or not pmpnn_script.exists():
|
| raise SelfConsistencyError(
|
| "ProteinMPNN scripts not found. Set FLOWPROT_PMPNN_DIR to a valid directory."
|
| )
|
| if not self._pmpnn_weights_dir.exists():
|
| raise SelfConsistencyError(
|
| "ProteinMPNN weights directory not found. "
|
| "Set FLOWPROT_PMPNN_WEIGHTS_DIR to a valid model weights directory."
|
| )
|
|
|
| run_stamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
| run_dir = sample_file.parent / "self_consistency" / run_stamp
|
| run_dir.mkdir(parents=True, exist_ok=True)
|
| (run_dir / "seqs").mkdir(parents=True, exist_ok=True)
|
| copied_sample = run_dir / sample_file.name
|
| shutil.copy2(sample_file, copied_sample)
|
|
|
| parsed_pdbs_jsonl = run_dir / "parsed_pdbs.jsonl"
|
| self._run_command(
|
| [
|
| sys.executable,
|
| str(parse_script),
|
| f"--input_path={run_dir.as_posix()}",
|
| f"--output_path={parsed_pdbs_jsonl.as_posix()}",
|
| ]
|
| )
|
| self._sanitize_parsed_jsonl(parsed_pdbs_jsonl)
|
|
|
| pmpnn_args = [
|
| sys.executable,
|
| str(pmpnn_script),
|
| "--out_folder",
|
| run_dir.as_posix(),
|
| "--jsonl_path",
|
| parsed_pdbs_jsonl.as_posix(),
|
| "--num_seq_per_target",
|
| str(num_seq_per_target),
|
| "--sampling_temp",
|
| "0.1",
|
| "--seed",
|
| str(self._seed),
|
| "--batch_size",
|
| "1",
|
| "--path_to_model_weights",
|
| self._pmpnn_weights_dir.as_posix(),
|
| "--model_name",
|
| self._pmpnn_model_name,
|
| ]
|
| device_for_mpnn = self._resolve_device()
|
| if device_for_mpnn.startswith("cuda"):
|
| gpu_id = device_for_mpnn.split(":")[1] if ":" in device_for_mpnn else "0"
|
| pmpnn_args.extend(["--device", gpu_id])
|
| self._run_command(pmpnn_args)
|
|
|
| mpnn_fasta_path = run_dir / "seqs" / sample_file.name.replace(".pdb", ".fa")
|
| sequences = _read_fasta(mpnn_fasta_path)
|
|
|
| folded_pdb_paths: List[str] = []
|
| per_sequence_metrics: List[Dict[str, object]] = []
|
| metrics_csv_path: Optional[str] = None
|
|
|
| if run_folding:
|
| try:
|
| from utils import new_pdbUtils as du
|
| except ModuleNotFoundError as exc:
|
| raise SelfConsistencyError(
|
| "Missing dependency for ESMFold metric computation: module `Bio` not found. "
|
| "Install requirements in your runtime environment (e.g. `pip install -r requirements.txt`)."
|
| ) from exc
|
|
|
| esmf_dir = run_dir / "esmf"
|
| esmf_dir.mkdir(parents=True, exist_ok=True)
|
| reference_feats = du.parse_pdb_feats("sample", str(copied_sample))
|
| reference_bb = reference_feats["bb_positions"]
|
| reference_seq = du.aatype_to_seq(reference_feats["aatype"])
|
|
|
| fold_total = max(
|
| len(sequences) - (1 if SKIP_REFERENCE_SEQUENCE and sequences else 0), 0
|
| )
|
| fold_done = 0
|
| for idx, (header, sequence) in enumerate(sequences):
|
| if SKIP_REFERENCE_SEQUENCE and idx == 0:
|
| continue
|
| if progress_callback is not None:
|
| progress_callback(fold_done, fold_total)
|
| folded_path = esmf_dir / f"sample_{idx}.pdb"
|
| self._run_folding(sequence=sequence, save_path=folded_path)
|
| folded_feats = du.parse_pdb_feats("folded_sample", str(folded_path))
|
| _, tm_score = _calc_tm_score(
|
| reference_bb,
|
| folded_feats["bb_positions"],
|
| reference_seq,
|
| reference_seq,
|
| )
|
| rmsd = _calc_aligned_rmsd(reference_bb, folded_feats["bb_positions"])
|
| mean_plddt = _mean_bfactor_from_pdb(folded_path)
|
| folded_pdb_paths.append(str(folded_path))
|
| per_sequence_metrics.append(
|
| {
|
| "index": idx,
|
| "header": header,
|
| "sequence": sequence,
|
| "scTM": float(tm_score),
|
| "scRMSD": float(rmsd),
|
| "esmfold_mean_plddt": mean_plddt,
|
| "folded_sample_path": str(folded_path),
|
| }
|
| )
|
| fold_done += 1
|
| if progress_callback is not None:
|
| progress_callback(fold_done, fold_total)
|
|
|
| metrics_csv_path = str(run_dir / "sc_results.csv")
|
| with Path(metrics_csv_path).open("w", encoding="utf-8", newline="") as handle:
|
| writer = csv.DictWriter(
|
| handle,
|
| fieldnames=[
|
| "index",
|
| "header",
|
| "sequence",
|
| "scTM",
|
| "scRMSD",
|
| "esmfold_mean_plddt",
|
| "folded_sample_path",
|
| ],
|
| )
|
| writer.writeheader()
|
| writer.writerows(per_sequence_metrics)
|
|
|
| skipped_sequences = 1 if (SKIP_REFERENCE_SEQUENCE and len(sequences) > 0) else 0
|
| target_sequences = max(len(sequences) - skipped_sequences, 0)
|
| summary: Dict[str, object] = {
|
| "sample_path": str(sample_file),
|
| "run_dir": str(run_dir),
|
| "num_sequences_total": len(sequences),
|
| "num_sequences_targeted": target_sequences,
|
| "num_sequences_evaluated": len(per_sequence_metrics) if run_folding else target_sequences,
|
| "skipped_reference_sequence": bool(SKIP_REFERENCE_SEQUENCE),
|
| "run_folding": bool(run_folding),
|
| }
|
| if per_sequence_metrics:
|
| tm_values = [float(item["scTM"]) for item in per_sequence_metrics]
|
| rmsd_values = [float(item["scRMSD"]) for item in per_sequence_metrics]
|
| plddt_values = [
|
| float(item["esmfold_mean_plddt"])
|
| for item in per_sequence_metrics
|
| if item.get("esmfold_mean_plddt") is not None
|
| ]
|
| summary.update(
|
| {
|
| "mean_scTM": _safe_mean(tm_values),
|
| "max_scTM": max(tm_values),
|
| "min_scTM": min(tm_values),
|
| "mean_scRMSD": _safe_mean(rmsd_values),
|
| "min_scRMSD": min(rmsd_values),
|
| "max_scRMSD": max(rmsd_values),
|
| "mean_esmfold_plddt": _safe_mean(plddt_values),
|
| }
|
| )
|
|
|
| return SelfConsistencyResult(
|
| sample_path=str(sample_file),
|
| run_dir=str(run_dir),
|
| parsed_pdbs_jsonl=str(parsed_pdbs_jsonl),
|
| mpnn_fasta_path=str(mpnn_fasta_path),
|
| run_folding=bool(run_folding),
|
| metrics_csv_path=metrics_csv_path,
|
| folded_pdb_paths=folded_pdb_paths,
|
| per_sequence_metrics=per_sequence_metrics,
|
| summary=summary,
|
| )
|
|
|
| def _run_command(self, args: List[str]) -> None:
|
| LOGGER.info("Running command: %s", " ".join(args))
|
| process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
|
| stdout, _ = process.communicate()
|
| if process.returncode != 0:
|
| LOGGER.error("Command failed (%s): %s", process.returncode, " ".join(args))
|
| raise SelfConsistencyError(
|
| f"Command failed ({process.returncode}): {' '.join(args)}\n{stdout}"
|
| )
|
|
|
| def _sanitize_parsed_jsonl(self, jsonl_path: Path) -> None:
|
| if not jsonl_path.exists():
|
| raise SelfConsistencyError(f"Parsed PDB JSONL not found: {jsonl_path}")
|
|
|
| sanitized_lines: List[str] = []
|
| for raw_line in jsonl_path.read_text(encoding="utf-8", errors="replace").splitlines():
|
| if not raw_line.strip():
|
| continue
|
| payload = json.loads(raw_line)
|
| raw_name = str(payload.get("name", "sample"))
|
| payload["name"] = Path(raw_name.replace("\\", "/")).name
|
| sanitized_lines.append(json.dumps(payload))
|
|
|
| if not sanitized_lines:
|
| raise SelfConsistencyError(f"No parsed structures found in {jsonl_path}")
|
| jsonl_path.write_text("\n".join(sanitized_lines) + "\n", encoding="utf-8")
|
|
|
| def _resolve_device(self) -> str:
|
| import torch
|
|
|
| if self._requested_device in {"", "auto"}:
|
| return "cuda:0" if torch.cuda.is_available() else "cpu"
|
| if self._requested_device.startswith("cuda") and not torch.cuda.is_available():
|
| raise SelfConsistencyError(
|
| f"FLOWPROT_SC_DEVICE={self._requested_device} requested, but CUDA is unavailable."
|
| )
|
| return self._requested_device
|
|
|
| def _ensure_folding_model(self) -> None:
|
| if self._folding_model is not None and self._tokenizer is not None:
|
| return
|
|
|
| try:
|
| from transformers import AutoTokenizer, EsmForProteinFolding
|
| except ImportError as exc:
|
| raise SelfConsistencyError(
|
| "transformers is required for ESMFold. Install `transformers`."
|
| ) from exc
|
|
|
| import torch
|
|
|
| device = self._resolve_device()
|
| LOGGER.info("Loading ESMFold model '%s' on %s", self._esmfold_model_id, device)
|
| tokenizer = AutoTokenizer.from_pretrained(self._esmfold_model_id)
|
| folding_model = EsmForProteinFolding.from_pretrained(
|
| self._esmfold_model_id,
|
| low_cpu_mem_usage=True,
|
| )
|
| if device.startswith("cuda"):
|
| folding_model.esm = folding_model.esm.half()
|
| folding_model = folding_model.to(device)
|
| folding_model.eval()
|
|
|
| self._tokenizer = tokenizer
|
| self._folding_model = folding_model
|
| self._folding_device = device
|
|
|
| def _run_folding(self, sequence: str, save_path: Path) -> None:
|
| self._ensure_folding_model()
|
|
|
| import torch
|
| from transformers.models.esm.openfold_utils.feats import atom14_to_atom37
|
| from transformers.models.esm.openfold_utils.protein import Protein as OFProtein
|
| from transformers.models.esm.openfold_utils.protein import to_pdb
|
|
|
| assert self._tokenizer is not None
|
| assert self._folding_model is not None
|
| assert self._folding_device is not None
|
|
|
| tokenized_input = self._tokenizer(
|
| [sequence],
|
| return_tensors="pt",
|
| add_special_tokens=False,
|
| )["input_ids"].to(self._folding_device)
|
|
|
| with torch.no_grad():
|
| output = self._folding_model(tokenized_input)
|
|
|
| final_atom_positions = atom14_to_atom37(output["positions"][-1], output)
|
| output_np = {k: v.to("cpu").numpy() for k, v in output.items()}
|
| final_atom_positions_np = final_atom_positions.cpu().numpy()
|
| final_atom_mask = output_np["atom37_atom_exists"]
|
|
|
| pdb_chunks: List[str] = []
|
| for i in range(output_np["aatype"].shape[0]):
|
| pred = OFProtein(
|
| aatype=output_np["aatype"][i],
|
| atom_positions=final_atom_positions_np[i],
|
| atom_mask=final_atom_mask[i],
|
| residue_index=output_np["residue_index"][i] + 1,
|
| b_factors=output_np["plddt"][i],
|
| chain_index=output_np["chain_index"][i] if "chain_index" in output_np else None,
|
| )
|
| pdb_chunks.append(to_pdb(pred))
|
|
|
| save_path.parent.mkdir(parents=True, exist_ok=True)
|
| save_path.write_text("".join(pdb_chunks), encoding="utf-8")
|
|
|