ravimohan19 commited on
Commit
5907c46
·
verified ·
1 Parent(s): d70a716

Upload priors/physics_prior.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. priors/physics_prior.py +95 -0
priors/physics_prior.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Physics-based prior: encode physical models and constraints."""
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Callable, Dict, List, Optional, Tuple
5
+
6
+ import torch
7
+ from torch import Tensor
8
+
9
+
10
+ @dataclass
11
+ class PhysicsConstraint:
12
+ """A physical constraint that candidate points must satisfy.
13
+
14
+ Examples:
15
+ - Conservation laws: sum of mass fractions = 1
16
+ - Thermodynamic feasibility: Gibbs free energy < 0
17
+ - Kinetic limits: reaction rate > 0
18
+ """
19
+
20
+ name: str
21
+ constraint_fn: Callable[[Tensor], Tensor] # Returns constraint violation (<=0 is feasible)
22
+ constraint_type: str = "inequality" # 'inequality' (<=0) or 'equality' (==0)
23
+ tolerance: float = 1e-6
24
+
25
+ def evaluate(self, X: Tensor) -> Tensor:
26
+ """Evaluate constraint. Returns violation amount (negative = feasible)."""
27
+ return self.constraint_fn(X)
28
+
29
+ def is_feasible(self, X: Tensor) -> Tensor:
30
+ """Check if points satisfy the constraint. Returns boolean tensor."""
31
+ violation = self.evaluate(X)
32
+ if self.constraint_type == "equality":
33
+ return violation.abs() <= self.tolerance
34
+ return violation <= self.tolerance
35
+
36
+
37
+ @dataclass
38
+ class PhysicsPrior:
39
+ """Encapsulates physics knowledge for Bayesian optimization.
40
+
41
+ Combines:
42
+ - A physics model function (used as GP mean)
43
+ - Physical constraints (feasibility conditions)
44
+ - Known parameter bounds from physics
45
+ - Domain-specific knowledge about the objective landscape
46
+ """
47
+
48
+ physics_fn: Callable[[Tensor], Tensor]
49
+ constraints: List[PhysicsConstraint] = field(default_factory=list)
50
+ parameter_bounds: Optional[Dict[str, Tuple[float, float]]] = None
51
+ known_optima: Optional[List[Dict]] = None # Known good regions from physics
52
+ model_fidelity: float = 1.0 # Confidence in physics model (0 to 1)
53
+
54
+ def evaluate(self, X: Tensor) -> Tensor:
55
+ """Evaluate the physics model at X."""
56
+ return self.physics_fn(X)
57
+
58
+ def add_constraint(
59
+ self,
60
+ name: str,
61
+ constraint_fn: Callable[[Tensor], Tensor],
62
+ constraint_type: str = "inequality",
63
+ tolerance: float = 1e-6,
64
+ ) -> None:
65
+ """Add a physical constraint."""
66
+ self.constraints.append(
67
+ PhysicsConstraint(name, constraint_fn, constraint_type, tolerance)
68
+ )
69
+
70
+ def check_feasibility(self, X: Tensor) -> Tensor:
71
+ """Check all constraints. Returns boolean tensor (True = all constraints satisfied)."""
72
+ if not self.constraints:
73
+ return torch.ones(len(X), dtype=torch.bool)
74
+
75
+ feasible = torch.ones(len(X), dtype=torch.bool)
76
+ for constraint in self.constraints:
77
+ feasible &= constraint.is_feasible(X)
78
+ return feasible
79
+
80
+ def constraint_violation(self, X: Tensor) -> Tensor:
81
+ """Compute total constraint violation for each point."""
82
+ if not self.constraints:
83
+ return torch.zeros(len(X))
84
+
85
+ total_violation = torch.zeros(len(X), dtype=X.dtype, device=X.device)
86
+ for constraint in self.constraints:
87
+ violation = constraint.evaluate(X)
88
+ # Only count positive violations (infeasible)
89
+ total_violation += torch.clamp(violation, min=0.0)
90
+ return total_violation
91
+
92
+ def get_feasible_subset(self, X: Tensor) -> Tensor:
93
+ """Filter X to only feasible points."""
94
+ mask = self.check_feasibility(X)
95
+ return X[mask]