"""Physics-based prior: encode physical models and constraints.""" from dataclasses import dataclass, field from typing import Callable, Dict, List, Optional, Tuple import torch from torch import Tensor @dataclass class PhysicsConstraint: """A physical constraint that candidate points must satisfy. Examples: - Conservation laws: sum of mass fractions = 1 - Thermodynamic feasibility: Gibbs free energy < 0 - Kinetic limits: reaction rate > 0 """ name: str constraint_fn: Callable[[Tensor], Tensor] # Returns constraint violation (<=0 is feasible) constraint_type: str = "inequality" # 'inequality' (<=0) or 'equality' (==0) tolerance: float = 1e-6 def evaluate(self, X: Tensor) -> Tensor: """Evaluate constraint. Returns violation amount (negative = feasible).""" return self.constraint_fn(X) def is_feasible(self, X: Tensor) -> Tensor: """Check if points satisfy the constraint. Returns boolean tensor.""" violation = self.evaluate(X) if self.constraint_type == "equality": return violation.abs() <= self.tolerance return violation <= self.tolerance @dataclass class PhysicsPrior: """Encapsulates physics knowledge for Bayesian optimization. Combines: - A physics model function (used as GP mean) - Physical constraints (feasibility conditions) - Known parameter bounds from physics - Domain-specific knowledge about the objective landscape """ physics_fn: Callable[[Tensor], Tensor] constraints: List[PhysicsConstraint] = field(default_factory=list) parameter_bounds: Optional[Dict[str, Tuple[float, float]]] = None known_optima: Optional[List[Dict]] = None # Known good regions from physics model_fidelity: float = 1.0 # Confidence in physics model (0 to 1) def evaluate(self, X: Tensor) -> Tensor: """Evaluate the physics model at X.""" return self.physics_fn(X) def add_constraint( self, name: str, constraint_fn: Callable[[Tensor], Tensor], constraint_type: str = "inequality", tolerance: float = 1e-6, ) -> None: """Add a physical constraint.""" self.constraints.append( PhysicsConstraint(name, constraint_fn, constraint_type, tolerance) ) def check_feasibility(self, X: Tensor) -> Tensor: """Check all constraints. Returns boolean tensor (True = all constraints satisfied).""" if not self.constraints: return torch.ones(len(X), dtype=torch.bool) feasible = torch.ones(len(X), dtype=torch.bool) for constraint in self.constraints: feasible &= constraint.is_feasible(X) return feasible def constraint_violation(self, X: Tensor) -> Tensor: """Compute total constraint violation for each point.""" if not self.constraints: return torch.zeros(len(X)) total_violation = torch.zeros(len(X), dtype=X.dtype, device=X.device) for constraint in self.constraints: violation = constraint.evaluate(X) # Only count positive violations (infeasible) total_violation += torch.clamp(violation, min=0.0) return total_violation def get_feasible_subset(self, X: Tensor) -> Tensor: """Filter X to only feasible points.""" mask = self.check_feasibility(X) return X[mask]