File size: 3,535 Bytes
5907c46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""Physics-based prior: encode physical models and constraints."""

from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional, Tuple

import torch
from torch import Tensor


@dataclass
class PhysicsConstraint:
    """A physical constraint that candidate points must satisfy.



    Examples:

        - Conservation laws: sum of mass fractions = 1

        - Thermodynamic feasibility: Gibbs free energy < 0

        - Kinetic limits: reaction rate > 0

    """

    name: str
    constraint_fn: Callable[[Tensor], Tensor]  # Returns constraint violation (<=0 is feasible)
    constraint_type: str = "inequality"  # 'inequality' (<=0) or 'equality' (==0)
    tolerance: float = 1e-6

    def evaluate(self, X: Tensor) -> Tensor:
        """Evaluate constraint. Returns violation amount (negative = feasible)."""
        return self.constraint_fn(X)

    def is_feasible(self, X: Tensor) -> Tensor:
        """Check if points satisfy the constraint. Returns boolean tensor."""
        violation = self.evaluate(X)
        if self.constraint_type == "equality":
            return violation.abs() <= self.tolerance
        return violation <= self.tolerance


@dataclass
class PhysicsPrior:
    """Encapsulates physics knowledge for Bayesian optimization.



    Combines:

    - A physics model function (used as GP mean)

    - Physical constraints (feasibility conditions)

    - Known parameter bounds from physics

    - Domain-specific knowledge about the objective landscape

    """

    physics_fn: Callable[[Tensor], Tensor]
    constraints: List[PhysicsConstraint] = field(default_factory=list)
    parameter_bounds: Optional[Dict[str, Tuple[float, float]]] = None
    known_optima: Optional[List[Dict]] = None  # Known good regions from physics
    model_fidelity: float = 1.0  # Confidence in physics model (0 to 1)

    def evaluate(self, X: Tensor) -> Tensor:
        """Evaluate the physics model at X."""
        return self.physics_fn(X)

    def add_constraint(

        self,

        name: str,

        constraint_fn: Callable[[Tensor], Tensor],

        constraint_type: str = "inequality",

        tolerance: float = 1e-6,

    ) -> None:
        """Add a physical constraint."""
        self.constraints.append(
            PhysicsConstraint(name, constraint_fn, constraint_type, tolerance)
        )

    def check_feasibility(self, X: Tensor) -> Tensor:
        """Check all constraints. Returns boolean tensor (True = all constraints satisfied)."""
        if not self.constraints:
            return torch.ones(len(X), dtype=torch.bool)

        feasible = torch.ones(len(X), dtype=torch.bool)
        for constraint in self.constraints:
            feasible &= constraint.is_feasible(X)
        return feasible

    def constraint_violation(self, X: Tensor) -> Tensor:
        """Compute total constraint violation for each point."""
        if not self.constraints:
            return torch.zeros(len(X))

        total_violation = torch.zeros(len(X), dtype=X.dtype, device=X.device)
        for constraint in self.constraints:
            violation = constraint.evaluate(X)
            # Only count positive violations (infeasible)
            total_violation += torch.clamp(violation, min=0.0)
        return total_violation

    def get_feasible_subset(self, X: Tensor) -> Tensor:
        """Filter X to only feasible points."""
        mask = self.check_feasibility(X)
        return X[mask]