File size: 5,781 Bytes
41a65be | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | """PriorManager: orchestrates combining data priors and physics priors."""
from typing import Callable, Dict, List, Optional, Tuple
import torch
from torch import Tensor
from physics_informed_bo.priors.data_prior import DataPrior
from physics_informed_bo.priors.physics_prior import PhysicsPrior
from physics_informed_bo.models.hybrid_model import HybridSurrogate
class PriorManager:
"""Manages the combination of physics priors and data priors.
Determines the best surrogate model mode based on available information:
- No data, has physics → physics_only mode
- Small data + physics → physics_as_mean mode (strong physics prior)
- Medium data + physics → weighted_ensemble mode
- Large data, poor physics → gp_only mode
Also handles constraint aggregation and prior validation.
"""
def __init__(
self,
physics_prior: Optional[PhysicsPrior] = None,
data_prior: Optional[DataPrior] = None,
min_data_for_gp: int = 3,
data_threshold_for_ensemble: int = 20,
data_threshold_for_gp_only: int = 50,
):
self.physics_prior = physics_prior
self.data_prior = data_prior or DataPrior()
self.min_data_for_gp = min_data_for_gp
self.data_threshold_for_ensemble = data_threshold_for_ensemble
self.data_threshold_for_gp_only = data_threshold_for_gp_only
def recommend_surrogate_mode(self) -> str:
"""Recommend the best surrogate model mode based on available priors."""
n_data = self.data_prior.n_observations
has_physics = self.physics_prior is not None
if not has_physics:
if n_data < self.min_data_for_gp:
raise ValueError(
f"Need at least {self.min_data_for_gp} data points or a physics "
f"model. Got {n_data} data points and no physics model."
)
return "gp_only"
if n_data < self.min_data_for_gp:
return "physics_only"
elif n_data < self.data_threshold_for_ensemble:
return "physics_as_mean"
elif n_data < self.data_threshold_for_gp_only:
return "weighted_ensemble"
else:
# Lots of data: check if physics is still useful
return "weighted_ensemble"
def build_surrogate(
self,
mode: Optional[str] = None,
kernel: str = "matern",
noise_variance: float = 0.01,
device: str = "cpu",
dtype: torch.dtype = torch.float64,
) -> HybridSurrogate:
"""Build and optionally fit a HybridSurrogate from the available priors.
Args:
mode: Override the auto-recommended mode. If None, uses recommend_surrogate_mode().
kernel: GP kernel type.
noise_variance: Initial noise variance.
device: Torch device.
dtype: Torch dtype.
Returns:
A configured (and fitted if data is available) HybridSurrogate.
"""
if mode is None:
mode = self.recommend_surrogate_mode()
physics_fn = self.physics_prior.evaluate if self.physics_prior else None
if physics_fn is None and mode in ("physics_only", "physics_as_mean", "weighted_ensemble"):
raise ValueError(f"Mode '{mode}' requires a physics model but none was provided.")
if physics_fn is None:
# Use a zero mean function as placeholder
physics_fn = lambda x: torch.zeros(x.shape[0], dtype=x.dtype, device=x.device)
surrogate = HybridSurrogate(
physics_fn=physics_fn,
mode=mode,
kernel=kernel,
noise_variance=noise_variance,
device=device,
dtype=dtype,
)
# Auto-fit if data is available
if self.data_prior.n_observations >= self.min_data_for_gp:
surrogate.fit(self.data_prior.X, self.data_prior.y)
return surrogate
def get_all_constraints(self) -> list:
"""Get all constraints from the physics prior."""
if self.physics_prior is None:
return []
return self.physics_prior.constraints
def validate_candidates(self, X: Tensor) -> Dict:
"""Validate candidate points against physics constraints.
Returns:
Dict with feasibility mask and violation details.
"""
if self.physics_prior is None:
return {
"feasible": torch.ones(len(X), dtype=torch.bool),
"violations": {},
}
feasible = self.physics_prior.check_feasibility(X)
violations = {}
for constraint in self.physics_prior.constraints:
violations[constraint.name] = {
"violation": constraint.evaluate(X),
"feasible": constraint.is_feasible(X),
}
return {"feasible": feasible, "violations": violations}
def update_with_observations(self, X_new: Tensor, y_new: Tensor) -> None:
"""Add new observations to the data prior."""
self.data_prior.add_observations(X_new, y_new)
def summary(self) -> Dict:
"""Return a summary of the current prior state."""
return {
"has_physics_model": self.physics_prior is not None,
"n_physics_constraints": len(self.get_all_constraints()),
"n_observations": self.data_prior.n_observations,
"recommended_mode": self.recommend_surrogate_mode()
if (self.physics_prior or self.data_prior.n_observations >= self.min_data_for_gp)
else "insufficient_data",
}
|