"""Self-consistency evaluation using ProteinMPNN and ESMFold.""" from __future__ import annotations import csv import json import logging import os import shutil import subprocess import sys from dataclasses import dataclass from datetime import datetime from pathlib import Path from typing import Dict, List, Optional, Tuple import numpy as np from model_loader import REPO_ROOT LOGGER = logging.getLogger(__name__) SKIP_REFERENCE_SEQUENCE = True class SelfConsistencyError(RuntimeError): """Raised when self-consistency evaluation fails.""" @dataclass class SelfConsistencyResult: sample_path: str run_dir: str parsed_pdbs_jsonl: str mpnn_fasta_path: str run_folding: bool metrics_csv_path: Optional[str] folded_pdb_paths: List[str] per_sequence_metrics: List[Dict[str, object]] summary: Dict[str, object] def _as_path(path_value: str) -> Path: raw = Path(path_value).expanduser() return raw if raw.is_absolute() else (REPO_ROOT / raw).resolve() def _safe_mean(values: List[float]) -> Optional[float]: return float(np.mean(values)) if values else None def _mean_bfactor_from_pdb(pdb_path: Path) -> Optional[float]: values: List[float] = [] for line in pdb_path.read_text(encoding="utf-8", errors="replace").splitlines(): if not (line.startswith("ATOM") or line.startswith("HETATM")): continue field = line[60:66].strip() if not field: continue try: values.append(float(field)) except ValueError: continue return _safe_mean(values) def _read_fasta(path: Path) -> List[Tuple[str, str]]: if not path.exists(): raise SelfConsistencyError(f"ProteinMPNN FASTA output not found: {path}") entries: List[Tuple[str, str]] = [] header: Optional[str] = None seq_lines: List[str] = [] for raw_line in path.read_text(encoding="utf-8", errors="replace").splitlines(): line = raw_line.strip() if not line: continue if line.startswith(">"): if header is not None: entries.append((header, "".join(seq_lines))) header = line[1:].strip() seq_lines = [] continue seq_lines.append(line) if header is not None: entries.append((header, "".join(seq_lines))) if not entries: raise SelfConsistencyError(f"ProteinMPNN FASTA output had no sequences: {path}") return entries def _calc_tm_score(pos_1, pos_2, seq_1: str, seq_2: str) -> Tuple[float, float]: try: from tmtools import tm_align except ImportError as exc: # pragma: no cover raise SelfConsistencyError( "tmtools is required for scTM calculation. Install `tmtools`." ) from exc tm_results = tm_align(pos_1, pos_2, seq_1, seq_2) return float(tm_results.tm_norm_chain1), float(tm_results.tm_norm_chain2) def _calc_aligned_rmsd(pos_1, pos_2) -> float: try: from utils import new_pdbUtils as du except ModuleNotFoundError as exc: # pragma: no cover raise SelfConsistencyError( "Missing dependency for structure parsing/alignment. " "Install `biopython` (module `Bio`) and re-run self-consistency." ) from exc aligned = du.rigid_transform_3D(pos_1, pos_2)[0] return float(np.mean(np.linalg.norm(aligned - pos_2, axis=-1))) class FlowProtSelfConsistencyService: """Runs self-consistency metrics for generated structures.""" def __init__(self) -> None: self._pmpnn_dir = _as_path(os.getenv("FLOWPROT_PMPNN_DIR", "model/ProteinMPNN")) self._pmpnn_weights_dir = _as_path( os.getenv( "FLOWPROT_PMPNN_WEIGHTS_DIR", str(self._pmpnn_dir / "vanilla_model_weights"), ) ) self._pmpnn_model_name = os.getenv("FLOWPROT_PMPNN_MODEL_NAME", "v_48_020") self._esmfold_model_id = os.getenv("FLOWPROT_ESMFOLD_MODEL_ID", "facebook/esmfold_v1") self._requested_device = os.getenv("FLOWPROT_SC_DEVICE", "auto").strip().lower() self._seed = int(os.getenv("FLOWPROT_SC_SEED", "123")) self._folding_model = None self._tokenizer = None self._folding_device = None def health_check(self) -> Dict[str, object]: return { "pmpnn_dir": str(self._pmpnn_dir), "pmpnn_available": self._pmpnn_dir.exists(), "pmpnn_weights_dir": str(self._pmpnn_weights_dir), "pmpnn_weights_available": self._pmpnn_weights_dir.exists(), "pmpnn_model_name": self._pmpnn_model_name, "esmfold_model_id": self._esmfold_model_id, "sc_device": self._requested_device, "folding_model_loaded": self._folding_model is not None, "folding_device": self._folding_device, } def run( self, sample_path: str, num_seq_per_target: int = 4, run_folding: bool = True, progress_callback=None, ) -> SelfConsistencyResult: sample_file = Path(sample_path).expanduser().resolve() if not sample_file.exists() or not sample_file.is_file(): raise SelfConsistencyError(f"Sample PDB does not exist: {sample_file}") if num_seq_per_target < 1 or num_seq_per_target > 64: raise SelfConsistencyError("num_seq_per_target must be in [1, 64].") parse_script = self._pmpnn_dir / "helper_scripts" / "parse_multiple_chains.py" pmpnn_script = self._pmpnn_dir / "protein_mpnn_run.py" if not parse_script.exists() or not pmpnn_script.exists(): raise SelfConsistencyError( "ProteinMPNN scripts not found. Set FLOWPROT_PMPNN_DIR to a valid directory." ) if not self._pmpnn_weights_dir.exists(): raise SelfConsistencyError( "ProteinMPNN weights directory not found. " "Set FLOWPROT_PMPNN_WEIGHTS_DIR to a valid model weights directory." ) run_stamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S") run_dir = sample_file.parent / "self_consistency" / run_stamp run_dir.mkdir(parents=True, exist_ok=True) (run_dir / "seqs").mkdir(parents=True, exist_ok=True) copied_sample = run_dir / sample_file.name shutil.copy2(sample_file, copied_sample) parsed_pdbs_jsonl = run_dir / "parsed_pdbs.jsonl" self._run_command( [ sys.executable, str(parse_script), f"--input_path={run_dir.as_posix()}", f"--output_path={parsed_pdbs_jsonl.as_posix()}", ] ) self._sanitize_parsed_jsonl(parsed_pdbs_jsonl) pmpnn_args = [ sys.executable, str(pmpnn_script), "--out_folder", run_dir.as_posix(), "--jsonl_path", parsed_pdbs_jsonl.as_posix(), "--num_seq_per_target", str(num_seq_per_target), "--sampling_temp", "0.1", "--seed", str(self._seed), "--batch_size", "1", "--path_to_model_weights", self._pmpnn_weights_dir.as_posix(), "--model_name", self._pmpnn_model_name, ] device_for_mpnn = self._resolve_device() if device_for_mpnn.startswith("cuda"): gpu_id = device_for_mpnn.split(":")[1] if ":" in device_for_mpnn else "0" pmpnn_args.extend(["--device", gpu_id]) self._run_command(pmpnn_args) mpnn_fasta_path = run_dir / "seqs" / sample_file.name.replace(".pdb", ".fa") sequences = _read_fasta(mpnn_fasta_path) folded_pdb_paths: List[str] = [] per_sequence_metrics: List[Dict[str, object]] = [] metrics_csv_path: Optional[str] = None if run_folding: try: from utils import new_pdbUtils as du except ModuleNotFoundError as exc: raise SelfConsistencyError( "Missing dependency for ESMFold metric computation: module `Bio` not found. " "Install requirements in your runtime environment (e.g. `pip install -r requirements.txt`)." ) from exc esmf_dir = run_dir / "esmf" esmf_dir.mkdir(parents=True, exist_ok=True) reference_feats = du.parse_pdb_feats("sample", str(copied_sample)) reference_bb = reference_feats["bb_positions"] reference_seq = du.aatype_to_seq(reference_feats["aatype"]) fold_total = max( len(sequences) - (1 if SKIP_REFERENCE_SEQUENCE and sequences else 0), 0 ) fold_done = 0 for idx, (header, sequence) in enumerate(sequences): if SKIP_REFERENCE_SEQUENCE and idx == 0: continue if progress_callback is not None: progress_callback(fold_done, fold_total) folded_path = esmf_dir / f"sample_{idx}.pdb" self._run_folding(sequence=sequence, save_path=folded_path) folded_feats = du.parse_pdb_feats("folded_sample", str(folded_path)) _, tm_score = _calc_tm_score( reference_bb, folded_feats["bb_positions"], reference_seq, reference_seq, ) rmsd = _calc_aligned_rmsd(reference_bb, folded_feats["bb_positions"]) mean_plddt = _mean_bfactor_from_pdb(folded_path) folded_pdb_paths.append(str(folded_path)) per_sequence_metrics.append( { "index": idx, "header": header, "sequence": sequence, "scTM": float(tm_score), "scRMSD": float(rmsd), "esmfold_mean_plddt": mean_plddt, "folded_sample_path": str(folded_path), } ) fold_done += 1 if progress_callback is not None: progress_callback(fold_done, fold_total) metrics_csv_path = str(run_dir / "sc_results.csv") with Path(metrics_csv_path).open("w", encoding="utf-8", newline="") as handle: writer = csv.DictWriter( handle, fieldnames=[ "index", "header", "sequence", "scTM", "scRMSD", "esmfold_mean_plddt", "folded_sample_path", ], ) writer.writeheader() writer.writerows(per_sequence_metrics) skipped_sequences = 1 if (SKIP_REFERENCE_SEQUENCE and len(sequences) > 0) else 0 target_sequences = max(len(sequences) - skipped_sequences, 0) summary: Dict[str, object] = { "sample_path": str(sample_file), "run_dir": str(run_dir), "num_sequences_total": len(sequences), "num_sequences_targeted": target_sequences, "num_sequences_evaluated": len(per_sequence_metrics) if run_folding else target_sequences, "skipped_reference_sequence": bool(SKIP_REFERENCE_SEQUENCE), "run_folding": bool(run_folding), } if per_sequence_metrics: tm_values = [float(item["scTM"]) for item in per_sequence_metrics] rmsd_values = [float(item["scRMSD"]) for item in per_sequence_metrics] plddt_values = [ float(item["esmfold_mean_plddt"]) for item in per_sequence_metrics if item.get("esmfold_mean_plddt") is not None ] summary.update( { "mean_scTM": _safe_mean(tm_values), "max_scTM": max(tm_values), "min_scTM": min(tm_values), "mean_scRMSD": _safe_mean(rmsd_values), "min_scRMSD": min(rmsd_values), "max_scRMSD": max(rmsd_values), "mean_esmfold_plddt": _safe_mean(plddt_values), } ) return SelfConsistencyResult( sample_path=str(sample_file), run_dir=str(run_dir), parsed_pdbs_jsonl=str(parsed_pdbs_jsonl), mpnn_fasta_path=str(mpnn_fasta_path), run_folding=bool(run_folding), metrics_csv_path=metrics_csv_path, folded_pdb_paths=folded_pdb_paths, per_sequence_metrics=per_sequence_metrics, summary=summary, ) def _run_command(self, args: List[str]) -> None: LOGGER.info("Running command: %s", " ".join(args)) process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True) stdout, _ = process.communicate() if process.returncode != 0: LOGGER.error("Command failed (%s): %s", process.returncode, " ".join(args)) raise SelfConsistencyError( f"Command failed ({process.returncode}): {' '.join(args)}\n{stdout}" ) def _sanitize_parsed_jsonl(self, jsonl_path: Path) -> None: if not jsonl_path.exists(): raise SelfConsistencyError(f"Parsed PDB JSONL not found: {jsonl_path}") sanitized_lines: List[str] = [] for raw_line in jsonl_path.read_text(encoding="utf-8", errors="replace").splitlines(): if not raw_line.strip(): continue payload = json.loads(raw_line) raw_name = str(payload.get("name", "sample")) payload["name"] = Path(raw_name.replace("\\", "/")).name sanitized_lines.append(json.dumps(payload)) if not sanitized_lines: raise SelfConsistencyError(f"No parsed structures found in {jsonl_path}") jsonl_path.write_text("\n".join(sanitized_lines) + "\n", encoding="utf-8") def _resolve_device(self) -> str: import torch if self._requested_device in {"", "auto"}: return "cuda:0" if torch.cuda.is_available() else "cpu" if self._requested_device.startswith("cuda") and not torch.cuda.is_available(): raise SelfConsistencyError( f"FLOWPROT_SC_DEVICE={self._requested_device} requested, but CUDA is unavailable." ) return self._requested_device def _ensure_folding_model(self) -> None: if self._folding_model is not None and self._tokenizer is not None: return try: from transformers import AutoTokenizer, EsmForProteinFolding except ImportError as exc: # pragma: no cover raise SelfConsistencyError( "transformers is required for ESMFold. Install `transformers`." ) from exc import torch device = self._resolve_device() LOGGER.info("Loading ESMFold model '%s' on %s", self._esmfold_model_id, device) tokenizer = AutoTokenizer.from_pretrained(self._esmfold_model_id) folding_model = EsmForProteinFolding.from_pretrained( self._esmfold_model_id, low_cpu_mem_usage=True, ) if device.startswith("cuda"): folding_model.esm = folding_model.esm.half() folding_model = folding_model.to(device) folding_model.eval() self._tokenizer = tokenizer self._folding_model = folding_model self._folding_device = device def _run_folding(self, sequence: str, save_path: Path) -> None: self._ensure_folding_model() import torch from transformers.models.esm.openfold_utils.feats import atom14_to_atom37 from transformers.models.esm.openfold_utils.protein import Protein as OFProtein from transformers.models.esm.openfold_utils.protein import to_pdb assert self._tokenizer is not None assert self._folding_model is not None assert self._folding_device is not None tokenized_input = self._tokenizer( [sequence], return_tensors="pt", add_special_tokens=False, )["input_ids"].to(self._folding_device) with torch.no_grad(): output = self._folding_model(tokenized_input) final_atom_positions = atom14_to_atom37(output["positions"][-1], output) output_np = {k: v.to("cpu").numpy() for k, v in output.items()} final_atom_positions_np = final_atom_positions.cpu().numpy() final_atom_mask = output_np["atom37_atom_exists"] pdb_chunks: List[str] = [] for i in range(output_np["aatype"].shape[0]): pred = OFProtein( aatype=output_np["aatype"][i], atom_positions=final_atom_positions_np[i], atom_mask=final_atom_mask[i], residue_index=output_np["residue_index"][i] + 1, b_factors=output_np["plddt"][i], chain_index=output_np["chain_index"][i] if "chain_index" in output_np else None, ) pdb_chunks.append(to_pdb(pred)) save_path.parent.mkdir(parents=True, exist_ok=True) save_path.write_text("".join(pdb_chunks), encoding="utf-8")