ravimohan19's picture
Upload priors/prior_manager.py with huggingface_hub
41a65be verified
"""PriorManager: orchestrates combining data priors and physics priors."""
from typing import Callable, Dict, List, Optional, Tuple
import torch
from torch import Tensor
from physics_informed_bo.priors.data_prior import DataPrior
from physics_informed_bo.priors.physics_prior import PhysicsPrior
from physics_informed_bo.models.hybrid_model import HybridSurrogate
class PriorManager:
"""Manages the combination of physics priors and data priors.
Determines the best surrogate model mode based on available information:
- No data, has physics → physics_only mode
- Small data + physics → physics_as_mean mode (strong physics prior)
- Medium data + physics → weighted_ensemble mode
- Large data, poor physics → gp_only mode
Also handles constraint aggregation and prior validation.
"""
def __init__(
self,
physics_prior: Optional[PhysicsPrior] = None,
data_prior: Optional[DataPrior] = None,
min_data_for_gp: int = 3,
data_threshold_for_ensemble: int = 20,
data_threshold_for_gp_only: int = 50,
):
self.physics_prior = physics_prior
self.data_prior = data_prior or DataPrior()
self.min_data_for_gp = min_data_for_gp
self.data_threshold_for_ensemble = data_threshold_for_ensemble
self.data_threshold_for_gp_only = data_threshold_for_gp_only
def recommend_surrogate_mode(self) -> str:
"""Recommend the best surrogate model mode based on available priors."""
n_data = self.data_prior.n_observations
has_physics = self.physics_prior is not None
if not has_physics:
if n_data < self.min_data_for_gp:
raise ValueError(
f"Need at least {self.min_data_for_gp} data points or a physics "
f"model. Got {n_data} data points and no physics model."
)
return "gp_only"
if n_data < self.min_data_for_gp:
return "physics_only"
elif n_data < self.data_threshold_for_ensemble:
return "physics_as_mean"
elif n_data < self.data_threshold_for_gp_only:
return "weighted_ensemble"
else:
# Lots of data: check if physics is still useful
return "weighted_ensemble"
def build_surrogate(
self,
mode: Optional[str] = None,
kernel: str = "matern",
noise_variance: float = 0.01,
device: str = "cpu",
dtype: torch.dtype = torch.float64,
) -> HybridSurrogate:
"""Build and optionally fit a HybridSurrogate from the available priors.
Args:
mode: Override the auto-recommended mode. If None, uses recommend_surrogate_mode().
kernel: GP kernel type.
noise_variance: Initial noise variance.
device: Torch device.
dtype: Torch dtype.
Returns:
A configured (and fitted if data is available) HybridSurrogate.
"""
if mode is None:
mode = self.recommend_surrogate_mode()
physics_fn = self.physics_prior.evaluate if self.physics_prior else None
if physics_fn is None and mode in ("physics_only", "physics_as_mean", "weighted_ensemble"):
raise ValueError(f"Mode '{mode}' requires a physics model but none was provided.")
if physics_fn is None:
# Use a zero mean function as placeholder
physics_fn = lambda x: torch.zeros(x.shape[0], dtype=x.dtype, device=x.device)
surrogate = HybridSurrogate(
physics_fn=physics_fn,
mode=mode,
kernel=kernel,
noise_variance=noise_variance,
device=device,
dtype=dtype,
)
# Auto-fit if data is available
if self.data_prior.n_observations >= self.min_data_for_gp:
surrogate.fit(self.data_prior.X, self.data_prior.y)
return surrogate
def get_all_constraints(self) -> list:
"""Get all constraints from the physics prior."""
if self.physics_prior is None:
return []
return self.physics_prior.constraints
def validate_candidates(self, X: Tensor) -> Dict:
"""Validate candidate points against physics constraints.
Returns:
Dict with feasibility mask and violation details.
"""
if self.physics_prior is None:
return {
"feasible": torch.ones(len(X), dtype=torch.bool),
"violations": {},
}
feasible = self.physics_prior.check_feasibility(X)
violations = {}
for constraint in self.physics_prior.constraints:
violations[constraint.name] = {
"violation": constraint.evaluate(X),
"feasible": constraint.is_feasible(X),
}
return {"feasible": feasible, "violations": violations}
def update_with_observations(self, X_new: Tensor, y_new: Tensor) -> None:
"""Add new observations to the data prior."""
self.data_prior.add_observations(X_new, y_new)
def summary(self) -> Dict:
"""Return a summary of the current prior state."""
return {
"has_physics_model": self.physics_prior is not None,
"n_physics_constraints": len(self.get_all_constraints()),
"n_observations": self.data_prior.n_observations,
"recommended_mode": self.recommend_surrogate_mode()
if (self.physics_prior or self.data_prior.n_observations >= self.min_data_for_gp)
else "insufficient_data",
}