"""Shared runtime-contract helpers for consumer-facing Hub bundles. This module is imported both by the local exporter and by the copied package inside the generated Hugging Face runtime bundle. Keep dependencies limited to modules that are already required for model inference. """ from __future__ import annotations from copy import deepcopy from dataclasses import asdict, dataclass from pathlib import Path from typing import Any, Dict, List, Mapping, Optional, Sequence, Union, get_args, get_origin import torch from transformers import PretrainedConfig from sim_priors_pk.config_classes.data_config import ( MetaDosingConfig, MetaStudyConfig, MixDataConfig, ObservationsConfig, SimpleMetaStudyConfig, ) from sim_priors_pk.config_classes.diffusion_pk_config import DiffusionPKExperimentConfig from sim_priors_pk.config_classes.flow_pk_config import FlowPKExperimentConfig, VectorFieldPKConfig from sim_priors_pk.config_classes.node_pk_config import ( EncoderDecoderNetworkConfig, NodePKExperimentConfig, ) from sim_priors_pk.config_classes.source_process_config import SourceProcessConfig from sim_priors_pk.config_classes.training_config import TrainingConfig from sim_priors_pk.data.data_empirical.builder import EmpiricalBatchConfig, JSON2AICMEBuilder from sim_priors_pk.data.data_empirical.json_schema import IndividualJSON, StudyJSON, canonicalize_study from sim_priors_pk.data.data_generation.observations_classes import ObservationStrategyFactory from sim_priors_pk.models import get_model_class from sim_priors_pk.models.amortized_inference.generative_pk import ( NewGenerativeMixin, NewPredictiveMixin, ) SUPPORTED_RUNTIME_ARCHITECTURES = { "AICMEPK", "ContextVAEPK", "FlowPK", "PredictionPK", } STUDY_JSON_IO_VERSION = "studyjson-v1" @dataclass class RuntimeBuilderConfig: """Fixed builder capacities serialized into the Hub runtime config.""" max_context_individuals: int max_target_individuals: int max_context_observations: int max_target_observations: int max_context_remaining: int max_target_remaining: int def to_dict(self) -> Dict[str, int]: """Return a JSON-serializable representation.""" return asdict(self) @classmethod def from_dict(cls, payload: Mapping[str, Any]) -> "RuntimeBuilderConfig": """Instantiate the builder capacities from serialized config payload.""" return cls( max_context_individuals=int(payload["max_context_individuals"]), max_target_individuals=int(payload["max_target_individuals"]), max_context_observations=int(payload["max_context_observations"]), max_target_observations=int(payload["max_target_observations"]), max_context_remaining=int(payload["max_context_remaining"]), max_target_remaining=int(payload["max_target_remaining"]), ) def to_empirical_batch_config(self, *, max_databatch_size: int) -> EmpiricalBatchConfig: """Translate runtime capacities to the builder used by StudyJSON IO.""" return EmpiricalBatchConfig( max_databatch_size=int(max_databatch_size), max_individuals=max(self.max_context_individuals, self.max_target_individuals), max_observations=max(self.max_context_observations, self.max_target_observations), max_remaining=max(self.max_context_remaining, self.max_target_remaining), max_context_individuals=self.max_context_individuals, max_target_individuals=self.max_target_individuals, max_context_observations=self.max_context_observations, max_target_observations=self.max_target_observations, max_context_remaining=self.max_context_remaining, max_target_remaining=self.max_target_remaining, ) def _coerce_annotation(annotation: Any, value: Any) -> Any: """Best-effort coercion of JSON-loaded values into dataclass field types.""" if value is None: return None origin = get_origin(annotation) args = get_args(annotation) if origin is Union: non_none = [arg for arg in args if arg is not type(None)] for candidate in non_none: if candidate in (dict, Dict, Any, Mapping): continue try: return _coerce_annotation(candidate, value) except Exception: continue return value if origin in (list, List, Sequence): (inner_type,) = args if args else (Any,) return [_coerce_annotation(inner_type, item) for item in value] if origin in (tuple,): if not args: return tuple(value) if len(args) == 2 and args[1] is Ellipsis: return tuple(_coerce_annotation(args[0], item) for item in value) return tuple(_coerce_annotation(inner, item) for inner, item in zip(args, value)) if origin in (dict, Dict, Mapping): return dict(value) if annotation is Any: return value if annotation is MetaStudyConfig and isinstance(value, Mapping) and value.get("simple_mode"): return SimpleMetaStudyConfig(**dict(value)) if hasattr(annotation, "__dataclass_fields__") and isinstance(value, Mapping): kwargs = {} for field_name, field_def in annotation.__dataclass_fields__.items(): if field_name in value: kwargs[field_name] = _coerce_annotation(field_def.type, value[field_name]) return annotation(**kwargs) return value def _rebuild_node_config(payload: Mapping[str, Any]) -> NodePKExperimentConfig: """Reconstruct a ``NodePKExperimentConfig`` from serialized dict content.""" return NodePKExperimentConfig( experiment_type=str(payload.get("experiment_type", "nodepk")).lower(), name_str=str(payload.get("name_str", "NodePK")), comet_ai_key=payload.get("comet_ai_key"), experiment_name=str(payload.get("experiment_name", "node_pk_compartments")), hugging_face_token=payload.get("hugging_face_token"), upload_to_hf_hub=bool(payload.get("upload_to_hf_hub", False)), hf_model_name=str(payload.get("hf_model_name", "NodePK_runtime")), hf_model_card_path=tuple(payload.get("hf_model_card_path", ("hf_model_cards", "README.md"))), tags=list(payload.get("tags", [])), experiment_indentifier=payload.get("experiment_indentifier"), my_results_path=payload.get("my_results_path"), experiment_dir=payload.get("experiment_dir"), verbose=bool(payload.get("verbose", False)), run_index=int(payload.get("run_index", 0)), debug_test=bool(payload.get("debug_test", False)), network=_coerce_annotation(EncoderDecoderNetworkConfig, payload.get("network", {})), mix_data=_coerce_annotation(MixDataConfig, payload.get("mix_data", {})), context_observations=_coerce_annotation( ObservationsConfig, payload.get("context_observations", {}) ), target_observations=_coerce_annotation( ObservationsConfig, payload.get("target_observations", {}) ), meta_study=_coerce_annotation(MetaStudyConfig, payload.get("meta_study", {})), dosing=_coerce_annotation(MetaDosingConfig, payload.get("dosing", {})), train=_coerce_annotation(TrainingConfig, payload.get("train", {})), ) def _rebuild_flow_config(payload: Mapping[str, Any]) -> FlowPKExperimentConfig: """Reconstruct a ``FlowPKExperimentConfig`` from serialized dict content.""" return FlowPKExperimentConfig( experiment_type=str(payload.get("experiment_type", "flowpk")).lower(), name_str=str(payload.get("name_str", "FlowPK")), comet_ai_key=payload.get("comet_ai_key"), experiment_name=str(payload.get("experiment_name", "flow_pk_compartments")), hugging_face_token=payload.get("hugging_face_token"), upload_to_hf_hub=bool(payload.get("upload_to_hf_hub", False)), hf_model_name=str(payload.get("hf_model_name", "FlowPK_runtime")), hf_model_card_path=tuple(payload.get("hf_model_card_path", ("hf_model_cards", "README.md"))), tags=list(payload.get("tags", [])), experiment_indentifier=payload.get("experiment_indentifier"), my_results_path=payload.get("my_results_path"), experiment_dir=payload.get("experiment_dir"), verbose=bool(payload.get("verbose", False)), run_index=int(payload.get("run_index", 0)), debug_test=bool(payload.get("debug_test", False)), flow_num_steps=int(payload.get("flow_num_steps", 50)), vector_field=_coerce_annotation(VectorFieldPKConfig, payload.get("vector_field", {})), source_process=_coerce_annotation(SourceProcessConfig, payload.get("source_process", {})), mix_data=_coerce_annotation(MixDataConfig, payload.get("mix_data", {})), context_observations=_coerce_annotation( ObservationsConfig, payload.get("context_observations", {}) ), target_observations=_coerce_annotation( ObservationsConfig, payload.get("target_observations", {}) ), meta_study=_coerce_annotation(MetaStudyConfig, payload.get("meta_study", {})), dosing=_coerce_annotation(MetaDosingConfig, payload.get("dosing", {})), train=_coerce_annotation(TrainingConfig, payload.get("train", {})), ) def _rebuild_diffusion_config(payload: Mapping[str, Any]) -> DiffusionPKExperimentConfig: """Reconstruct a ``DiffusionPKExperimentConfig`` from serialized dict content.""" return DiffusionPKExperimentConfig( experiment_type=str(payload.get("experiment_type", "diffusionpk")).lower(), name_str=str(payload.get("name_str", "ContinuousDiffusionPK")), diffusion_type=str(payload.get("diffusion_type", "continuous")), comet_ai_key=payload.get("comet_ai_key"), experiment_name=str(payload.get("experiment_name", "diffusion_pk_compartments")), hugging_face_token=payload.get("hugging_face_token"), upload_to_hf_hub=bool(payload.get("upload_to_hf_hub", False)), hf_model_name=str(payload.get("hf_model_name", "DiffusionPK_runtime")), hf_model_card_path=tuple(payload.get("hf_model_card_path", ("hf_model_cards", "README.md"))), tags=list(payload.get("tags", [])), experiment_indentifier=payload.get("experiment_indentifier"), my_results_path=payload.get("my_results_path"), experiment_dir=payload.get("experiment_dir"), verbose=bool(payload.get("verbose", False)), run_index=int(payload.get("run_index", 0)), debug_test=bool(payload.get("debug_test", False)), predict_gaussian_noise=bool(payload.get("predict_gaussian_noise", True)), network=_coerce_annotation(EncoderDecoderNetworkConfig, payload.get("network", {})), source_process=_coerce_annotation(SourceProcessConfig, payload.get("source_process", {})), mix_data=_coerce_annotation(MixDataConfig, payload.get("mix_data", {})), context_observations=_coerce_annotation( ObservationsConfig, payload.get("context_observations", {}) ), target_observations=_coerce_annotation( ObservationsConfig, payload.get("target_observations", {}) ), meta_study=_coerce_annotation(MetaStudyConfig, payload.get("meta_study", {})), dosing=_coerce_annotation(MetaDosingConfig, payload.get("dosing", {})), train=_coerce_annotation(TrainingConfig, payload.get("train", {})), ) def rebuild_experiment_config( payload: Mapping[str, Any], ) -> Union[NodePKExperimentConfig, FlowPKExperimentConfig, DiffusionPKExperimentConfig]: """Rebuild the serialized experiment config stored in the Hub config.""" experiment_type = str(payload.get("experiment_type", "nodepk")).lower() if experiment_type == "nodepk": return _rebuild_node_config(payload) if experiment_type == "flowpk": return _rebuild_flow_config(payload) if experiment_type == "diffusionpk": return _rebuild_diffusion_config(payload) raise ValueError(f"Unsupported experiment_type for runtime bundle: {experiment_type!r}.") def compute_runtime_builder_config( exp_config: Union[NodePKExperimentConfig, FlowPKExperimentConfig, DiffusionPKExperimentConfig], ) -> RuntimeBuilderConfig: """Compute fixed empirical StudyJSON capacities from the experiment config.""" context_strategy = ObservationStrategyFactory.from_config( exp_config.context_observations, exp_config.meta_study, ) target_strategy = ObservationStrategyFactory.from_config( exp_config.target_observations, exp_config.meta_study, ) ctx_obs_cap, ctx_rem_cap = context_strategy.get_shapes() tgt_obs_cap, tgt_rem_cap = target_strategy.get_shapes() max_context_individuals = int(exp_config.meta_study.num_individuals_range[-1]) max_target_individuals = int(getattr(exp_config.mix_data, "n_of_target_individuals", 1)) if max_target_individuals < 0: raise ValueError("n_of_target_individuals must be >= 0 for Hub runtime export.") return RuntimeBuilderConfig( max_context_individuals=max_context_individuals, max_target_individuals=max_target_individuals, max_context_observations=int(ctx_obs_cap), max_target_observations=int(tgt_obs_cap), max_context_remaining=int(ctx_rem_cap), max_target_remaining=int(tgt_rem_cap), ) def infer_supported_tasks(backbone: torch.nn.Module) -> List[str]: """Infer the public task surface supported by the wrapped model.""" tasks: List[str] = [] if isinstance(backbone, NewGenerativeMixin): tasks.append("generate") if isinstance(backbone, NewPredictiveMixin): tasks.append("predict") return tasks def validate_runtime_architecture(backbone: torch.nn.Module) -> str: """Ensure the loaded architecture is supported by the runtime bundle v1.""" architecture_name = backbone.__class__.__name__ if architecture_name not in SUPPORTED_RUNTIME_ARCHITECTURES: raise ValueError( "Runtime Hub export only supports " f"{sorted(SUPPORTED_RUNTIME_ARCHITECTURES)}, got {architecture_name!r}." ) return architecture_name def build_runtime_config_payload( *, backbone: torch.nn.Module, exp_config: Union[NodePKExperimentConfig, FlowPKExperimentConfig, DiffusionPKExperimentConfig], original_repo_id: Optional[str], runtime_repo_id: Optional[str], ) -> Dict[str, Any]: """Build the serializable fields stored in the Hub config.""" architecture_name = validate_runtime_architecture(backbone) supported_tasks = infer_supported_tasks(backbone) if not supported_tasks: raise ValueError(f"Model {architecture_name!r} does not expose runtime tasks.") builder_config = compute_runtime_builder_config(exp_config) return { "architecture_name": architecture_name, "experiment_type": str(getattr(exp_config, "experiment_type", "nodepk")).lower(), "experiment_config": asdict(exp_config), "builder_config": builder_config.to_dict(), "supported_tasks": supported_tasks, "default_task": supported_tasks[0], "io_schema_version": STUDY_JSON_IO_VERSION, "original_repo_id": original_repo_id, "runtime_repo_id": runtime_repo_id, } def instantiate_backbone_from_hub_config(config: PretrainedConfig) -> torch.nn.Module: """Rebuild the internal PK model represented by the public Hub wrapper.""" experiment_config_payload = getattr(config, "experiment_config", None) if not isinstance(experiment_config_payload, Mapping): raise ValueError("Hub config is missing the serialized experiment_config payload.") exp_config = rebuild_experiment_config(experiment_config_payload) model_cls = get_model_class(exp_config) backbone = model_cls(exp_config) backbone.eval() return backbone def normalize_studies_input( studies: Union[StudyJSON, Sequence[StudyJSON]], ) -> List[StudyJSON]: """Normalize runtime input to a mutable list of canonicalized studies.""" if isinstance(studies, Mapping): raw_studies = [dict(studies)] else: raw_studies = [dict(study) for study in studies] return [canonicalize_study(study, drop_tgt_too_few=False) for study in raw_studies] def validate_studies_for_task( studies: Sequence[StudyJSON], *, task: str, builder_config: RuntimeBuilderConfig, ) -> None: """Validate task semantics and reject inputs that exceed runtime capacities.""" for study_idx, study in enumerate(studies): context = list(study.get("context", [])) target = list(study.get("target", [])) if task == "generate": if not context: raise ValueError("`generate` requires at least one context individual per study.") if target: raise ValueError("`generate` expects target to be empty in the input StudyJSON.") elif task == "predict": if not target: raise ValueError("`predict` requires at least one target individual per study.") else: raise ValueError(f"Unsupported task {task!r}.") if len(context) > builder_config.max_context_individuals: raise ValueError( f"Study {study_idx} exceeds context individual capacity " f"({len(context)} > {builder_config.max_context_individuals})." ) if len(target) > builder_config.max_target_individuals: raise ValueError( f"Study {study_idx} exceeds target individual capacity " f"({len(target)} > {builder_config.max_target_individuals})." ) _validate_individual_block( study_idx=study_idx, block_name="context", individuals=context, max_observations=builder_config.max_context_observations, max_remaining=builder_config.max_context_remaining, ) _validate_individual_block( study_idx=study_idx, block_name="target", individuals=target, max_observations=builder_config.max_target_observations, max_remaining=builder_config.max_target_remaining, ) def _validate_individual_block( *, study_idx: int, block_name: str, individuals: Sequence[IndividualJSON], max_observations: int, max_remaining: int, ) -> None: """Reject studies that would otherwise be truncated by the empirical builder.""" for ind_idx, individual in enumerate(individuals): obs_len = len(individual.get("observations", [])) rem_len = len(individual.get("remaining", [])) if obs_len > max_observations: raise ValueError( f"Study {study_idx} {block_name}[{ind_idx}] exceeds observation capacity " f"({obs_len} > {max_observations})." ) if rem_len > max_remaining: raise ValueError( f"Study {study_idx} {block_name}[{ind_idx}] exceeds remaining capacity " f"({rem_len} > {max_remaining})." ) def build_batch_from_studies( studies: Sequence[StudyJSON], *, builder_config: RuntimeBuilderConfig, meta_dosing: MetaDosingConfig, ): """Convert canonical studies into the internal PK databatch representation.""" builder = JSON2AICMEBuilder( builder_config.to_empirical_batch_config(max_databatch_size=max(1, len(studies))) ) return builder.build_one_aicmebatch(list(studies), meta_dosing) def split_runtime_samples(task: str, study: StudyJSON) -> List[StudyJSON]: """Convert model-specific StudyJSON outputs into per-sample StudyJSONs.""" if task == "generate": return _split_generate_samples(study) if task == "predict": return _split_predict_samples(study) raise ValueError(f"Unsupported task {task!r}.") def _split_generate_samples(study: StudyJSON) -> List[StudyJSON]: """Split generated target individuals into one StudyJSON per sample.""" targets = list(study.get("target", [])) if not targets: return [deepcopy(study)] split: List[StudyJSON] = [] for target in targets: split.append( { "context": deepcopy(study.get("context", [])), "target": [deepcopy(target)], "meta_data": deepcopy(study.get("meta_data", {})), } ) return split def _split_predict_samples(study: StudyJSON) -> List[StudyJSON]: """Split target prediction samples into one StudyJSON per sample index.""" targets = list(study.get("target", [])) if not targets: return [deepcopy(study)] sample_count = 0 for target in targets: sample_count = max(sample_count, len(target.get("prediction_samples", []))) if sample_count == 0: return [deepcopy(study)] split: List[StudyJSON] = [] for sample_idx in range(sample_count): target_block: List[IndividualJSON] = [] for target in targets: target_copy: IndividualJSON = deepcopy(target) samples = list(target.get("prediction_samples", [])) if samples: if sample_idx >= len(samples): raise ValueError( "All target individuals must expose the same number of prediction samples." ) target_copy["prediction_samples"] = [deepcopy(samples[sample_idx])] target_block.append(target_copy) split.append( { "context": deepcopy(study.get("context", [])), "target": target_block, "meta_data": deepcopy(study.get("meta_data", {})), } ) return split def runtime_readme_text( *, base_model_card: str, runtime_repo_id: str, original_repo_id: Optional[str], supported_tasks: Sequence[str], default_task: str, ) -> str: """Compose the README uploaded with the consumer-facing runtime bundle.""" original_line = ( f"- Native training/artifact repo: `{original_repo_id}`" if original_repo_id else "- Native training/artifact repo: not recorded" ) tasks_literal = ", ".join(f"`{task}`" for task in supported_tasks) usage = f""" ## Runtime Bundle This repository is the consumer-facing runtime bundle for this PK model. - Runtime repo: `{runtime_repo_id}` {original_line} - Supported tasks: {tasks_literal} - Default task: `{default_task}` - Load path: `AutoModel.from_pretrained(..., trust_remote_code=True)` ### Installation You do **not** need to install `sim_priors_pk` to use this runtime bundle. `transformers` is the public loading entrypoint, but `transformers` alone is not sufficient because this is a PyTorch model with custom runtime code. A reliable consumer environment is: ```bash pip install torch transformers huggingface_hub lightning datasets pandas torchtyping gpytorch pot torchdiffeq torchsde ruamel.yaml pyyaml ``` ### Python Usage ```python from transformers import AutoModel model = AutoModel.from_pretrained("{runtime_repo_id}", trust_remote_code=True) studies = [ {{ "context": [ {{ "name_id": "ctx_0", "observations": [0.2, 0.5, 0.3], "observation_times": [0.5, 1.0, 2.0], "dosing": [1.0], "dosing_type": ["oral"], "dosing_times": [0.0], "dosing_name": ["oral"], }} ], "target": [], "meta_data": {{"study_name": "demo", "substance_name": "drug_x"}}, }} ] outputs = model.run_task( task="{default_task}", studies=studies, num_samples=4, ) print(outputs["results"][0]["samples"]) ``` ### Predictive Sampling ```python from transformers import AutoModel model = AutoModel.from_pretrained("{runtime_repo_id}", trust_remote_code=True) predict_studies = [ {{ "context": [ {{ "name_id": "ctx_0", "observations": [0.2, 0.5, 0.3], "observation_times": [0.5, 1.0, 2.0], "dosing": [1.0], "dosing_type": ["oral"], "dosing_times": [0.0], "dosing_name": ["oral"], }} ], "target": [ {{ "name_id": "tgt_0", "observations": [0.25, 0.31], "observation_times": [0.5, 1.0], "remaining": [0.0, 0.0, 0.0], "remaining_times": [2.0, 4.0, 8.0], "dosing": [1.0], "dosing_type": ["oral"], "dosing_times": [0.0], "dosing_name": ["oral"], }} ], "meta_data": {{"study_name": "demo", "substance_name": "drug_x"}}, }} ] outputs = model.run_task( task="predict", studies=predict_studies, num_samples=4, ) print(outputs["results"][0]["samples"][0]["target"][0]["prediction_samples"]) ``` ### Notes - `trust_remote_code=True` is required because this model uses custom Hugging Face Hub runtime code. - The consumer API is `transformers` + `run_task(...)`; the consumer does not need a local clone of this repository. - This runtime bundle is intentionally separate from the native training export so you can evaluate both distribution paths in parallel. """ return base_model_card.rstrip() + "\n" + usage.strip() + "\n" def resolve_model_card_text(model_card_path: Path) -> str: """Read and validate the model card that seeds the runtime README.""" if not model_card_path.is_file(): raise FileNotFoundError(f"Model card not found at: {model_card_path}") return model_card_path.read_text(encoding="utf-8")