FlowProt / self_consistency.py
alibtsd's picture
Deploy FlowProt Docker Space
f34af6f verified
Raw
History Blame Contribute Delete
17.9 kB
"""Self-consistency evaluation using ProteinMPNN and ESMFold."""
from __future__ import annotations
import csv
import json
import logging
import os
import shutil
import subprocess
import sys
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
from model_loader import REPO_ROOT
LOGGER = logging.getLogger(__name__)
SKIP_REFERENCE_SEQUENCE = True
class SelfConsistencyError(RuntimeError):
"""Raised when self-consistency evaluation fails."""
@dataclass
class SelfConsistencyResult:
sample_path: str
run_dir: str
parsed_pdbs_jsonl: str
mpnn_fasta_path: str
run_folding: bool
metrics_csv_path: Optional[str]
folded_pdb_paths: List[str]
per_sequence_metrics: List[Dict[str, object]]
summary: Dict[str, object]
def _as_path(path_value: str) -> Path:
raw = Path(path_value).expanduser()
return raw if raw.is_absolute() else (REPO_ROOT / raw).resolve()
def _safe_mean(values: List[float]) -> Optional[float]:
return float(np.mean(values)) if values else None
def _mean_bfactor_from_pdb(pdb_path: Path) -> Optional[float]:
values: List[float] = []
for line in pdb_path.read_text(encoding="utf-8", errors="replace").splitlines():
if not (line.startswith("ATOM") or line.startswith("HETATM")):
continue
field = line[60:66].strip()
if not field:
continue
try:
values.append(float(field))
except ValueError:
continue
return _safe_mean(values)
def _read_fasta(path: Path) -> List[Tuple[str, str]]:
if not path.exists():
raise SelfConsistencyError(f"ProteinMPNN FASTA output not found: {path}")
entries: List[Tuple[str, str]] = []
header: Optional[str] = None
seq_lines: List[str] = []
for raw_line in path.read_text(encoding="utf-8", errors="replace").splitlines():
line = raw_line.strip()
if not line:
continue
if line.startswith(">"):
if header is not None:
entries.append((header, "".join(seq_lines)))
header = line[1:].strip()
seq_lines = []
continue
seq_lines.append(line)
if header is not None:
entries.append((header, "".join(seq_lines)))
if not entries:
raise SelfConsistencyError(f"ProteinMPNN FASTA output had no sequences: {path}")
return entries
def _calc_tm_score(pos_1, pos_2, seq_1: str, seq_2: str) -> Tuple[float, float]:
try:
from tmtools import tm_align
except ImportError as exc: # pragma: no cover
raise SelfConsistencyError(
"tmtools is required for scTM calculation. Install `tmtools`."
) from exc
tm_results = tm_align(pos_1, pos_2, seq_1, seq_2)
return float(tm_results.tm_norm_chain1), float(tm_results.tm_norm_chain2)
def _calc_aligned_rmsd(pos_1, pos_2) -> float:
try:
from utils import new_pdbUtils as du
except ModuleNotFoundError as exc: # pragma: no cover
raise SelfConsistencyError(
"Missing dependency for structure parsing/alignment. "
"Install `biopython` (module `Bio`) and re-run self-consistency."
) from exc
aligned = du.rigid_transform_3D(pos_1, pos_2)[0]
return float(np.mean(np.linalg.norm(aligned - pos_2, axis=-1)))
class FlowProtSelfConsistencyService:
"""Runs self-consistency metrics for generated structures."""
def __init__(self) -> None:
self._pmpnn_dir = _as_path(os.getenv("FLOWPROT_PMPNN_DIR", "model/ProteinMPNN"))
self._pmpnn_weights_dir = _as_path(
os.getenv(
"FLOWPROT_PMPNN_WEIGHTS_DIR",
str(self._pmpnn_dir / "vanilla_model_weights"),
)
)
self._pmpnn_model_name = os.getenv("FLOWPROT_PMPNN_MODEL_NAME", "v_48_020")
self._esmfold_model_id = os.getenv("FLOWPROT_ESMFOLD_MODEL_ID", "facebook/esmfold_v1")
self._requested_device = os.getenv("FLOWPROT_SC_DEVICE", "auto").strip().lower()
self._seed = int(os.getenv("FLOWPROT_SC_SEED", "123"))
self._folding_model = None
self._tokenizer = None
self._folding_device = None
def health_check(self) -> Dict[str, object]:
return {
"pmpnn_dir": str(self._pmpnn_dir),
"pmpnn_available": self._pmpnn_dir.exists(),
"pmpnn_weights_dir": str(self._pmpnn_weights_dir),
"pmpnn_weights_available": self._pmpnn_weights_dir.exists(),
"pmpnn_model_name": self._pmpnn_model_name,
"esmfold_model_id": self._esmfold_model_id,
"sc_device": self._requested_device,
"folding_model_loaded": self._folding_model is not None,
"folding_device": self._folding_device,
}
def run(
self,
sample_path: str,
num_seq_per_target: int = 4,
run_folding: bool = True,
progress_callback=None,
) -> SelfConsistencyResult:
sample_file = Path(sample_path).expanduser().resolve()
if not sample_file.exists() or not sample_file.is_file():
raise SelfConsistencyError(f"Sample PDB does not exist: {sample_file}")
if num_seq_per_target < 1 or num_seq_per_target > 64:
raise SelfConsistencyError("num_seq_per_target must be in [1, 64].")
parse_script = self._pmpnn_dir / "helper_scripts" / "parse_multiple_chains.py"
pmpnn_script = self._pmpnn_dir / "protein_mpnn_run.py"
if not parse_script.exists() or not pmpnn_script.exists():
raise SelfConsistencyError(
"ProteinMPNN scripts not found. Set FLOWPROT_PMPNN_DIR to a valid directory."
)
if not self._pmpnn_weights_dir.exists():
raise SelfConsistencyError(
"ProteinMPNN weights directory not found. "
"Set FLOWPROT_PMPNN_WEIGHTS_DIR to a valid model weights directory."
)
run_stamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
run_dir = sample_file.parent / "self_consistency" / run_stamp
run_dir.mkdir(parents=True, exist_ok=True)
(run_dir / "seqs").mkdir(parents=True, exist_ok=True)
copied_sample = run_dir / sample_file.name
shutil.copy2(sample_file, copied_sample)
parsed_pdbs_jsonl = run_dir / "parsed_pdbs.jsonl"
self._run_command(
[
sys.executable,
str(parse_script),
f"--input_path={run_dir.as_posix()}",
f"--output_path={parsed_pdbs_jsonl.as_posix()}",
]
)
self._sanitize_parsed_jsonl(parsed_pdbs_jsonl)
pmpnn_args = [
sys.executable,
str(pmpnn_script),
"--out_folder",
run_dir.as_posix(),
"--jsonl_path",
parsed_pdbs_jsonl.as_posix(),
"--num_seq_per_target",
str(num_seq_per_target),
"--sampling_temp",
"0.1",
"--seed",
str(self._seed),
"--batch_size",
"1",
"--path_to_model_weights",
self._pmpnn_weights_dir.as_posix(),
"--model_name",
self._pmpnn_model_name,
]
device_for_mpnn = self._resolve_device()
if device_for_mpnn.startswith("cuda"):
gpu_id = device_for_mpnn.split(":")[1] if ":" in device_for_mpnn else "0"
pmpnn_args.extend(["--device", gpu_id])
self._run_command(pmpnn_args)
mpnn_fasta_path = run_dir / "seqs" / sample_file.name.replace(".pdb", ".fa")
sequences = _read_fasta(mpnn_fasta_path)
folded_pdb_paths: List[str] = []
per_sequence_metrics: List[Dict[str, object]] = []
metrics_csv_path: Optional[str] = None
if run_folding:
try:
from utils import new_pdbUtils as du
except ModuleNotFoundError as exc:
raise SelfConsistencyError(
"Missing dependency for ESMFold metric computation: module `Bio` not found. "
"Install requirements in your runtime environment (e.g. `pip install -r requirements.txt`)."
) from exc
esmf_dir = run_dir / "esmf"
esmf_dir.mkdir(parents=True, exist_ok=True)
reference_feats = du.parse_pdb_feats("sample", str(copied_sample))
reference_bb = reference_feats["bb_positions"]
reference_seq = du.aatype_to_seq(reference_feats["aatype"])
fold_total = max(
len(sequences) - (1 if SKIP_REFERENCE_SEQUENCE and sequences else 0), 0
)
fold_done = 0
for idx, (header, sequence) in enumerate(sequences):
if SKIP_REFERENCE_SEQUENCE and idx == 0:
continue
if progress_callback is not None:
progress_callback(fold_done, fold_total)
folded_path = esmf_dir / f"sample_{idx}.pdb"
self._run_folding(sequence=sequence, save_path=folded_path)
folded_feats = du.parse_pdb_feats("folded_sample", str(folded_path))
_, tm_score = _calc_tm_score(
reference_bb,
folded_feats["bb_positions"],
reference_seq,
reference_seq,
)
rmsd = _calc_aligned_rmsd(reference_bb, folded_feats["bb_positions"])
mean_plddt = _mean_bfactor_from_pdb(folded_path)
folded_pdb_paths.append(str(folded_path))
per_sequence_metrics.append(
{
"index": idx,
"header": header,
"sequence": sequence,
"scTM": float(tm_score),
"scRMSD": float(rmsd),
"esmfold_mean_plddt": mean_plddt,
"folded_sample_path": str(folded_path),
}
)
fold_done += 1
if progress_callback is not None:
progress_callback(fold_done, fold_total)
metrics_csv_path = str(run_dir / "sc_results.csv")
with Path(metrics_csv_path).open("w", encoding="utf-8", newline="") as handle:
writer = csv.DictWriter(
handle,
fieldnames=[
"index",
"header",
"sequence",
"scTM",
"scRMSD",
"esmfold_mean_plddt",
"folded_sample_path",
],
)
writer.writeheader()
writer.writerows(per_sequence_metrics)
skipped_sequences = 1 if (SKIP_REFERENCE_SEQUENCE and len(sequences) > 0) else 0
target_sequences = max(len(sequences) - skipped_sequences, 0)
summary: Dict[str, object] = {
"sample_path": str(sample_file),
"run_dir": str(run_dir),
"num_sequences_total": len(sequences),
"num_sequences_targeted": target_sequences,
"num_sequences_evaluated": len(per_sequence_metrics) if run_folding else target_sequences,
"skipped_reference_sequence": bool(SKIP_REFERENCE_SEQUENCE),
"run_folding": bool(run_folding),
}
if per_sequence_metrics:
tm_values = [float(item["scTM"]) for item in per_sequence_metrics]
rmsd_values = [float(item["scRMSD"]) for item in per_sequence_metrics]
plddt_values = [
float(item["esmfold_mean_plddt"])
for item in per_sequence_metrics
if item.get("esmfold_mean_plddt") is not None
]
summary.update(
{
"mean_scTM": _safe_mean(tm_values),
"max_scTM": max(tm_values),
"min_scTM": min(tm_values),
"mean_scRMSD": _safe_mean(rmsd_values),
"min_scRMSD": min(rmsd_values),
"max_scRMSD": max(rmsd_values),
"mean_esmfold_plddt": _safe_mean(plddt_values),
}
)
return SelfConsistencyResult(
sample_path=str(sample_file),
run_dir=str(run_dir),
parsed_pdbs_jsonl=str(parsed_pdbs_jsonl),
mpnn_fasta_path=str(mpnn_fasta_path),
run_folding=bool(run_folding),
metrics_csv_path=metrics_csv_path,
folded_pdb_paths=folded_pdb_paths,
per_sequence_metrics=per_sequence_metrics,
summary=summary,
)
def _run_command(self, args: List[str]) -> None:
LOGGER.info("Running command: %s", " ".join(args))
process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
stdout, _ = process.communicate()
if process.returncode != 0:
LOGGER.error("Command failed (%s): %s", process.returncode, " ".join(args))
raise SelfConsistencyError(
f"Command failed ({process.returncode}): {' '.join(args)}\n{stdout}"
)
def _sanitize_parsed_jsonl(self, jsonl_path: Path) -> None:
if not jsonl_path.exists():
raise SelfConsistencyError(f"Parsed PDB JSONL not found: {jsonl_path}")
sanitized_lines: List[str] = []
for raw_line in jsonl_path.read_text(encoding="utf-8", errors="replace").splitlines():
if not raw_line.strip():
continue
payload = json.loads(raw_line)
raw_name = str(payload.get("name", "sample"))
payload["name"] = Path(raw_name.replace("\\", "/")).name
sanitized_lines.append(json.dumps(payload))
if not sanitized_lines:
raise SelfConsistencyError(f"No parsed structures found in {jsonl_path}")
jsonl_path.write_text("\n".join(sanitized_lines) + "\n", encoding="utf-8")
def _resolve_device(self) -> str:
import torch
if self._requested_device in {"", "auto"}:
return "cuda:0" if torch.cuda.is_available() else "cpu"
if self._requested_device.startswith("cuda") and not torch.cuda.is_available():
raise SelfConsistencyError(
f"FLOWPROT_SC_DEVICE={self._requested_device} requested, but CUDA is unavailable."
)
return self._requested_device
def _ensure_folding_model(self) -> None:
if self._folding_model is not None and self._tokenizer is not None:
return
try:
from transformers import AutoTokenizer, EsmForProteinFolding
except ImportError as exc: # pragma: no cover
raise SelfConsistencyError(
"transformers is required for ESMFold. Install `transformers`."
) from exc
import torch
device = self._resolve_device()
LOGGER.info("Loading ESMFold model '%s' on %s", self._esmfold_model_id, device)
tokenizer = AutoTokenizer.from_pretrained(self._esmfold_model_id)
folding_model = EsmForProteinFolding.from_pretrained(
self._esmfold_model_id,
low_cpu_mem_usage=True,
)
if device.startswith("cuda"):
folding_model.esm = folding_model.esm.half()
folding_model = folding_model.to(device)
folding_model.eval()
self._tokenizer = tokenizer
self._folding_model = folding_model
self._folding_device = device
def _run_folding(self, sequence: str, save_path: Path) -> None:
self._ensure_folding_model()
import torch
from transformers.models.esm.openfold_utils.feats import atom14_to_atom37
from transformers.models.esm.openfold_utils.protein import Protein as OFProtein
from transformers.models.esm.openfold_utils.protein import to_pdb
assert self._tokenizer is not None
assert self._folding_model is not None
assert self._folding_device is not None
tokenized_input = self._tokenizer(
[sequence],
return_tensors="pt",
add_special_tokens=False,
)["input_ids"].to(self._folding_device)
with torch.no_grad():
output = self._folding_model(tokenized_input)
final_atom_positions = atom14_to_atom37(output["positions"][-1], output)
output_np = {k: v.to("cpu").numpy() for k, v in output.items()}
final_atom_positions_np = final_atom_positions.cpu().numpy()
final_atom_mask = output_np["atom37_atom_exists"]
pdb_chunks: List[str] = []
for i in range(output_np["aatype"].shape[0]):
pred = OFProtein(
aatype=output_np["aatype"][i],
atom_positions=final_atom_positions_np[i],
atom_mask=final_atom_mask[i],
residue_index=output_np["residue_index"][i] + 1,
b_factors=output_np["plddt"][i],
chain_index=output_np["chain_index"][i] if "chain_index" in output_np else None,
)
pdb_chunks.append(to_pdb(pred))
save_path.parent.mkdir(parents=True, exist_ok=True)
save_path.write_text("".join(pdb_chunks), encoding="utf-8")