| """Physics-based prior: encode physical models and constraints."""
|
|
|
| from dataclasses import dataclass, field
|
| from typing import Callable, Dict, List, Optional, Tuple
|
|
|
| import torch
|
| from torch import Tensor
|
|
|
|
|
| @dataclass
|
| class PhysicsConstraint:
|
| """A physical constraint that candidate points must satisfy.
|
|
|
| Examples:
|
| - Conservation laws: sum of mass fractions = 1
|
| - Thermodynamic feasibility: Gibbs free energy < 0
|
| - Kinetic limits: reaction rate > 0
|
| """
|
|
|
| name: str
|
| constraint_fn: Callable[[Tensor], Tensor]
|
| constraint_type: str = "inequality"
|
| tolerance: float = 1e-6
|
|
|
| def evaluate(self, X: Tensor) -> Tensor:
|
| """Evaluate constraint. Returns violation amount (negative = feasible)."""
|
| return self.constraint_fn(X)
|
|
|
| def is_feasible(self, X: Tensor) -> Tensor:
|
| """Check if points satisfy the constraint. Returns boolean tensor."""
|
| violation = self.evaluate(X)
|
| if self.constraint_type == "equality":
|
| return violation.abs() <= self.tolerance
|
| return violation <= self.tolerance
|
|
|
|
|
| @dataclass
|
| class PhysicsPrior:
|
| """Encapsulates physics knowledge for Bayesian optimization.
|
|
|
| Combines:
|
| - A physics model function (used as GP mean)
|
| - Physical constraints (feasibility conditions)
|
| - Known parameter bounds from physics
|
| - Domain-specific knowledge about the objective landscape
|
| """
|
|
|
| physics_fn: Callable[[Tensor], Tensor]
|
| constraints: List[PhysicsConstraint] = field(default_factory=list)
|
| parameter_bounds: Optional[Dict[str, Tuple[float, float]]] = None
|
| known_optima: Optional[List[Dict]] = None
|
| model_fidelity: float = 1.0
|
|
|
| def evaluate(self, X: Tensor) -> Tensor:
|
| """Evaluate the physics model at X."""
|
| return self.physics_fn(X)
|
|
|
| def add_constraint(
|
| self,
|
| name: str,
|
| constraint_fn: Callable[[Tensor], Tensor],
|
| constraint_type: str = "inequality",
|
| tolerance: float = 1e-6,
|
| ) -> None:
|
| """Add a physical constraint."""
|
| self.constraints.append(
|
| PhysicsConstraint(name, constraint_fn, constraint_type, tolerance)
|
| )
|
|
|
| def check_feasibility(self, X: Tensor) -> Tensor:
|
| """Check all constraints. Returns boolean tensor (True = all constraints satisfied)."""
|
| if not self.constraints:
|
| return torch.ones(len(X), dtype=torch.bool)
|
|
|
| feasible = torch.ones(len(X), dtype=torch.bool)
|
| for constraint in self.constraints:
|
| feasible &= constraint.is_feasible(X)
|
| return feasible
|
|
|
| def constraint_violation(self, X: Tensor) -> Tensor:
|
| """Compute total constraint violation for each point."""
|
| if not self.constraints:
|
| return torch.zeros(len(X))
|
|
|
| total_violation = torch.zeros(len(X), dtype=X.dtype, device=X.device)
|
| for constraint in self.constraints:
|
| violation = constraint.evaluate(X)
|
|
|
| total_violation += torch.clamp(violation, min=0.0)
|
| return total_violation
|
|
|
| def get_feasible_subset(self, X: Tensor) -> Tensor:
|
| """Filter X to only feasible points."""
|
| mask = self.check_feasibility(X)
|
| return X[mask]
|
|
|