| """Inference orchestration for the FlowProt Hugging Face Space MVP."""
|
|
|
| from __future__ import annotations
|
|
|
| import logging
|
| from dataclasses import dataclass
|
| from datetime import datetime
|
| from pathlib import Path
|
| from typing import Dict, List, Optional
|
|
|
| import numpy as np
|
| import torch
|
| from omegaconf import OmegaConf
|
|
|
| from model_loader import (
|
| REPO_ROOT,
|
| ArtifactResolutionError,
|
| FlowProtClassifierManager,
|
| FlowProtModelManager,
|
| ModelLoadError,
|
| ensure_model_pythonpath,
|
| load_runtime_config,
|
| )
|
|
|
| ensure_model_pythonpath()
|
|
|
| from utils.experiments import save_traj
|
| from utils.flows import Interpolant
|
| from utils.modelUtils import to_numpy
|
| from utils.pdbUtils import parse_pdb_feats
|
|
|
| LOGGER = logging.getLogger(__name__)
|
|
|
|
|
| class InferenceError(RuntimeError):
|
| """Raised when runtime inference fails."""
|
|
|
|
|
| @dataclass
|
| class InferenceResult:
|
| mode: str
|
| run_dir: str
|
| sample_files: List[str]
|
| trajectory_files: List[str]
|
| x0_trajectory_files: List[str]
|
| seed: int
|
| artifacts_source: str
|
| guidance_scale: Optional[float] = None
|
| target_class: Optional[int] = None
|
| fixed_residue_count: Optional[int] = None
|
| num_timesteps: Optional[int] = None
|
|
|
|
|
| def _cfg_get(cfg, key: str, default):
|
| value = OmegaConf.select(cfg, key)
|
| return default if value is None else value
|
|
|
|
|
| class FlowProtInferenceService:
|
| """Service layer for Gradio UI and smoke checks."""
|
|
|
| def __init__(self, config_path: Optional[str] = None):
|
| self._config_path = config_path
|
| self._runtime_cfg = load_runtime_config(config_path=config_path)
|
| self._model_manager = FlowProtModelManager(config_path=config_path)
|
| self._classifier_manager = FlowProtClassifierManager(config_path=config_path)
|
| self._mvp_mode = str(_cfg_get(self._runtime_cfg, "app.mvp_mode", "unconditional")).lower()
|
| self._conditional_enabled = bool(
|
| _cfg_get(self._runtime_cfg, "app.enable_conditional", False)
|
| )
|
| self._classifier_enabled = bool(
|
| _cfg_get(self._runtime_cfg, "app.enable_classifier", True)
|
| )
|
| self._last_inference_error: Optional[str] = None
|
|
|
| if bool(_cfg_get(self._runtime_cfg, "app.load_on_startup", False)):
|
| try:
|
| self.preload_model()
|
| except Exception:
|
| LOGGER.exception("Startup model preload failed.")
|
|
|
| @property
|
| def mvp_mode(self) -> str:
|
| return self._mvp_mode
|
|
|
| @property
|
| def conditional_enabled(self) -> bool:
|
| return self._conditional_enabled
|
|
|
| @property
|
| def classifier_enabled(self) -> bool:
|
| return self._classifier_enabled
|
|
|
| def health_check(self) -> Dict[str, object]:
|
| loaded_ctx = self._model_manager.peek_loaded()
|
| loaded_clf = self._classifier_manager.peek_loaded()
|
| return {
|
| "status": "ok",
|
| "mvp_mode": self._mvp_mode,
|
| "conditional_enabled": self._conditional_enabled,
|
| "classifier_enabled": self._classifier_enabled,
|
| "model_loaded": self._model_manager.is_loaded,
|
| "classifier_loaded": self._classifier_manager.is_loaded,
|
| "device": str(loaded_ctx.device) if loaded_ctx else None,
|
| "artifacts_source": loaded_ctx.artifacts.source if loaded_ctx else None,
|
| "checkpoint_path": str(loaded_ctx.artifacts.ckpt_path) if loaded_ctx else None,
|
| "classifier_checkpoint_path": (
|
| str(loaded_clf.artifacts.ckpt_path) if loaded_clf else None
|
| ),
|
| "classifier_artifacts_source": (
|
| loaded_clf.artifacts.source if loaded_clf else None
|
| ),
|
| "loader_error": self._model_manager.last_error,
|
| "classifier_loader_error": self._classifier_manager.last_error,
|
| "inference_error": self._last_inference_error,
|
| }
|
|
|
| def preload_model(self) -> None:
|
| loaded = self._model_manager.load()
|
| if self._classifier_enabled:
|
| self._classifier_manager.load(device=loaded.device)
|
|
|
| def generate(
|
| self,
|
| length: int,
|
| num_samples: int,
|
| mode: Optional[str] = None,
|
| seed: Optional[int] = None,
|
| guidance_scale: Optional[float] = None,
|
| target_class: Optional[int] = None,
|
| num_timesteps: Optional[int] = None,
|
| reference_pdb_path: Optional[str] = None,
|
| fixed_residues: Optional[List[int]] = None,
|
| use_classifier_guidance: bool = False,
|
| ) -> InferenceResult:
|
| selected_mode = (mode or self._mvp_mode).lower()
|
| if selected_mode == "conditional":
|
| if not self._conditional_enabled:
|
| raise InferenceError(
|
| "Conditional mode is disabled for MVP. "
|
| "Set app.enable_conditional=true in config.yaml to enable the UI toggle."
|
| )
|
| return self.generate_conditional(
|
| num_samples=num_samples,
|
| reference_pdb_path=reference_pdb_path,
|
| fixed_residues=fixed_residues,
|
| seed=seed,
|
| num_timesteps=num_timesteps,
|
| use_classifier_guidance=use_classifier_guidance,
|
| guidance_scale=guidance_scale,
|
| target_class=target_class,
|
| )
|
| if selected_mode == "classifier":
|
| if not self._classifier_enabled:
|
| raise InferenceError(
|
| "Classifier-guided mode is disabled. "
|
| "Set app.enable_classifier=true in config.yaml to enable the UI toggle."
|
| )
|
| return self.generate_classifier_guided(
|
| length=length,
|
| num_samples=num_samples,
|
| seed=seed,
|
| guidance_scale=guidance_scale,
|
| target_class=target_class,
|
| num_timesteps=num_timesteps,
|
| )
|
| if selected_mode != "unconditional":
|
| raise InferenceError(f"Unsupported mode: {selected_mode}")
|
| return self.generate_unconditional(
|
| length=length,
|
| num_samples=num_samples,
|
| seed=seed,
|
| num_timesteps=num_timesteps,
|
| )
|
|
|
| @staticmethod
|
| def _resolve_interpolant_cfg(merged_cfg, num_timesteps: Optional[int]):
|
| interpolant_cfg = OmegaConf.select(merged_cfg, "inference.interpolant")
|
| if interpolant_cfg is None:
|
| raise InferenceError(
|
| "Missing interpolant config in merged runtime config (inference.interpolant)."
|
| )
|
| if num_timesteps is not None:
|
| steps = int(num_timesteps)
|
| if steps < 1 or steps > 1000:
|
| raise InferenceError("num_timesteps must be in [1, 1000].")
|
| interpolant_cfg = OmegaConf.merge(
|
| interpolant_cfg, OmegaConf.create({"sampling": {"num_timesteps": steps}})
|
| )
|
| return interpolant_cfg
|
|
|
| def generate_unconditional(
|
| self,
|
| length: int,
|
| num_samples: int,
|
| seed: Optional[int] = None,
|
| num_timesteps: Optional[int] = None,
|
| ) -> InferenceResult:
|
| min_length = int(_cfg_get(self._runtime_cfg, "app.min_length", 32))
|
| max_length = int(_cfg_get(self._runtime_cfg, "app.max_length", 1024))
|
| max_samples = int(_cfg_get(self._runtime_cfg, "app.max_samples_per_request", 4))
|
| if length < min_length or length > max_length:
|
| raise InferenceError(
|
| f"Length must be in [{min_length}, {max_length}] for this Space deployment."
|
| )
|
| if num_samples < 1 or num_samples > max_samples:
|
| raise InferenceError(
|
| f"num_samples must be in [1, {max_samples}] for this Space deployment."
|
| )
|
|
|
| try:
|
| loaded = self._model_manager.load()
|
| self._last_inference_error = None
|
|
|
| effective_seed = (
|
| int(seed)
|
| if seed is not None
|
| else int(_cfg_get(loaded.merged_cfg, "inference.seed", 123))
|
| )
|
| np.random.seed(effective_seed)
|
| torch.manual_seed(effective_seed)
|
| if loaded.device.type == "cuda":
|
| torch.cuda.manual_seed_all(effective_seed)
|
|
|
| interpolant_cfg = self._resolve_interpolant_cfg(loaded.merged_cfg, num_timesteps)
|
| effective_timesteps = int(
|
| _cfg_get(interpolant_cfg, "sampling.num_timesteps", 100)
|
| )
|
|
|
| output_root = Path(str(_cfg_get(loaded.merged_cfg, "inference.output_dir", "space_outputs")))
|
| if not output_root.is_absolute():
|
| output_root = (REPO_ROOT / output_root).resolve()
|
| run_id = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
| run_dir = output_root / f"space_unconditional_{run_id}"
|
| run_dir.mkdir(parents=True, exist_ok=True)
|
|
|
| interpolant = Interpolant(interpolant_cfg)
|
| interpolant.set_device(loaded.device)
|
|
|
| sample_files: List[str] = []
|
| trajectory_files: List[str] = []
|
| x0_trajectory_files: List[str] = []
|
|
|
| for sample_id in range(num_samples):
|
| sample_dir = run_dir / f"length_{length}" / f"sample_{sample_id}"
|
| sample_dir.mkdir(parents=True, exist_ok=True)
|
|
|
| atom37_traj, model_traj, _ = interpolant.sample(
|
| num_batch=1,
|
| num_res=length,
|
| model=loaded.model,
|
| )
|
| bb_traj = to_numpy(torch.concat(atom37_traj, dim=0))
|
| model_x0_traj = np.flip(to_numpy(torch.concat(model_traj, dim=0)), axis=0)
|
| saved = save_traj(
|
| sample=bb_traj[-1],
|
| bb_prot_traj=bb_traj,
|
| x0_traj=model_x0_traj,
|
| diffuse_mask=np.ones(length, dtype=np.float32),
|
| output_dir=str(sample_dir),
|
| )
|
| sample_files.append(saved["sample_path"])
|
| trajectory_files.append(saved["traj_path"])
|
| x0_trajectory_files.append(saved["x0_traj_path"])
|
|
|
| return InferenceResult(
|
| mode="unconditional",
|
| run_dir=str(run_dir),
|
| sample_files=sample_files,
|
| trajectory_files=trajectory_files,
|
| x0_trajectory_files=x0_trajectory_files,
|
| seed=effective_seed,
|
| artifacts_source=loaded.artifacts.source,
|
| num_timesteps=effective_timesteps,
|
| )
|
| except (ArtifactResolutionError, ModelLoadError, InferenceError) as exc:
|
| self._last_inference_error = str(exc)
|
| raise
|
| except Exception as exc:
|
| self._last_inference_error = str(exc)
|
| LOGGER.exception("Unexpected failure during unconditional inference.")
|
| raise InferenceError(str(exc)) from exc
|
|
|
| def generate_classifier_guided(
|
| self,
|
| length: int,
|
| num_samples: int,
|
| seed: Optional[int] = None,
|
| guidance_scale: Optional[float] = None,
|
| target_class: Optional[int] = None,
|
| num_timesteps: Optional[int] = None,
|
| ) -> InferenceResult:
|
| min_length = int(_cfg_get(self._runtime_cfg, "app.min_length", 32))
|
| max_length = int(_cfg_get(self._runtime_cfg, "app.max_length", 1024))
|
| max_samples = int(_cfg_get(self._runtime_cfg, "app.max_samples_per_request", 4))
|
| if length < min_length or length > max_length:
|
| raise InferenceError(
|
| f"Length must be in [{min_length}, {max_length}] for this Space deployment."
|
| )
|
| if num_samples < 1 or num_samples > max_samples:
|
| raise InferenceError(
|
| f"num_samples must be in [1, {max_samples}] for this Space deployment."
|
| )
|
|
|
| try:
|
| loaded = self._model_manager.load()
|
| classifier_ctx = self._classifier_manager.load(device=loaded.device)
|
| self._last_inference_error = None
|
|
|
| effective_seed = (
|
| int(seed)
|
| if seed is not None
|
| else int(_cfg_get(loaded.merged_cfg, "inference.seed", 123))
|
| )
|
| effective_guidance_scale = (
|
| float(guidance_scale)
|
| if guidance_scale is not None
|
| else float(
|
| _cfg_get(loaded.merged_cfg, "inference.classifier.guidance_scale", 0.2)
|
| )
|
| )
|
| effective_target_class = (
|
| int(target_class)
|
| if target_class is not None
|
| else int(_cfg_get(loaded.merged_cfg, "inference.classifier.target_class", 1))
|
| )
|
| if effective_target_class not in (0, 1):
|
| raise InferenceError("target_class must be 0 or 1 for the binary classifier.")
|
| if effective_guidance_scale < 0:
|
| raise InferenceError("guidance_scale must be non-negative.")
|
|
|
| np.random.seed(effective_seed)
|
| torch.manual_seed(effective_seed)
|
| if loaded.device.type == "cuda":
|
| torch.cuda.manual_seed_all(effective_seed)
|
|
|
| interpolant_cfg = self._resolve_interpolant_cfg(loaded.merged_cfg, num_timesteps)
|
| effective_timesteps = int(
|
| _cfg_get(interpolant_cfg, "sampling.num_timesteps", 100)
|
| )
|
|
|
| output_root = Path(str(_cfg_get(loaded.merged_cfg, "inference.output_dir", "space_outputs")))
|
| if not output_root.is_absolute():
|
| output_root = (REPO_ROOT / output_root).resolve()
|
| run_id = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
| run_dir = output_root / f"space_classifier_{run_id}"
|
| run_dir.mkdir(parents=True, exist_ok=True)
|
|
|
| interpolant = Interpolant(interpolant_cfg)
|
| interpolant.set_device(loaded.device)
|
|
|
| sample_files: List[str] = []
|
| trajectory_files: List[str] = []
|
| x0_trajectory_files: List[str] = []
|
|
|
| for sample_id in range(num_samples):
|
| sample_dir = run_dir / f"length_{length}" / f"sample_{sample_id}"
|
| sample_dir.mkdir(parents=True, exist_ok=True)
|
|
|
| atom37_traj, model_traj, _ = interpolant.sample_clf(
|
| num_batch=1,
|
| num_res=length,
|
| model=loaded.model,
|
| clf_model=classifier_ctx.classifier,
|
| guidance_scale=effective_guidance_scale,
|
| target_class=effective_target_class,
|
| )
|
| bb_traj = to_numpy(torch.concat(atom37_traj, dim=0))
|
| model_x0_traj = np.flip(to_numpy(torch.concat(model_traj, dim=0)), axis=0)
|
| saved = save_traj(
|
| sample=bb_traj[-1],
|
| bb_prot_traj=bb_traj,
|
| x0_traj=model_x0_traj,
|
| diffuse_mask=np.ones(length, dtype=np.float32),
|
| output_dir=str(sample_dir),
|
| )
|
| sample_files.append(saved["sample_path"])
|
| trajectory_files.append(saved["traj_path"])
|
| x0_trajectory_files.append(saved["x0_traj_path"])
|
|
|
| return InferenceResult(
|
| mode="classifier",
|
| run_dir=str(run_dir),
|
| sample_files=sample_files,
|
| trajectory_files=trajectory_files,
|
| x0_trajectory_files=x0_trajectory_files,
|
| seed=effective_seed,
|
| artifacts_source=loaded.artifacts.source,
|
| guidance_scale=effective_guidance_scale,
|
| target_class=effective_target_class,
|
| num_timesteps=effective_timesteps,
|
| )
|
| except (ArtifactResolutionError, ModelLoadError, InferenceError) as exc:
|
| self._last_inference_error = str(exc)
|
| raise
|
| except Exception as exc:
|
| self._last_inference_error = str(exc)
|
| LOGGER.exception("Unexpected failure during classifier-guided inference.")
|
| raise InferenceError(str(exc)) from exc
|
|
|
| def generate_conditional(
|
| self,
|
| num_samples: int,
|
| reference_pdb_path: Optional[str],
|
| fixed_residues: Optional[List[int]] = None,
|
| seed: Optional[int] = None,
|
| num_timesteps: Optional[int] = None,
|
| use_classifier_guidance: bool = False,
|
| guidance_scale: Optional[float] = None,
|
| target_class: Optional[int] = None,
|
| chain_id: str = "A",
|
| temperature: float = 1.0,
|
| ) -> InferenceResult:
|
| max_samples = int(_cfg_get(self._runtime_cfg, "app.max_samples_per_request", 4))
|
| if num_samples < 1 or num_samples > max_samples:
|
| raise InferenceError(
|
| f"num_samples must be in [1, {max_samples}] for this Space deployment."
|
| )
|
| if not reference_pdb_path:
|
| raise InferenceError(
|
| "Conditional mode requires a reference PDB upload to define fixed positions."
|
| )
|
| reference_path = Path(reference_pdb_path)
|
| if not reference_path.exists():
|
| raise InferenceError(f"Reference PDB not found: {reference_path}")
|
|
|
| try:
|
| loaded = self._model_manager.load()
|
| classifier_ctx = None
|
| if use_classifier_guidance:
|
| if not self._classifier_enabled:
|
| raise InferenceError(
|
| "Classifier guidance requested but classifier is disabled."
|
| )
|
| classifier_ctx = self._classifier_manager.load(device=loaded.device)
|
| self._last_inference_error = None
|
|
|
| effective_seed = (
|
| int(seed)
|
| if seed is not None
|
| else int(_cfg_get(loaded.merged_cfg, "inference.seed", 123))
|
| )
|
| np.random.seed(effective_seed)
|
| torch.manual_seed(effective_seed)
|
| if loaded.device.type == "cuda":
|
| torch.cuda.manual_seed_all(effective_seed)
|
|
|
| pdb_feats = parse_pdb_feats(
|
| "reference", str(reference_path), chain_id=chain_id, exclude_hetatm=True
|
| )
|
| if "bb_positions" not in pdb_feats:
|
| raise InferenceError("Reference PDB has no backbone positions to fix.")
|
| num_res = int(pdb_feats["aatype"].shape[0])
|
| min_length = int(_cfg_get(self._runtime_cfg, "app.min_length", 32))
|
| max_length = int(_cfg_get(self._runtime_cfg, "app.max_length", 1024))
|
| if num_res < min_length or num_res > max_length:
|
| raise InferenceError(
|
| f"Reference length {num_res} is outside [{min_length}, {max_length}]."
|
| )
|
|
|
| residue_indices = np.asarray(pdb_feats["residue_index"]).reshape(-1)
|
| fixed_positions = torch.tensor(
|
| np.asarray(pdb_feats["bb_positions"]), dtype=torch.float32, device=loaded.device
|
| )
|
| if fixed_positions.ndim != 2 or fixed_positions.shape[1] != 3:
|
| raise InferenceError(
|
| f"Expected fixed_positions of shape [N, 3], got {tuple(fixed_positions.shape)}."
|
| )
|
|
|
| if fixed_residues:
|
| fixed_mask = torch.zeros(num_res, dtype=torch.bool, device=loaded.device)
|
| missing: List[int] = []
|
| for resnum in fixed_residues:
|
| matches = np.where(residue_indices == int(resnum))[0]
|
| if len(matches) == 0:
|
| missing.append(int(resnum))
|
| continue
|
| fixed_mask[int(matches[0])] = True
|
| if missing:
|
| raise InferenceError(
|
| f"Fixed residue number(s) not found in reference PDB: {missing}."
|
| )
|
| if not bool(fixed_mask.any()):
|
| raise InferenceError("No valid fixed residues resolved from the request.")
|
| else:
|
| fixed_mask = torch.ones(num_res, dtype=torch.bool, device=loaded.device)
|
|
|
| fixed_residue_count = int(fixed_mask.sum().item())
|
|
|
| interpolant_cfg = self._resolve_interpolant_cfg(loaded.merged_cfg, num_timesteps)
|
| effective_timesteps = int(
|
| _cfg_get(interpolant_cfg, "sampling.num_timesteps", 100)
|
| )
|
|
|
| effective_guidance_scale = None
|
| effective_target_class = None
|
| clf_model = None
|
| if classifier_ctx is not None:
|
| clf_model = classifier_ctx.classifier
|
| effective_guidance_scale = (
|
| float(guidance_scale)
|
| if guidance_scale is not None
|
| else float(_cfg_get(loaded.merged_cfg, "inference.classifier.guidance_scale", 0.2))
|
| )
|
| effective_target_class = (
|
| int(target_class)
|
| if target_class is not None
|
| else int(_cfg_get(loaded.merged_cfg, "inference.classifier.target_class", 1))
|
| )
|
|
|
| output_root = Path(str(_cfg_get(loaded.merged_cfg, "inference.output_dir", "space_outputs")))
|
| if not output_root.is_absolute():
|
| output_root = (REPO_ROOT / output_root).resolve()
|
| run_id = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
| run_dir = output_root / f"space_conditional_{run_id}"
|
| run_dir.mkdir(parents=True, exist_ok=True)
|
|
|
| interpolant = Interpolant(interpolant_cfg)
|
| interpolant.set_device(loaded.device)
|
|
|
| flow_mask_np = to_numpy((~fixed_mask).float())
|
|
|
| sample_files: List[str] = []
|
| trajectory_files: List[str] = []
|
| x0_trajectory_files: List[str] = []
|
|
|
| for sample_id in range(num_samples):
|
| sample_dir = run_dir / f"length_{num_res}" / f"sample_{sample_id}"
|
| sample_dir.mkdir(parents=True, exist_ok=True)
|
|
|
| fixed_positions_b = fixed_positions.unsqueeze(0)
|
| atom37_traj, clean_atom37_traj, _ = interpolant.sample_conditional(
|
| num_batch=1,
|
| num_res=num_res,
|
| model=loaded.model,
|
| fixed_positions=fixed_positions_b,
|
| fixed_mask=fixed_mask,
|
| clf_model=clf_model,
|
| guidance_scale=effective_guidance_scale or 0.2,
|
| target_class=effective_target_class or 1,
|
| temperature=temperature,
|
| )
|
| bb_traj = to_numpy(torch.concat(atom37_traj, dim=0))
|
| model_x0_traj = np.flip(to_numpy(torch.concat(clean_atom37_traj, dim=0)), axis=0)
|
| saved = save_traj(
|
| sample=bb_traj[-1],
|
| bb_prot_traj=bb_traj,
|
| x0_traj=model_x0_traj,
|
| diffuse_mask=flow_mask_np,
|
| output_dir=str(sample_dir),
|
| )
|
| sample_files.append(saved["sample_path"])
|
| trajectory_files.append(saved["traj_path"])
|
| x0_trajectory_files.append(saved["x0_traj_path"])
|
|
|
| return InferenceResult(
|
| mode="conditional",
|
| run_dir=str(run_dir),
|
| sample_files=sample_files,
|
| trajectory_files=trajectory_files,
|
| x0_trajectory_files=x0_trajectory_files,
|
| seed=effective_seed,
|
| artifacts_source=loaded.artifacts.source,
|
| guidance_scale=effective_guidance_scale,
|
| target_class=effective_target_class,
|
| fixed_residue_count=fixed_residue_count,
|
| num_timesteps=effective_timesteps,
|
| )
|
| except (ArtifactResolutionError, ModelLoadError, InferenceError) as exc:
|
| self._last_inference_error = str(exc)
|
| raise
|
| except Exception as exc:
|
| self._last_inference_error = str(exc)
|
| LOGGER.exception("Unexpected failure during conditional inference.")
|
| raise InferenceError(str(exc)) from exc
|
|
|