FlowProt / model_loader.py
alibtsd's picture
Deploy FlowProt Docker Space
f34af6f verified
Raw
History Blame Contribute Delete
15.1 kB
"""Model and artifact loading utilities for the FlowProt Space MVP."""
from __future__ import annotations
import logging
import os
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Optional
import torch
from huggingface_hub import snapshot_download
from omegaconf import DictConfig, OmegaConf
LOGGER = logging.getLogger(__name__)
REPO_ROOT = Path(__file__).resolve().parent
MODEL_ROOT = REPO_ROOT / "model"
DEFAULT_APP_CONFIG = REPO_ROOT / "config.yaml"
def ensure_model_pythonpath() -> None:
"""Ensure `model/` package imports resolve from the Space root."""
model_root_str = str(MODEL_ROOT)
if model_root_str not in sys.path:
sys.path.insert(0, model_root_str)
ensure_model_pythonpath()
from models.classifier_wrapper_v2 import ClasfModule # noqa: E402
from models.proteinflow import ProteinFlow # noqa: E402
class ArtifactResolutionError(RuntimeError):
"""Raised when checkpoint/config artifacts cannot be resolved."""
class ModelLoadError(RuntimeError):
"""Raised when model instantiation or weights loading fails."""
@dataclass
class ResolvedArtifacts:
ckpt_path: Path
config_path: Path
source: str
@dataclass
class LoadedModelContext:
model: ProteinFlow
device: torch.device
merged_cfg: DictConfig
artifacts: ResolvedArtifacts
@dataclass
class ResolvedClassifierArtifacts:
ckpt_path: Path
source: str
@dataclass
class LoadedClassifierContext:
classifier: ClasfModule
device: torch.device
artifacts: ResolvedClassifierArtifacts
def _as_path(path_value: str) -> Path:
raw = Path(path_value).expanduser()
return raw if raw.is_absolute() else (REPO_ROOT / raw).resolve()
def _require_file(path: Path, label: str) -> None:
if not path.exists() or not path.is_file():
raise ArtifactResolutionError(f"{label} does not exist: {path}")
def load_runtime_config(config_path: Optional[str] = None) -> DictConfig:
"""Load app/runtime config from file."""
explicit_path = config_path or os.getenv("FLOWPROT_APP_CONFIG")
cfg_path = _as_path(explicit_path) if explicit_path else DEFAULT_APP_CONFIG
if not cfg_path.exists():
raise ArtifactResolutionError(
f"App config file is missing: {cfg_path}. "
"Set FLOWPROT_APP_CONFIG or add config.yaml at repo root."
)
cfg = OmegaConf.load(cfg_path)
LOGGER.info("Loaded runtime config from %s", cfg_path)
return cfg
def resolve_artifacts(runtime_cfg: Optional[DictConfig] = None) -> ResolvedArtifacts:
"""Resolve checkpoint + checkpoint config.
Resolution precedence (first match wins):
1. Env vars (FLOWPROT_CKPT_PATH / FLOWPROT_CKPT_DIR / FLOWPROT_HF_REPO_ID)
so deployments (e.g. HF Space) can override without editing files.
2. Runtime config file (inference.ckpt_path + optional inference.ckpt_config_path).
"""
ckpt_path_env = os.getenv("FLOWPROT_CKPT_PATH")
ckpt_dir_env = os.getenv("FLOWPROT_CKPT_DIR")
hf_repo_id = os.getenv("FLOWPROT_HF_REPO_ID")
config_filename = os.getenv("FLOWPROT_CKPT_CONFIG_FILENAME", "config.yaml")
cfg_ckpt_path = (
OmegaConf.select(runtime_cfg, "inference.ckpt_path") if runtime_cfg is not None else None
)
cfg_ckpt_config_path = (
OmegaConf.select(runtime_cfg, "inference.ckpt_config_path")
if runtime_cfg is not None
else None
)
if ckpt_path_env:
ckpt_path = _as_path(ckpt_path_env)
config_path = _as_path(
os.getenv("FLOWPROT_CKPT_CONFIG_PATH", str(ckpt_path.parent / config_filename))
)
source = "local_ckpt_path"
elif ckpt_dir_env:
ckpt_dir = _as_path(ckpt_dir_env)
ckpt_filename = os.getenv("FLOWPROT_CKPT_FILENAME", "epoch.ckpt")
ckpt_path = ckpt_dir / ckpt_filename
config_path = _as_path(
os.getenv("FLOWPROT_CKPT_CONFIG_PATH", str(ckpt_dir / config_filename))
)
source = "local_ckpt_dir"
elif hf_repo_id:
ckpt_filename = os.getenv("FLOWPROT_CKPT_FILENAME")
if not ckpt_filename:
raise ArtifactResolutionError(
"FLOWPROT_CKPT_FILENAME is required when FLOWPROT_HF_REPO_ID is set."
)
revision = os.getenv("FLOWPROT_HF_REVISION")
token = os.getenv("HF_TOKEN")
local_dir = snapshot_download(
repo_id=hf_repo_id,
revision=revision,
token=token,
allow_patterns=[ckpt_filename, config_filename],
)
ckpt_path = Path(local_dir) / ckpt_filename
config_path = Path(local_dir) / config_filename
source = "hf_hub_snapshot"
elif cfg_ckpt_path:
ckpt_path = _as_path(str(cfg_ckpt_path))
config_path = (
_as_path(str(cfg_ckpt_config_path))
if cfg_ckpt_config_path
else (ckpt_path.parent / config_filename)
)
source = "runtime_config"
else:
raise ArtifactResolutionError(
"No model artifact source configured. Set inference.ckpt_path in config.yaml, "
"or one of the env vars: FLOWPROT_CKPT_PATH, FLOWPROT_CKPT_DIR, or "
"FLOWPROT_HF_REPO_ID (with FLOWPROT_CKPT_FILENAME)."
)
_require_file(ckpt_path, "Checkpoint file")
_require_file(config_path, "Checkpoint config")
LOGGER.info("Resolved artifacts from %s", source)
LOGGER.info("Checkpoint: %s", ckpt_path)
LOGGER.info("Checkpoint config: %s", config_path)
return ResolvedArtifacts(ckpt_path=ckpt_path, config_path=config_path, source=source)
def resolve_classifier_artifacts(runtime_cfg: Optional[DictConfig] = None) -> ResolvedClassifierArtifacts:
"""Resolve classifier checkpoint via env var or runtime config."""
ckpt_path_env = os.getenv("FLOWPROT_CLASSIFIER_CKPT_PATH")
if ckpt_path_env:
ckpt_path = _as_path(ckpt_path_env)
source = "env_classifier_ckpt_path"
else:
cfg_path = None
if runtime_cfg is not None:
cfg_path = OmegaConf.select(runtime_cfg, "inference.classifier.ckpt_path")
if cfg_path:
ckpt_path = _as_path(str(cfg_path))
source = "runtime_config"
else:
ckpt_path = (
MODEL_ROOT / "ckpt" / "classifier_ckpt" / "epoch=90-step=728000.ckpt"
).resolve()
source = "default_classifier_ckpt"
_require_file(ckpt_path, "Classifier checkpoint file")
LOGGER.info("Resolved classifier artifacts from %s", source)
LOGGER.info("Classifier checkpoint: %s", ckpt_path)
return ResolvedClassifierArtifacts(ckpt_path=ckpt_path, source=source)
def _resolve_device(merged_cfg: DictConfig) -> torch.device:
app_cfg = merged_cfg.get("app", {})
configured = os.getenv("FLOWPROT_DEVICE", str(app_cfg.get("device", "auto"))).strip().lower()
if configured in {"", "auto"}:
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
if configured.startswith("cuda") and not torch.cuda.is_available():
raise ModelLoadError(
f"FLOWPROT_DEVICE={configured} requested, but CUDA is not available."
)
return torch.device(configured)
def _normalize_state_dict(raw_state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Normalize Lightning-style checkpoints to raw ProteinFlow module keys."""
candidates: Dict[str, torch.Tensor] = {}
for key, value in raw_state_dict.items():
if key.startswith("model."):
candidates[key[len("model."):]] = value
if candidates:
return candidates
candidates = {}
for key, value in raw_state_dict.items():
if key.startswith("module.model."):
candidates[key[len("module.model."):]] = value
if candidates:
return candidates
# Fall back to de-DDP or already-normalized keys.
normalized: Dict[str, torch.Tensor] = {}
for key, value in raw_state_dict.items():
normalized[key[len("module."):] if key.startswith("module.") else key] = value
return normalized
def _merge_runtime_and_checkpoint_cfg(
runtime_cfg: DictConfig, checkpoint_cfg_path: Path
) -> DictConfig:
ckpt_cfg = OmegaConf.load(checkpoint_cfg_path)
OmegaConf.set_struct(runtime_cfg, False)
OmegaConf.set_struct(ckpt_cfg, False)
merged = OmegaConf.merge(ckpt_cfg, runtime_cfg)
if "inference" not in merged:
merged.inference = OmegaConf.create({})
if "interpolant" in merged and "interpolant" not in merged.inference:
merged.inference.interpolant = merged.interpolant
return merged
class FlowProtModelManager:
"""Lazy model manager with cached loaded context."""
def __init__(self, config_path: Optional[str] = None):
self._config_path = config_path
self._loaded: Optional[LoadedModelContext] = None
self._last_error: Optional[str] = None
@property
def is_loaded(self) -> bool:
return self._loaded is not None
@property
def last_error(self) -> Optional[str]:
return self._last_error
def peek_loaded(self) -> Optional[LoadedModelContext]:
return self._loaded
def load(self, force_reload: bool = False) -> LoadedModelContext:
if self._loaded is not None and not force_reload:
return self._loaded
try:
runtime_cfg = load_runtime_config(self._config_path)
artifacts = resolve_artifacts(runtime_cfg)
merged_cfg = _merge_runtime_and_checkpoint_cfg(runtime_cfg, artifacts.config_path)
device = _resolve_device(merged_cfg)
# PyTorch >=2.6 defaults to weights_only=True, which breaks older
# Lightning checkpoints that store OmegaConf objects in the payload.
checkpoint_payload = torch.load(
artifacts.ckpt_path,
map_location="cpu",
weights_only=False,
)
state_dict = checkpoint_payload.get("state_dict", checkpoint_payload)
if not isinstance(state_dict, dict):
raise ModelLoadError(
"Checkpoint payload does not include a valid state_dict dictionary."
)
model = ProteinFlow(merged_cfg.model)
normalized_state_dict = _normalize_state_dict(state_dict)
missing, unexpected = model.load_state_dict(normalized_state_dict, strict=False)
if missing:
LOGGER.warning("Missing checkpoint keys while loading model: %s", missing[:20])
if unexpected:
LOGGER.warning(
"Unexpected checkpoint keys while loading model: %s", unexpected[:20]
)
model.to(device)
model.eval()
self._loaded = LoadedModelContext(
model=model,
device=device,
merged_cfg=merged_cfg,
artifacts=artifacts,
)
self._last_error = None
LOGGER.info("Model loaded successfully on %s", device)
return self._loaded
except Exception as exc:
self._last_error = str(exc)
LOGGER.exception("Failed to load FlowProt model artifacts.")
if isinstance(exc, (ArtifactResolutionError, ModelLoadError)):
raise
raise ModelLoadError(str(exc)) from exc
class FlowProtClassifierManager:
"""Lazy classifier manager with cached loaded context."""
def __init__(self, config_path: Optional[str] = None):
self._config_path = config_path
self._loaded: Optional[LoadedClassifierContext] = None
self._last_error: Optional[str] = None
@property
def is_loaded(self) -> bool:
return self._loaded is not None
@property
def last_error(self) -> Optional[str]:
return self._last_error
def peek_loaded(self) -> Optional[LoadedClassifierContext]:
return self._loaded
def load(
self,
device: torch.device,
force_reload: bool = False,
) -> LoadedClassifierContext:
if self._loaded is not None and not force_reload:
if self._loaded.device == device:
return self._loaded
try:
runtime_cfg = load_runtime_config(self._config_path)
artifacts = resolve_classifier_artifacts(runtime_cfg)
# Bypass Lightning's load_from_checkpoint: PyTorch >=2.6 defaults to
# weights_only=True and Lightning explicitly forwards that flag, which
# rejects the OmegaConf objects pickled in this checkpoint. We load the
# payload directly (trusted source) and rebuild the module ourselves.
checkpoint_payload = torch.load(
str(artifacts.ckpt_path),
map_location="cpu",
weights_only=False,
)
classifier_cfg = checkpoint_payload.get("hyper_parameters", {}).get("cfg")
if classifier_cfg is None:
sibling_config = artifacts.ckpt_path.parent / "config.yaml"
_require_file(sibling_config, "Classifier checkpoint config")
classifier_cfg = OmegaConf.load(sibling_config)
state_dict = checkpoint_payload.get("state_dict", checkpoint_payload)
if not isinstance(state_dict, dict):
raise ModelLoadError(
"Classifier checkpoint payload does not include a valid state_dict."
)
classifier = ClasfModule(classifier_cfg)
missing, unexpected = classifier.load_state_dict(state_dict, strict=False)
if missing:
LOGGER.warning("Missing classifier checkpoint keys: %s", missing[:20])
if unexpected:
LOGGER.warning("Unexpected classifier checkpoint keys: %s", unexpected[:20])
for param in classifier.parameters():
param.requires_grad_(True)
classifier.to(device)
classifier.eval()
self._loaded = LoadedClassifierContext(
classifier=classifier,
device=device,
artifacts=artifacts,
)
self._last_error = None
LOGGER.info("Classifier loaded successfully on %s", device)
return self._loaded
except Exception as exc:
self._last_error = str(exc)
LOGGER.exception("Failed to load FlowProt classifier artifacts.")
if isinstance(exc, (ArtifactResolutionError, ModelLoadError)):
raise
raise ModelLoadError(str(exc)) from exc