"""Hybrid surrogate model combining physics models with data-driven GP.""" from typing import Callable, Dict, List, Optional, Tuple import torch from torch import Tensor from physics_informed_bo.models.base import SurrogateModel from physics_informed_bo.models.gp_model import PhysicsInformedGP, StandardGP from physics_informed_bo.models.physics_model import PhysicsModel class HybridSurrogate(SurrogateModel): """Hybrid model that combines a physics model with a GP. Provides multiple operating modes: 1. **Physics-as-mean** (default): Physics function is the GP mean, GP learns the residual/discrepancy. 2. **Weighted ensemble**: Weighted combination of physics prediction and GP prediction, with weight adapting based on data. 3. **Physics-only**: Pure physics model when no data is available. 4. **GP-only**: Pure GP when physics model is unreliable. The model automatically transitions from physics-only → hybrid → GP-dominant as more experimental data becomes available. """ def __init__( self, physics_fn: Callable[[Tensor], Tensor], mode: str = "physics_as_mean", kernel: str = "matern", noise_variance: float = 0.01, learn_noise: bool = True, initial_physics_weight: float = 1.0, adapt_weight: bool = True, device: str = "cpu", dtype: torch.dtype = torch.float64, ): """ Args: physics_fn: Physics model callable. Takes (n, d) tensor, returns (n,) tensor. mode: One of 'physics_as_mean', 'weighted_ensemble', 'physics_only', 'gp_only'. kernel: GP kernel type ('rbf' or 'matern'). noise_variance: Initial observation noise variance. learn_noise: Whether to learn noise variance from data. initial_physics_weight: Starting weight for physics model (0 to 1). adapt_weight: Auto-adapt physics weight based on residual analysis. device: Torch device. dtype: Torch dtype. """ self.physics_fn = physics_fn self.mode = mode self.kernel = kernel self.noise_variance = noise_variance self.learn_noise = learn_noise self.physics_weight = initial_physics_weight self.adapt_weight = adapt_weight self.device = torch.device(device) self.dtype = dtype # Internal models self._physics_model = PhysicsModel(physics_fn, noise_std=noise_variance**0.5) self._gp_model: Optional[PhysicsInformedGP] = None self._standard_gp: Optional[StandardGP] = None self._is_fitted = False self._train_X = None self._train_y = None def fit( self, X: Tensor, y: Tensor, training_iterations: int = 200, lr: float = 0.05, ) -> None: """Fit the hybrid model. If mode is 'physics_as_mean', fits a PhysicsInformedGP. If mode is 'weighted_ensemble', fits both physics and standard GP, then determines optimal weighting. """ X = X.to(device=self.device, dtype=self.dtype) y = y.to(device=self.device, dtype=self.dtype) if y.dim() == 1: y = y.unsqueeze(-1) self._train_X = X self._train_y = y if self.mode == "physics_only": self._physics_model.fit(X, y) elif self.mode == "physics_as_mean": self._gp_model = PhysicsInformedGP( physics_fn=self.physics_fn, kernel=self.kernel, noise_variance=self.noise_variance, learn_noise=self.learn_noise, device=str(self.device), dtype=self.dtype, ) self._gp_model.fit(X, y, training_iterations, lr) elif self.mode == "weighted_ensemble": # Fit physics-informed GP self._gp_model = PhysicsInformedGP( physics_fn=self.physics_fn, kernel=self.kernel, noise_variance=self.noise_variance, learn_noise=self.learn_noise, device=str(self.device), dtype=self.dtype, ) self._gp_model.fit(X, y, training_iterations, lr) # Fit standard GP self._standard_gp = StandardGP( kernel=self.kernel, noise_variance=self.noise_variance, learn_noise=self.learn_noise, device=str(self.device), dtype=self.dtype, ) self._standard_gp.fit(X, y, training_iterations, lr) if self.adapt_weight: self._adapt_physics_weight(X, y) elif self.mode == "gp_only": self._standard_gp = StandardGP( kernel=self.kernel, noise_variance=self.noise_variance, learn_noise=self.learn_noise, device=str(self.device), dtype=self.dtype, ) self._standard_gp.fit(X, y, training_iterations, lr) self._is_fitted = True def _adapt_physics_weight(self, X: Tensor, y: Tensor) -> None: """Adapt physics weight based on LOO cross-validation of residuals. If physics model is accurate (small residuals), keep high weight. If physics model is inaccurate, reduce weight toward pure GP. """ with torch.no_grad(): physics_pred = self.physics_fn(X) residuals = y.squeeze() - physics_pred relative_error = (residuals.abs() / (y.squeeze().abs() + 1e-8)).mean() # Sigmoid mapping: high error → low physics weight self.physics_weight = float(torch.sigmoid(-5.0 * (relative_error - 0.5))) def predict(self, X: Tensor) -> Tuple[Tensor, Tensor]: X = X.to(device=self.device, dtype=self.dtype) if self.mode == "physics_only" or not self._is_fitted: return self._physics_model.predict(X) elif self.mode == "physics_as_mean": return self._gp_model.predict(X) elif self.mode == "weighted_ensemble": gp_mean, gp_var = self._gp_model.predict(X) std_mean, std_var = self._standard_gp.predict(X) w = self.physics_weight mean = w * gp_mean + (1 - w) * std_mean variance = w**2 * gp_var + (1 - w) ** 2 * std_var return mean, variance elif self.mode == "gp_only": return self._standard_gp.predict(X) def posterior(self, X: Tensor): if self.mode in ("physics_as_mean", "weighted_ensemble") and self._gp_model: return self._gp_model.posterior(X) elif self.mode == "gp_only" and self._standard_gp: return self._standard_gp.posterior(X) else: return self._physics_model.posterior(X) @property def model(self): """Return the primary BoTorch-compatible model for optimization.""" if self._gp_model is not None: return self._gp_model.model elif self._standard_gp is not None: return self._standard_gp.model return None def get_physics_residuals(self) -> Optional[Tensor]: """Return residuals between physics predictions and training data.""" if self._train_X is None or self._train_y is None: return None with torch.no_grad(): physics_pred = self.physics_fn(self._train_X) return self._train_y.squeeze() - physics_pred def physics_model_quality(self) -> Dict: """Assess how well the physics model matches the data.""" if self._train_X is None: return {"status": "no_data"} residuals = self.get_physics_residuals() rmse = float((residuals**2).mean().sqrt()) mae = float(residuals.abs().mean()) r2 = float( 1 - (residuals**2).sum() / ((self._train_y.squeeze() - self._train_y.mean()) ** 2).sum() ) return { "rmse": rmse, "mae": mae, "r2": r2, "physics_weight": self.physics_weight, "n_observations": len(self._train_X), }