ravimohan19's picture
Upload priors/physics_prior.py with huggingface_hub
5907c46 verified
"""Physics-based prior: encode physical models and constraints."""
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional, Tuple
import torch
from torch import Tensor
@dataclass
class PhysicsConstraint:
"""A physical constraint that candidate points must satisfy.
Examples:
- Conservation laws: sum of mass fractions = 1
- Thermodynamic feasibility: Gibbs free energy < 0
- Kinetic limits: reaction rate > 0
"""
name: str
constraint_fn: Callable[[Tensor], Tensor] # Returns constraint violation (<=0 is feasible)
constraint_type: str = "inequality" # 'inequality' (<=0) or 'equality' (==0)
tolerance: float = 1e-6
def evaluate(self, X: Tensor) -> Tensor:
"""Evaluate constraint. Returns violation amount (negative = feasible)."""
return self.constraint_fn(X)
def is_feasible(self, X: Tensor) -> Tensor:
"""Check if points satisfy the constraint. Returns boolean tensor."""
violation = self.evaluate(X)
if self.constraint_type == "equality":
return violation.abs() <= self.tolerance
return violation <= self.tolerance
@dataclass
class PhysicsPrior:
"""Encapsulates physics knowledge for Bayesian optimization.
Combines:
- A physics model function (used as GP mean)
- Physical constraints (feasibility conditions)
- Known parameter bounds from physics
- Domain-specific knowledge about the objective landscape
"""
physics_fn: Callable[[Tensor], Tensor]
constraints: List[PhysicsConstraint] = field(default_factory=list)
parameter_bounds: Optional[Dict[str, Tuple[float, float]]] = None
known_optima: Optional[List[Dict]] = None # Known good regions from physics
model_fidelity: float = 1.0 # Confidence in physics model (0 to 1)
def evaluate(self, X: Tensor) -> Tensor:
"""Evaluate the physics model at X."""
return self.physics_fn(X)
def add_constraint(
self,
name: str,
constraint_fn: Callable[[Tensor], Tensor],
constraint_type: str = "inequality",
tolerance: float = 1e-6,
) -> None:
"""Add a physical constraint."""
self.constraints.append(
PhysicsConstraint(name, constraint_fn, constraint_type, tolerance)
)
def check_feasibility(self, X: Tensor) -> Tensor:
"""Check all constraints. Returns boolean tensor (True = all constraints satisfied)."""
if not self.constraints:
return torch.ones(len(X), dtype=torch.bool)
feasible = torch.ones(len(X), dtype=torch.bool)
for constraint in self.constraints:
feasible &= constraint.is_feasible(X)
return feasible
def constraint_violation(self, X: Tensor) -> Tensor:
"""Compute total constraint violation for each point."""
if not self.constraints:
return torch.zeros(len(X))
total_violation = torch.zeros(len(X), dtype=X.dtype, device=X.device)
for constraint in self.constraints:
violation = constraint.evaluate(X)
# Only count positive violations (infeasible)
total_violation += torch.clamp(violation, min=0.0)
return total_violation
def get_feasible_subset(self, X: Tensor) -> Tensor:
"""Filter X to only feasible points."""
mask = self.check_feasibility(X)
return X[mask]