from typing import Type, Union import yaml from sim_priors_pk.config_classes.flow_pk_config import FlowPKExperimentConfig from sim_priors_pk.config_classes.node_pk_config import NodePKExperimentConfig from sim_priors_pk.config_classes.diffusion_pk_config import DiffusionPKExperimentConfig from sim_priors_pk.config_classes.utils import TupleSafeLoader from sim_priors_pk.models.amortized_inference.aicme import AICMEPK from sim_priors_pk.models.amortized_inference.context_vae_pk import ContextVAEPK from sim_priors_pk.models.amortized_inference.diffusion_pk import ( ContinuousDiffusionPK, DiscreteDiffusionPK, ) from sim_priors_pk.models.amortized_inference.flows_pk import FlowPK from sim_priors_pk.models.amortized_inference.prediction_pk import PredictionPK def get_model_class( model_config: Union[NodePKExperimentConfig, FlowPKExperimentConfig, DiffusionPKExperimentConfig] = None, name_str: str = None, ) -> Type[Union[PredictionPK]]: """ Returns the model class from the model_config.name_str to be used by the basic experiments """ if model_config is not None: name_str = model_config.name_str if name_str == "PredictionPK": return PredictionPK elif name_str == "AICMEPK": return AICMEPK elif name_str == "FlowPK": return FlowPK elif name_str == "ContinuousDiffusionPK": return ContinuousDiffusionPK elif name_str == "DiscreteDiffusionPK": return DiscreteDiffusionPK elif name_str == "ContextVAEPK": return ContextVAEPK else: raise ValueError(f"Unknown model name: {model_config.name_str}") def _raise_flowpk_migration_error() -> None: raise ValueError( "FlowPK configs now require 'experiment_type: flowpk' and a 'vector_field' section. " "Rename the old 'network' section to 'vector_field' and update your YAML accordingly." ) def get_model_config( yaml_path: str, ) -> Union[NodePKExperimentConfig, FlowPKExperimentConfig, DiffusionPKExperimentConfig]: """ Returns the model class from the model_config.name_str to be used by the basic experiments """ with open(yaml_path, "r") as file: config_dict = yaml.load(file, Loader=TupleSafeLoader) or {} if not isinstance(config_dict, dict): raise TypeError("Expected experiment YAML to be a mapping.") experiment_type = config_dict.get("experiment_type") if experiment_type is None: name_str = config_dict.get("name_str") if name_str == "FlowPK": _raise_flowpk_migration_error() if name_str in ("ContinuousDiffusionPK", "DiscreteDiffusionPK"): return DiffusionPKExperimentConfig.from_yaml(yaml_path) return NodePKExperimentConfig.from_yaml(yaml_path) experiment_type = str(experiment_type).lower() if experiment_type == "nodepk": return NodePKExperimentConfig.from_yaml(yaml_path) if experiment_type == "flowpk": return FlowPKExperimentConfig.from_yaml(yaml_path) if experiment_type == "diffusionpk": return DiffusionPKExperimentConfig.from_yaml(yaml_path) raise ValueError(f"Unknown experiment_type: {experiment_type}")