| """Model and artifact loading utilities for the FlowProt Space MVP."""
|
|
|
| from __future__ import annotations
|
|
|
| import logging
|
| import os
|
| import sys
|
| from dataclasses import dataclass
|
| from pathlib import Path
|
| from typing import Dict, Optional
|
|
|
| import torch
|
| from huggingface_hub import snapshot_download
|
| from omegaconf import DictConfig, OmegaConf
|
|
|
| LOGGER = logging.getLogger(__name__)
|
| REPO_ROOT = Path(__file__).resolve().parent
|
| MODEL_ROOT = REPO_ROOT / "model"
|
| DEFAULT_APP_CONFIG = REPO_ROOT / "config.yaml"
|
|
|
|
|
| def ensure_model_pythonpath() -> None:
|
| """Ensure `model/` package imports resolve from the Space root."""
|
| model_root_str = str(MODEL_ROOT)
|
| if model_root_str not in sys.path:
|
| sys.path.insert(0, model_root_str)
|
|
|
|
|
| ensure_model_pythonpath()
|
|
|
| from models.classifier_wrapper_v2 import ClasfModule
|
| from models.proteinflow import ProteinFlow
|
|
|
|
|
| class ArtifactResolutionError(RuntimeError):
|
| """Raised when checkpoint/config artifacts cannot be resolved."""
|
|
|
|
|
| class ModelLoadError(RuntimeError):
|
| """Raised when model instantiation or weights loading fails."""
|
|
|
|
|
| @dataclass
|
| class ResolvedArtifacts:
|
| ckpt_path: Path
|
| config_path: Path
|
| source: str
|
|
|
|
|
| @dataclass
|
| class LoadedModelContext:
|
| model: ProteinFlow
|
| device: torch.device
|
| merged_cfg: DictConfig
|
| artifacts: ResolvedArtifacts
|
|
|
|
|
| @dataclass
|
| class ResolvedClassifierArtifacts:
|
| ckpt_path: Path
|
| source: str
|
|
|
|
|
| @dataclass
|
| class LoadedClassifierContext:
|
| classifier: ClasfModule
|
| device: torch.device
|
| artifacts: ResolvedClassifierArtifacts
|
|
|
|
|
| def _as_path(path_value: str) -> Path:
|
| raw = Path(path_value).expanduser()
|
| return raw if raw.is_absolute() else (REPO_ROOT / raw).resolve()
|
|
|
|
|
| def _require_file(path: Path, label: str) -> None:
|
| if not path.exists() or not path.is_file():
|
| raise ArtifactResolutionError(f"{label} does not exist: {path}")
|
|
|
|
|
| def load_runtime_config(config_path: Optional[str] = None) -> DictConfig:
|
| """Load app/runtime config from file."""
|
| explicit_path = config_path or os.getenv("FLOWPROT_APP_CONFIG")
|
| cfg_path = _as_path(explicit_path) if explicit_path else DEFAULT_APP_CONFIG
|
| if not cfg_path.exists():
|
| raise ArtifactResolutionError(
|
| f"App config file is missing: {cfg_path}. "
|
| "Set FLOWPROT_APP_CONFIG or add config.yaml at repo root."
|
| )
|
| cfg = OmegaConf.load(cfg_path)
|
| LOGGER.info("Loaded runtime config from %s", cfg_path)
|
| return cfg
|
|
|
|
|
| def resolve_artifacts(runtime_cfg: Optional[DictConfig] = None) -> ResolvedArtifacts:
|
| """Resolve checkpoint + checkpoint config.
|
|
|
| Resolution precedence (first match wins):
|
| 1. Env vars (FLOWPROT_CKPT_PATH / FLOWPROT_CKPT_DIR / FLOWPROT_HF_REPO_ID)
|
| so deployments (e.g. HF Space) can override without editing files.
|
| 2. Runtime config file (inference.ckpt_path + optional inference.ckpt_config_path).
|
| """
|
| ckpt_path_env = os.getenv("FLOWPROT_CKPT_PATH")
|
| ckpt_dir_env = os.getenv("FLOWPROT_CKPT_DIR")
|
| hf_repo_id = os.getenv("FLOWPROT_HF_REPO_ID")
|
| config_filename = os.getenv("FLOWPROT_CKPT_CONFIG_FILENAME", "config.yaml")
|
|
|
| cfg_ckpt_path = (
|
| OmegaConf.select(runtime_cfg, "inference.ckpt_path") if runtime_cfg is not None else None
|
| )
|
| cfg_ckpt_config_path = (
|
| OmegaConf.select(runtime_cfg, "inference.ckpt_config_path")
|
| if runtime_cfg is not None
|
| else None
|
| )
|
|
|
| if ckpt_path_env:
|
| ckpt_path = _as_path(ckpt_path_env)
|
| config_path = _as_path(
|
| os.getenv("FLOWPROT_CKPT_CONFIG_PATH", str(ckpt_path.parent / config_filename))
|
| )
|
| source = "local_ckpt_path"
|
| elif ckpt_dir_env:
|
| ckpt_dir = _as_path(ckpt_dir_env)
|
| ckpt_filename = os.getenv("FLOWPROT_CKPT_FILENAME", "epoch.ckpt")
|
| ckpt_path = ckpt_dir / ckpt_filename
|
| config_path = _as_path(
|
| os.getenv("FLOWPROT_CKPT_CONFIG_PATH", str(ckpt_dir / config_filename))
|
| )
|
| source = "local_ckpt_dir"
|
| elif hf_repo_id:
|
| ckpt_filename = os.getenv("FLOWPROT_CKPT_FILENAME")
|
| if not ckpt_filename:
|
| raise ArtifactResolutionError(
|
| "FLOWPROT_CKPT_FILENAME is required when FLOWPROT_HF_REPO_ID is set."
|
| )
|
| revision = os.getenv("FLOWPROT_HF_REVISION")
|
| token = os.getenv("HF_TOKEN")
|
| local_dir = snapshot_download(
|
| repo_id=hf_repo_id,
|
| revision=revision,
|
| token=token,
|
| allow_patterns=[ckpt_filename, config_filename],
|
| )
|
| ckpt_path = Path(local_dir) / ckpt_filename
|
| config_path = Path(local_dir) / config_filename
|
| source = "hf_hub_snapshot"
|
| elif cfg_ckpt_path:
|
| ckpt_path = _as_path(str(cfg_ckpt_path))
|
| config_path = (
|
| _as_path(str(cfg_ckpt_config_path))
|
| if cfg_ckpt_config_path
|
| else (ckpt_path.parent / config_filename)
|
| )
|
| source = "runtime_config"
|
| else:
|
| raise ArtifactResolutionError(
|
| "No model artifact source configured. Set inference.ckpt_path in config.yaml, "
|
| "or one of the env vars: FLOWPROT_CKPT_PATH, FLOWPROT_CKPT_DIR, or "
|
| "FLOWPROT_HF_REPO_ID (with FLOWPROT_CKPT_FILENAME)."
|
| )
|
|
|
| _require_file(ckpt_path, "Checkpoint file")
|
| _require_file(config_path, "Checkpoint config")
|
| LOGGER.info("Resolved artifacts from %s", source)
|
| LOGGER.info("Checkpoint: %s", ckpt_path)
|
| LOGGER.info("Checkpoint config: %s", config_path)
|
| return ResolvedArtifacts(ckpt_path=ckpt_path, config_path=config_path, source=source)
|
|
|
|
|
| def resolve_classifier_artifacts(runtime_cfg: Optional[DictConfig] = None) -> ResolvedClassifierArtifacts:
|
| """Resolve classifier checkpoint via env var or runtime config."""
|
| ckpt_path_env = os.getenv("FLOWPROT_CLASSIFIER_CKPT_PATH")
|
| if ckpt_path_env:
|
| ckpt_path = _as_path(ckpt_path_env)
|
| source = "env_classifier_ckpt_path"
|
| else:
|
| cfg_path = None
|
| if runtime_cfg is not None:
|
| cfg_path = OmegaConf.select(runtime_cfg, "inference.classifier.ckpt_path")
|
| if cfg_path:
|
| ckpt_path = _as_path(str(cfg_path))
|
| source = "runtime_config"
|
| else:
|
| ckpt_path = (
|
| MODEL_ROOT / "ckpt" / "classifier_ckpt" / "epoch=90-step=728000.ckpt"
|
| ).resolve()
|
| source = "default_classifier_ckpt"
|
|
|
| _require_file(ckpt_path, "Classifier checkpoint file")
|
| LOGGER.info("Resolved classifier artifacts from %s", source)
|
| LOGGER.info("Classifier checkpoint: %s", ckpt_path)
|
| return ResolvedClassifierArtifacts(ckpt_path=ckpt_path, source=source)
|
|
|
|
|
| def _resolve_device(merged_cfg: DictConfig) -> torch.device:
|
| app_cfg = merged_cfg.get("app", {})
|
| configured = os.getenv("FLOWPROT_DEVICE", str(app_cfg.get("device", "auto"))).strip().lower()
|
| if configured in {"", "auto"}:
|
| return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| if configured.startswith("cuda") and not torch.cuda.is_available():
|
| raise ModelLoadError(
|
| f"FLOWPROT_DEVICE={configured} requested, but CUDA is not available."
|
| )
|
| return torch.device(configured)
|
|
|
|
|
| def _normalize_state_dict(raw_state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
| """Normalize Lightning-style checkpoints to raw ProteinFlow module keys."""
|
| candidates: Dict[str, torch.Tensor] = {}
|
| for key, value in raw_state_dict.items():
|
| if key.startswith("model."):
|
| candidates[key[len("model."):]] = value
|
| if candidates:
|
| return candidates
|
|
|
| candidates = {}
|
| for key, value in raw_state_dict.items():
|
| if key.startswith("module.model."):
|
| candidates[key[len("module.model."):]] = value
|
| if candidates:
|
| return candidates
|
|
|
|
|
| normalized: Dict[str, torch.Tensor] = {}
|
| for key, value in raw_state_dict.items():
|
| normalized[key[len("module."):] if key.startswith("module.") else key] = value
|
| return normalized
|
|
|
|
|
| def _merge_runtime_and_checkpoint_cfg(
|
| runtime_cfg: DictConfig, checkpoint_cfg_path: Path
|
| ) -> DictConfig:
|
| ckpt_cfg = OmegaConf.load(checkpoint_cfg_path)
|
| OmegaConf.set_struct(runtime_cfg, False)
|
| OmegaConf.set_struct(ckpt_cfg, False)
|
| merged = OmegaConf.merge(ckpt_cfg, runtime_cfg)
|
| if "inference" not in merged:
|
| merged.inference = OmegaConf.create({})
|
| if "interpolant" in merged and "interpolant" not in merged.inference:
|
| merged.inference.interpolant = merged.interpolant
|
| return merged
|
|
|
|
|
| class FlowProtModelManager:
|
| """Lazy model manager with cached loaded context."""
|
|
|
| def __init__(self, config_path: Optional[str] = None):
|
| self._config_path = config_path
|
| self._loaded: Optional[LoadedModelContext] = None
|
| self._last_error: Optional[str] = None
|
|
|
| @property
|
| def is_loaded(self) -> bool:
|
| return self._loaded is not None
|
|
|
| @property
|
| def last_error(self) -> Optional[str]:
|
| return self._last_error
|
|
|
| def peek_loaded(self) -> Optional[LoadedModelContext]:
|
| return self._loaded
|
|
|
| def load(self, force_reload: bool = False) -> LoadedModelContext:
|
| if self._loaded is not None and not force_reload:
|
| return self._loaded
|
|
|
| try:
|
| runtime_cfg = load_runtime_config(self._config_path)
|
| artifacts = resolve_artifacts(runtime_cfg)
|
| merged_cfg = _merge_runtime_and_checkpoint_cfg(runtime_cfg, artifacts.config_path)
|
| device = _resolve_device(merged_cfg)
|
|
|
|
|
|
|
| checkpoint_payload = torch.load(
|
| artifacts.ckpt_path,
|
| map_location="cpu",
|
| weights_only=False,
|
| )
|
| state_dict = checkpoint_payload.get("state_dict", checkpoint_payload)
|
| if not isinstance(state_dict, dict):
|
| raise ModelLoadError(
|
| "Checkpoint payload does not include a valid state_dict dictionary."
|
| )
|
|
|
| model = ProteinFlow(merged_cfg.model)
|
| normalized_state_dict = _normalize_state_dict(state_dict)
|
| missing, unexpected = model.load_state_dict(normalized_state_dict, strict=False)
|
| if missing:
|
| LOGGER.warning("Missing checkpoint keys while loading model: %s", missing[:20])
|
| if unexpected:
|
| LOGGER.warning(
|
| "Unexpected checkpoint keys while loading model: %s", unexpected[:20]
|
| )
|
|
|
| model.to(device)
|
| model.eval()
|
|
|
| self._loaded = LoadedModelContext(
|
| model=model,
|
| device=device,
|
| merged_cfg=merged_cfg,
|
| artifacts=artifacts,
|
| )
|
| self._last_error = None
|
| LOGGER.info("Model loaded successfully on %s", device)
|
| return self._loaded
|
| except Exception as exc:
|
| self._last_error = str(exc)
|
| LOGGER.exception("Failed to load FlowProt model artifacts.")
|
| if isinstance(exc, (ArtifactResolutionError, ModelLoadError)):
|
| raise
|
| raise ModelLoadError(str(exc)) from exc
|
|
|
|
|
| class FlowProtClassifierManager:
|
| """Lazy classifier manager with cached loaded context."""
|
|
|
| def __init__(self, config_path: Optional[str] = None):
|
| self._config_path = config_path
|
| self._loaded: Optional[LoadedClassifierContext] = None
|
| self._last_error: Optional[str] = None
|
|
|
| @property
|
| def is_loaded(self) -> bool:
|
| return self._loaded is not None
|
|
|
| @property
|
| def last_error(self) -> Optional[str]:
|
| return self._last_error
|
|
|
| def peek_loaded(self) -> Optional[LoadedClassifierContext]:
|
| return self._loaded
|
|
|
| def load(
|
| self,
|
| device: torch.device,
|
| force_reload: bool = False,
|
| ) -> LoadedClassifierContext:
|
| if self._loaded is not None and not force_reload:
|
| if self._loaded.device == device:
|
| return self._loaded
|
|
|
| try:
|
| runtime_cfg = load_runtime_config(self._config_path)
|
| artifacts = resolve_classifier_artifacts(runtime_cfg)
|
|
|
|
|
|
|
|
|
|
|
| checkpoint_payload = torch.load(
|
| str(artifacts.ckpt_path),
|
| map_location="cpu",
|
| weights_only=False,
|
| )
|
|
|
| classifier_cfg = checkpoint_payload.get("hyper_parameters", {}).get("cfg")
|
| if classifier_cfg is None:
|
| sibling_config = artifacts.ckpt_path.parent / "config.yaml"
|
| _require_file(sibling_config, "Classifier checkpoint config")
|
| classifier_cfg = OmegaConf.load(sibling_config)
|
|
|
| state_dict = checkpoint_payload.get("state_dict", checkpoint_payload)
|
| if not isinstance(state_dict, dict):
|
| raise ModelLoadError(
|
| "Classifier checkpoint payload does not include a valid state_dict."
|
| )
|
|
|
| classifier = ClasfModule(classifier_cfg)
|
| missing, unexpected = classifier.load_state_dict(state_dict, strict=False)
|
| if missing:
|
| LOGGER.warning("Missing classifier checkpoint keys: %s", missing[:20])
|
| if unexpected:
|
| LOGGER.warning("Unexpected classifier checkpoint keys: %s", unexpected[:20])
|
|
|
| for param in classifier.parameters():
|
| param.requires_grad_(True)
|
| classifier.to(device)
|
| classifier.eval()
|
|
|
| self._loaded = LoadedClassifierContext(
|
| classifier=classifier,
|
| device=device,
|
| artifacts=artifacts,
|
| )
|
| self._last_error = None
|
| LOGGER.info("Classifier loaded successfully on %s", device)
|
| return self._loaded
|
| except Exception as exc:
|
| self._last_error = str(exc)
|
| LOGGER.exception("Failed to load FlowProt classifier artifacts.")
|
| if isinstance(exc, (ArtifactResolutionError, ModelLoadError)):
|
| raise
|
| raise ModelLoadError(str(exc)) from exc
|
|
|