cesarali's picture
manual runtime bundle push from load_and_push.ipynb
5686f5b verified
from typing import Type, Union
import yaml
from sim_priors_pk.config_classes.flow_pk_config import FlowPKExperimentConfig
from sim_priors_pk.config_classes.node_pk_config import NodePKExperimentConfig
from sim_priors_pk.config_classes.diffusion_pk_config import DiffusionPKExperimentConfig
from sim_priors_pk.config_classes.utils import TupleSafeLoader
from sim_priors_pk.models.amortized_inference.aicme import AICMEPK
from sim_priors_pk.models.amortized_inference.context_vae_pk import ContextVAEPK
from sim_priors_pk.models.amortized_inference.diffusion_pk import (
ContinuousDiffusionPK,
DiscreteDiffusionPK,
)
from sim_priors_pk.models.amortized_inference.flows_pk import FlowPK
from sim_priors_pk.models.amortized_inference.prediction_pk import PredictionPK
def get_model_class(
model_config: Union[NodePKExperimentConfig, FlowPKExperimentConfig, DiffusionPKExperimentConfig] = None,
name_str: str = None,
) -> Type[Union[PredictionPK]]:
"""
Returns the model class from the model_config.name_str to be used by the
basic experiments
"""
if model_config is not None:
name_str = model_config.name_str
if name_str == "PredictionPK":
return PredictionPK
elif name_str == "AICMEPK":
return AICMEPK
elif name_str == "FlowPK":
return FlowPK
elif name_str == "ContinuousDiffusionPK":
return ContinuousDiffusionPK
elif name_str == "DiscreteDiffusionPK":
return DiscreteDiffusionPK
elif name_str == "ContextVAEPK":
return ContextVAEPK
else:
raise ValueError(f"Unknown model name: {model_config.name_str}")
def _raise_flowpk_migration_error() -> None:
raise ValueError(
"FlowPK configs now require 'experiment_type: flowpk' and a 'vector_field' section. "
"Rename the old 'network' section to 'vector_field' and update your YAML accordingly."
)
def get_model_config(
yaml_path: str,
) -> Union[NodePKExperimentConfig, FlowPKExperimentConfig, DiffusionPKExperimentConfig]:
"""
Returns the model class from the model_config.name_str to be used by the
basic experiments
"""
with open(yaml_path, "r") as file:
config_dict = yaml.load(file, Loader=TupleSafeLoader) or {}
if not isinstance(config_dict, dict):
raise TypeError("Expected experiment YAML to be a mapping.")
experiment_type = config_dict.get("experiment_type")
if experiment_type is None:
name_str = config_dict.get("name_str")
if name_str == "FlowPK":
_raise_flowpk_migration_error()
if name_str in ("ContinuousDiffusionPK", "DiscreteDiffusionPK"):
return DiffusionPKExperimentConfig.from_yaml(yaml_path)
return NodePKExperimentConfig.from_yaml(yaml_path)
experiment_type = str(experiment_type).lower()
if experiment_type == "nodepk":
return NodePKExperimentConfig.from_yaml(yaml_path)
if experiment_type == "flowpk":
return FlowPKExperimentConfig.from_yaml(yaml_path)
if experiment_type == "diffusionpk":
return DiffusionPKExperimentConfig.from_yaml(yaml_path)
raise ValueError(f"Unknown experiment_type: {experiment_type}")