| """PriorManager: orchestrates combining data priors and physics priors."""
|
|
|
| from typing import Callable, Dict, List, Optional, Tuple
|
|
|
| import torch
|
| from torch import Tensor
|
|
|
| from physics_informed_bo.priors.data_prior import DataPrior
|
| from physics_informed_bo.priors.physics_prior import PhysicsPrior
|
| from physics_informed_bo.models.hybrid_model import HybridSurrogate
|
|
|
|
|
| class PriorManager:
|
| """Manages the combination of physics priors and data priors.
|
|
|
| Determines the best surrogate model mode based on available information:
|
| - No data, has physics → physics_only mode
|
| - Small data + physics → physics_as_mean mode (strong physics prior)
|
| - Medium data + physics → weighted_ensemble mode
|
| - Large data, poor physics → gp_only mode
|
|
|
| Also handles constraint aggregation and prior validation.
|
| """
|
|
|
| def __init__(
|
| self,
|
| physics_prior: Optional[PhysicsPrior] = None,
|
| data_prior: Optional[DataPrior] = None,
|
| min_data_for_gp: int = 3,
|
| data_threshold_for_ensemble: int = 20,
|
| data_threshold_for_gp_only: int = 50,
|
| ):
|
| self.physics_prior = physics_prior
|
| self.data_prior = data_prior or DataPrior()
|
| self.min_data_for_gp = min_data_for_gp
|
| self.data_threshold_for_ensemble = data_threshold_for_ensemble
|
| self.data_threshold_for_gp_only = data_threshold_for_gp_only
|
|
|
| def recommend_surrogate_mode(self) -> str:
|
| """Recommend the best surrogate model mode based on available priors."""
|
| n_data = self.data_prior.n_observations
|
| has_physics = self.physics_prior is not None
|
|
|
| if not has_physics:
|
| if n_data < self.min_data_for_gp:
|
| raise ValueError(
|
| f"Need at least {self.min_data_for_gp} data points or a physics "
|
| f"model. Got {n_data} data points and no physics model."
|
| )
|
| return "gp_only"
|
|
|
| if n_data < self.min_data_for_gp:
|
| return "physics_only"
|
| elif n_data < self.data_threshold_for_ensemble:
|
| return "physics_as_mean"
|
| elif n_data < self.data_threshold_for_gp_only:
|
| return "weighted_ensemble"
|
| else:
|
|
|
| return "weighted_ensemble"
|
|
|
| def build_surrogate(
|
| self,
|
| mode: Optional[str] = None,
|
| kernel: str = "matern",
|
| noise_variance: float = 0.01,
|
| device: str = "cpu",
|
| dtype: torch.dtype = torch.float64,
|
| ) -> HybridSurrogate:
|
| """Build and optionally fit a HybridSurrogate from the available priors.
|
|
|
| Args:
|
| mode: Override the auto-recommended mode. If None, uses recommend_surrogate_mode().
|
| kernel: GP kernel type.
|
| noise_variance: Initial noise variance.
|
| device: Torch device.
|
| dtype: Torch dtype.
|
|
|
| Returns:
|
| A configured (and fitted if data is available) HybridSurrogate.
|
| """
|
| if mode is None:
|
| mode = self.recommend_surrogate_mode()
|
|
|
| physics_fn = self.physics_prior.evaluate if self.physics_prior else None
|
|
|
| if physics_fn is None and mode in ("physics_only", "physics_as_mean", "weighted_ensemble"):
|
| raise ValueError(f"Mode '{mode}' requires a physics model but none was provided.")
|
|
|
| if physics_fn is None:
|
|
|
| physics_fn = lambda x: torch.zeros(x.shape[0], dtype=x.dtype, device=x.device)
|
|
|
| surrogate = HybridSurrogate(
|
| physics_fn=physics_fn,
|
| mode=mode,
|
| kernel=kernel,
|
| noise_variance=noise_variance,
|
| device=device,
|
| dtype=dtype,
|
| )
|
|
|
|
|
| if self.data_prior.n_observations >= self.min_data_for_gp:
|
| surrogate.fit(self.data_prior.X, self.data_prior.y)
|
|
|
| return surrogate
|
|
|
| def get_all_constraints(self) -> list:
|
| """Get all constraints from the physics prior."""
|
| if self.physics_prior is None:
|
| return []
|
| return self.physics_prior.constraints
|
|
|
| def validate_candidates(self, X: Tensor) -> Dict:
|
| """Validate candidate points against physics constraints.
|
|
|
| Returns:
|
| Dict with feasibility mask and violation details.
|
| """
|
| if self.physics_prior is None:
|
| return {
|
| "feasible": torch.ones(len(X), dtype=torch.bool),
|
| "violations": {},
|
| }
|
|
|
| feasible = self.physics_prior.check_feasibility(X)
|
| violations = {}
|
| for constraint in self.physics_prior.constraints:
|
| violations[constraint.name] = {
|
| "violation": constraint.evaluate(X),
|
| "feasible": constraint.is_feasible(X),
|
| }
|
|
|
| return {"feasible": feasible, "violations": violations}
|
|
|
| def update_with_observations(self, X_new: Tensor, y_new: Tensor) -> None:
|
| """Add new observations to the data prior."""
|
| self.data_prior.add_observations(X_new, y_new)
|
|
|
| def summary(self) -> Dict:
|
| """Return a summary of the current prior state."""
|
| return {
|
| "has_physics_model": self.physics_prior is not None,
|
| "n_physics_constraints": len(self.get_all_constraints()),
|
| "n_observations": self.data_prior.n_observations,
|
| "recommended_mode": self.recommend_surrogate_mode()
|
| if (self.physics_prior or self.data_prior.n_observations >= self.min_data_for_gp)
|
| else "insufficient_data",
|
| }
|
|
|