"""Model and artifact loading utilities for the FlowProt Space MVP.""" from __future__ import annotations import logging import os import sys from dataclasses import dataclass from pathlib import Path from typing import Dict, Optional import torch from huggingface_hub import snapshot_download from omegaconf import DictConfig, OmegaConf LOGGER = logging.getLogger(__name__) REPO_ROOT = Path(__file__).resolve().parent MODEL_ROOT = REPO_ROOT / "model" DEFAULT_APP_CONFIG = REPO_ROOT / "config.yaml" def ensure_model_pythonpath() -> None: """Ensure `model/` package imports resolve from the Space root.""" model_root_str = str(MODEL_ROOT) if model_root_str not in sys.path: sys.path.insert(0, model_root_str) ensure_model_pythonpath() from models.classifier_wrapper_v2 import ClasfModule # noqa: E402 from models.proteinflow import ProteinFlow # noqa: E402 class ArtifactResolutionError(RuntimeError): """Raised when checkpoint/config artifacts cannot be resolved.""" class ModelLoadError(RuntimeError): """Raised when model instantiation or weights loading fails.""" @dataclass class ResolvedArtifacts: ckpt_path: Path config_path: Path source: str @dataclass class LoadedModelContext: model: ProteinFlow device: torch.device merged_cfg: DictConfig artifacts: ResolvedArtifacts @dataclass class ResolvedClassifierArtifacts: ckpt_path: Path source: str @dataclass class LoadedClassifierContext: classifier: ClasfModule device: torch.device artifacts: ResolvedClassifierArtifacts def _as_path(path_value: str) -> Path: raw = Path(path_value).expanduser() return raw if raw.is_absolute() else (REPO_ROOT / raw).resolve() def _require_file(path: Path, label: str) -> None: if not path.exists() or not path.is_file(): raise ArtifactResolutionError(f"{label} does not exist: {path}") def load_runtime_config(config_path: Optional[str] = None) -> DictConfig: """Load app/runtime config from file.""" explicit_path = config_path or os.getenv("FLOWPROT_APP_CONFIG") cfg_path = _as_path(explicit_path) if explicit_path else DEFAULT_APP_CONFIG if not cfg_path.exists(): raise ArtifactResolutionError( f"App config file is missing: {cfg_path}. " "Set FLOWPROT_APP_CONFIG or add config.yaml at repo root." ) cfg = OmegaConf.load(cfg_path) LOGGER.info("Loaded runtime config from %s", cfg_path) return cfg def resolve_artifacts(runtime_cfg: Optional[DictConfig] = None) -> ResolvedArtifacts: """Resolve checkpoint + checkpoint config. Resolution precedence (first match wins): 1. Env vars (FLOWPROT_CKPT_PATH / FLOWPROT_CKPT_DIR / FLOWPROT_HF_REPO_ID) so deployments (e.g. HF Space) can override without editing files. 2. Runtime config file (inference.ckpt_path + optional inference.ckpt_config_path). """ ckpt_path_env = os.getenv("FLOWPROT_CKPT_PATH") ckpt_dir_env = os.getenv("FLOWPROT_CKPT_DIR") hf_repo_id = os.getenv("FLOWPROT_HF_REPO_ID") config_filename = os.getenv("FLOWPROT_CKPT_CONFIG_FILENAME", "config.yaml") cfg_ckpt_path = ( OmegaConf.select(runtime_cfg, "inference.ckpt_path") if runtime_cfg is not None else None ) cfg_ckpt_config_path = ( OmegaConf.select(runtime_cfg, "inference.ckpt_config_path") if runtime_cfg is not None else None ) if ckpt_path_env: ckpt_path = _as_path(ckpt_path_env) config_path = _as_path( os.getenv("FLOWPROT_CKPT_CONFIG_PATH", str(ckpt_path.parent / config_filename)) ) source = "local_ckpt_path" elif ckpt_dir_env: ckpt_dir = _as_path(ckpt_dir_env) ckpt_filename = os.getenv("FLOWPROT_CKPT_FILENAME", "epoch.ckpt") ckpt_path = ckpt_dir / ckpt_filename config_path = _as_path( os.getenv("FLOWPROT_CKPT_CONFIG_PATH", str(ckpt_dir / config_filename)) ) source = "local_ckpt_dir" elif hf_repo_id: ckpt_filename = os.getenv("FLOWPROT_CKPT_FILENAME") if not ckpt_filename: raise ArtifactResolutionError( "FLOWPROT_CKPT_FILENAME is required when FLOWPROT_HF_REPO_ID is set." ) revision = os.getenv("FLOWPROT_HF_REVISION") token = os.getenv("HF_TOKEN") local_dir = snapshot_download( repo_id=hf_repo_id, revision=revision, token=token, allow_patterns=[ckpt_filename, config_filename], ) ckpt_path = Path(local_dir) / ckpt_filename config_path = Path(local_dir) / config_filename source = "hf_hub_snapshot" elif cfg_ckpt_path: ckpt_path = _as_path(str(cfg_ckpt_path)) config_path = ( _as_path(str(cfg_ckpt_config_path)) if cfg_ckpt_config_path else (ckpt_path.parent / config_filename) ) source = "runtime_config" else: raise ArtifactResolutionError( "No model artifact source configured. Set inference.ckpt_path in config.yaml, " "or one of the env vars: FLOWPROT_CKPT_PATH, FLOWPROT_CKPT_DIR, or " "FLOWPROT_HF_REPO_ID (with FLOWPROT_CKPT_FILENAME)." ) _require_file(ckpt_path, "Checkpoint file") _require_file(config_path, "Checkpoint config") LOGGER.info("Resolved artifacts from %s", source) LOGGER.info("Checkpoint: %s", ckpt_path) LOGGER.info("Checkpoint config: %s", config_path) return ResolvedArtifacts(ckpt_path=ckpt_path, config_path=config_path, source=source) def resolve_classifier_artifacts(runtime_cfg: Optional[DictConfig] = None) -> ResolvedClassifierArtifacts: """Resolve classifier checkpoint via env var or runtime config.""" ckpt_path_env = os.getenv("FLOWPROT_CLASSIFIER_CKPT_PATH") if ckpt_path_env: ckpt_path = _as_path(ckpt_path_env) source = "env_classifier_ckpt_path" else: cfg_path = None if runtime_cfg is not None: cfg_path = OmegaConf.select(runtime_cfg, "inference.classifier.ckpt_path") if cfg_path: ckpt_path = _as_path(str(cfg_path)) source = "runtime_config" else: ckpt_path = ( MODEL_ROOT / "ckpt" / "classifier_ckpt" / "epoch=90-step=728000.ckpt" ).resolve() source = "default_classifier_ckpt" _require_file(ckpt_path, "Classifier checkpoint file") LOGGER.info("Resolved classifier artifacts from %s", source) LOGGER.info("Classifier checkpoint: %s", ckpt_path) return ResolvedClassifierArtifacts(ckpt_path=ckpt_path, source=source) def _resolve_device(merged_cfg: DictConfig) -> torch.device: app_cfg = merged_cfg.get("app", {}) configured = os.getenv("FLOWPROT_DEVICE", str(app_cfg.get("device", "auto"))).strip().lower() if configured in {"", "auto"}: return torch.device("cuda" if torch.cuda.is_available() else "cpu") if configured.startswith("cuda") and not torch.cuda.is_available(): raise ModelLoadError( f"FLOWPROT_DEVICE={configured} requested, but CUDA is not available." ) return torch.device(configured) def _normalize_state_dict(raw_state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Normalize Lightning-style checkpoints to raw ProteinFlow module keys.""" candidates: Dict[str, torch.Tensor] = {} for key, value in raw_state_dict.items(): if key.startswith("model."): candidates[key[len("model."):]] = value if candidates: return candidates candidates = {} for key, value in raw_state_dict.items(): if key.startswith("module.model."): candidates[key[len("module.model."):]] = value if candidates: return candidates # Fall back to de-DDP or already-normalized keys. normalized: Dict[str, torch.Tensor] = {} for key, value in raw_state_dict.items(): normalized[key[len("module."):] if key.startswith("module.") else key] = value return normalized def _merge_runtime_and_checkpoint_cfg( runtime_cfg: DictConfig, checkpoint_cfg_path: Path ) -> DictConfig: ckpt_cfg = OmegaConf.load(checkpoint_cfg_path) OmegaConf.set_struct(runtime_cfg, False) OmegaConf.set_struct(ckpt_cfg, False) merged = OmegaConf.merge(ckpt_cfg, runtime_cfg) if "inference" not in merged: merged.inference = OmegaConf.create({}) if "interpolant" in merged and "interpolant" not in merged.inference: merged.inference.interpolant = merged.interpolant return merged class FlowProtModelManager: """Lazy model manager with cached loaded context.""" def __init__(self, config_path: Optional[str] = None): self._config_path = config_path self._loaded: Optional[LoadedModelContext] = None self._last_error: Optional[str] = None @property def is_loaded(self) -> bool: return self._loaded is not None @property def last_error(self) -> Optional[str]: return self._last_error def peek_loaded(self) -> Optional[LoadedModelContext]: return self._loaded def load(self, force_reload: bool = False) -> LoadedModelContext: if self._loaded is not None and not force_reload: return self._loaded try: runtime_cfg = load_runtime_config(self._config_path) artifacts = resolve_artifacts(runtime_cfg) merged_cfg = _merge_runtime_and_checkpoint_cfg(runtime_cfg, artifacts.config_path) device = _resolve_device(merged_cfg) # PyTorch >=2.6 defaults to weights_only=True, which breaks older # Lightning checkpoints that store OmegaConf objects in the payload. checkpoint_payload = torch.load( artifacts.ckpt_path, map_location="cpu", weights_only=False, ) state_dict = checkpoint_payload.get("state_dict", checkpoint_payload) if not isinstance(state_dict, dict): raise ModelLoadError( "Checkpoint payload does not include a valid state_dict dictionary." ) model = ProteinFlow(merged_cfg.model) normalized_state_dict = _normalize_state_dict(state_dict) missing, unexpected = model.load_state_dict(normalized_state_dict, strict=False) if missing: LOGGER.warning("Missing checkpoint keys while loading model: %s", missing[:20]) if unexpected: LOGGER.warning( "Unexpected checkpoint keys while loading model: %s", unexpected[:20] ) model.to(device) model.eval() self._loaded = LoadedModelContext( model=model, device=device, merged_cfg=merged_cfg, artifacts=artifacts, ) self._last_error = None LOGGER.info("Model loaded successfully on %s", device) return self._loaded except Exception as exc: self._last_error = str(exc) LOGGER.exception("Failed to load FlowProt model artifacts.") if isinstance(exc, (ArtifactResolutionError, ModelLoadError)): raise raise ModelLoadError(str(exc)) from exc class FlowProtClassifierManager: """Lazy classifier manager with cached loaded context.""" def __init__(self, config_path: Optional[str] = None): self._config_path = config_path self._loaded: Optional[LoadedClassifierContext] = None self._last_error: Optional[str] = None @property def is_loaded(self) -> bool: return self._loaded is not None @property def last_error(self) -> Optional[str]: return self._last_error def peek_loaded(self) -> Optional[LoadedClassifierContext]: return self._loaded def load( self, device: torch.device, force_reload: bool = False, ) -> LoadedClassifierContext: if self._loaded is not None and not force_reload: if self._loaded.device == device: return self._loaded try: runtime_cfg = load_runtime_config(self._config_path) artifacts = resolve_classifier_artifacts(runtime_cfg) # Bypass Lightning's load_from_checkpoint: PyTorch >=2.6 defaults to # weights_only=True and Lightning explicitly forwards that flag, which # rejects the OmegaConf objects pickled in this checkpoint. We load the # payload directly (trusted source) and rebuild the module ourselves. checkpoint_payload = torch.load( str(artifacts.ckpt_path), map_location="cpu", weights_only=False, ) classifier_cfg = checkpoint_payload.get("hyper_parameters", {}).get("cfg") if classifier_cfg is None: sibling_config = artifacts.ckpt_path.parent / "config.yaml" _require_file(sibling_config, "Classifier checkpoint config") classifier_cfg = OmegaConf.load(sibling_config) state_dict = checkpoint_payload.get("state_dict", checkpoint_payload) if not isinstance(state_dict, dict): raise ModelLoadError( "Classifier checkpoint payload does not include a valid state_dict." ) classifier = ClasfModule(classifier_cfg) missing, unexpected = classifier.load_state_dict(state_dict, strict=False) if missing: LOGGER.warning("Missing classifier checkpoint keys: %s", missing[:20]) if unexpected: LOGGER.warning("Unexpected classifier checkpoint keys: %s", unexpected[:20]) for param in classifier.parameters(): param.requires_grad_(True) classifier.to(device) classifier.eval() self._loaded = LoadedClassifierContext( classifier=classifier, device=device, artifacts=artifacts, ) self._last_error = None LOGGER.info("Classifier loaded successfully on %s", device) return self._loaded except Exception as exc: self._last_error = str(exc) LOGGER.exception("Failed to load FlowProt classifier artifacts.") if isinstance(exc, (ArtifactResolutionError, ModelLoadError)): raise raise ModelLoadError(str(exc)) from exc