"""PriorManager: orchestrates combining data priors and physics priors.""" from typing import Callable, Dict, List, Optional, Tuple import torch from torch import Tensor from physics_informed_bo.priors.data_prior import DataPrior from physics_informed_bo.priors.physics_prior import PhysicsPrior from physics_informed_bo.models.hybrid_model import HybridSurrogate class PriorManager: """Manages the combination of physics priors and data priors. Determines the best surrogate model mode based on available information: - No data, has physics → physics_only mode - Small data + physics → physics_as_mean mode (strong physics prior) - Medium data + physics → weighted_ensemble mode - Large data, poor physics → gp_only mode Also handles constraint aggregation and prior validation. """ def __init__( self, physics_prior: Optional[PhysicsPrior] = None, data_prior: Optional[DataPrior] = None, min_data_for_gp: int = 3, data_threshold_for_ensemble: int = 20, data_threshold_for_gp_only: int = 50, ): self.physics_prior = physics_prior self.data_prior = data_prior or DataPrior() self.min_data_for_gp = min_data_for_gp self.data_threshold_for_ensemble = data_threshold_for_ensemble self.data_threshold_for_gp_only = data_threshold_for_gp_only def recommend_surrogate_mode(self) -> str: """Recommend the best surrogate model mode based on available priors.""" n_data = self.data_prior.n_observations has_physics = self.physics_prior is not None if not has_physics: if n_data < self.min_data_for_gp: raise ValueError( f"Need at least {self.min_data_for_gp} data points or a physics " f"model. Got {n_data} data points and no physics model." ) return "gp_only" if n_data < self.min_data_for_gp: return "physics_only" elif n_data < self.data_threshold_for_ensemble: return "physics_as_mean" elif n_data < self.data_threshold_for_gp_only: return "weighted_ensemble" else: # Lots of data: check if physics is still useful return "weighted_ensemble" def build_surrogate( self, mode: Optional[str] = None, kernel: str = "matern", noise_variance: float = 0.01, device: str = "cpu", dtype: torch.dtype = torch.float64, ) -> HybridSurrogate: """Build and optionally fit a HybridSurrogate from the available priors. Args: mode: Override the auto-recommended mode. If None, uses recommend_surrogate_mode(). kernel: GP kernel type. noise_variance: Initial noise variance. device: Torch device. dtype: Torch dtype. Returns: A configured (and fitted if data is available) HybridSurrogate. """ if mode is None: mode = self.recommend_surrogate_mode() physics_fn = self.physics_prior.evaluate if self.physics_prior else None if physics_fn is None and mode in ("physics_only", "physics_as_mean", "weighted_ensemble"): raise ValueError(f"Mode '{mode}' requires a physics model but none was provided.") if physics_fn is None: # Use a zero mean function as placeholder physics_fn = lambda x: torch.zeros(x.shape[0], dtype=x.dtype, device=x.device) surrogate = HybridSurrogate( physics_fn=physics_fn, mode=mode, kernel=kernel, noise_variance=noise_variance, device=device, dtype=dtype, ) # Auto-fit if data is available if self.data_prior.n_observations >= self.min_data_for_gp: surrogate.fit(self.data_prior.X, self.data_prior.y) return surrogate def get_all_constraints(self) -> list: """Get all constraints from the physics prior.""" if self.physics_prior is None: return [] return self.physics_prior.constraints def validate_candidates(self, X: Tensor) -> Dict: """Validate candidate points against physics constraints. Returns: Dict with feasibility mask and violation details. """ if self.physics_prior is None: return { "feasible": torch.ones(len(X), dtype=torch.bool), "violations": {}, } feasible = self.physics_prior.check_feasibility(X) violations = {} for constraint in self.physics_prior.constraints: violations[constraint.name] = { "violation": constraint.evaluate(X), "feasible": constraint.is_feasible(X), } return {"feasible": feasible, "violations": violations} def update_with_observations(self, X_new: Tensor, y_new: Tensor) -> None: """Add new observations to the data prior.""" self.data_prior.add_observations(X_new, y_new) def summary(self) -> Dict: """Return a summary of the current prior state.""" return { "has_physics_model": self.physics_prior is not None, "n_physics_constraints": len(self.get_all_constraints()), "n_observations": self.data_prior.n_observations, "recommended_mode": self.recommend_surrogate_mode() if (self.physics_prior or self.data_prior.n_observations >= self.min_data_for_gp) else "insufficient_data", }