| from typing import Type, Union |
|
|
| import yaml |
|
|
| from sim_priors_pk.config_classes.flow_pk_config import FlowPKExperimentConfig |
| from sim_priors_pk.config_classes.node_pk_config import NodePKExperimentConfig |
| from sim_priors_pk.config_classes.diffusion_pk_config import DiffusionPKExperimentConfig |
| from sim_priors_pk.config_classes.utils import TupleSafeLoader |
| from sim_priors_pk.models.amortized_inference.aicme import AICMEPK |
| from sim_priors_pk.models.amortized_inference.context_vae_pk import ContextVAEPK |
| from sim_priors_pk.models.amortized_inference.diffusion_pk import ( |
| ContinuousDiffusionPK, |
| DiscreteDiffusionPK, |
| ) |
| from sim_priors_pk.models.amortized_inference.flows_pk import FlowPK |
| from sim_priors_pk.models.amortized_inference.prediction_pk import PredictionPK |
|
|
|
|
| def get_model_class( |
| model_config: Union[NodePKExperimentConfig, FlowPKExperimentConfig, DiffusionPKExperimentConfig] = None, |
| name_str: str = None, |
| ) -> Type[Union[PredictionPK]]: |
| """ |
| Returns the model class from the model_config.name_str to be used by the |
| basic experiments |
| """ |
| if model_config is not None: |
| name_str = model_config.name_str |
|
|
| if name_str == "PredictionPK": |
| return PredictionPK |
| elif name_str == "AICMEPK": |
| return AICMEPK |
| elif name_str == "FlowPK": |
| return FlowPK |
| elif name_str == "ContinuousDiffusionPK": |
| return ContinuousDiffusionPK |
| elif name_str == "DiscreteDiffusionPK": |
| return DiscreteDiffusionPK |
| elif name_str == "ContextVAEPK": |
| return ContextVAEPK |
| else: |
| raise ValueError(f"Unknown model name: {model_config.name_str}") |
|
|
|
|
| def _raise_flowpk_migration_error() -> None: |
| raise ValueError( |
| "FlowPK configs now require 'experiment_type: flowpk' and a 'vector_field' section. " |
| "Rename the old 'network' section to 'vector_field' and update your YAML accordingly." |
| ) |
|
|
|
|
| def get_model_config( |
| yaml_path: str, |
| ) -> Union[NodePKExperimentConfig, FlowPKExperimentConfig, DiffusionPKExperimentConfig]: |
| """ |
| Returns the model class from the model_config.name_str to be used by the |
| basic experiments |
| """ |
| with open(yaml_path, "r") as file: |
| config_dict = yaml.load(file, Loader=TupleSafeLoader) or {} |
|
|
| if not isinstance(config_dict, dict): |
| raise TypeError("Expected experiment YAML to be a mapping.") |
|
|
| experiment_type = config_dict.get("experiment_type") |
| if experiment_type is None: |
| name_str = config_dict.get("name_str") |
| if name_str == "FlowPK": |
| _raise_flowpk_migration_error() |
| if name_str in ("ContinuousDiffusionPK", "DiscreteDiffusionPK"): |
| return DiffusionPKExperimentConfig.from_yaml(yaml_path) |
| return NodePKExperimentConfig.from_yaml(yaml_path) |
|
|
| experiment_type = str(experiment_type).lower() |
| if experiment_type == "nodepk": |
| return NodePKExperimentConfig.from_yaml(yaml_path) |
| if experiment_type == "flowpk": |
| return FlowPKExperimentConfig.from_yaml(yaml_path) |
| if experiment_type == "diffusionpk": |
| return DiffusionPKExperimentConfig.from_yaml(yaml_path) |
| raise ValueError(f"Unknown experiment_type: {experiment_type}") |
|
|