File size: 5,781 Bytes
41a65be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""PriorManager: orchestrates combining data priors and physics priors."""

from typing import Callable, Dict, List, Optional, Tuple

import torch
from torch import Tensor

from physics_informed_bo.priors.data_prior import DataPrior
from physics_informed_bo.priors.physics_prior import PhysicsPrior
from physics_informed_bo.models.hybrid_model import HybridSurrogate


class PriorManager:
    """Manages the combination of physics priors and data priors.



    Determines the best surrogate model mode based on available information:

    - No data, has physics → physics_only mode

    - Small data + physics → physics_as_mean mode (strong physics prior)

    - Medium data + physics → weighted_ensemble mode

    - Large data, poor physics → gp_only mode



    Also handles constraint aggregation and prior validation.

    """

    def __init__(

        self,

        physics_prior: Optional[PhysicsPrior] = None,

        data_prior: Optional[DataPrior] = None,

        min_data_for_gp: int = 3,

        data_threshold_for_ensemble: int = 20,

        data_threshold_for_gp_only: int = 50,

    ):
        self.physics_prior = physics_prior
        self.data_prior = data_prior or DataPrior()
        self.min_data_for_gp = min_data_for_gp
        self.data_threshold_for_ensemble = data_threshold_for_ensemble
        self.data_threshold_for_gp_only = data_threshold_for_gp_only

    def recommend_surrogate_mode(self) -> str:
        """Recommend the best surrogate model mode based on available priors."""
        n_data = self.data_prior.n_observations
        has_physics = self.physics_prior is not None

        if not has_physics:
            if n_data < self.min_data_for_gp:
                raise ValueError(
                    f"Need at least {self.min_data_for_gp} data points or a physics "
                    f"model. Got {n_data} data points and no physics model."
                )
            return "gp_only"

        if n_data < self.min_data_for_gp:
            return "physics_only"
        elif n_data < self.data_threshold_for_ensemble:
            return "physics_as_mean"
        elif n_data < self.data_threshold_for_gp_only:
            return "weighted_ensemble"
        else:
            # Lots of data: check if physics is still useful
            return "weighted_ensemble"

    def build_surrogate(

        self,

        mode: Optional[str] = None,

        kernel: str = "matern",

        noise_variance: float = 0.01,

        device: str = "cpu",

        dtype: torch.dtype = torch.float64,

    ) -> HybridSurrogate:
        """Build and optionally fit a HybridSurrogate from the available priors.



        Args:

            mode: Override the auto-recommended mode. If None, uses recommend_surrogate_mode().

            kernel: GP kernel type.

            noise_variance: Initial noise variance.

            device: Torch device.

            dtype: Torch dtype.



        Returns:

            A configured (and fitted if data is available) HybridSurrogate.

        """
        if mode is None:
            mode = self.recommend_surrogate_mode()

        physics_fn = self.physics_prior.evaluate if self.physics_prior else None

        if physics_fn is None and mode in ("physics_only", "physics_as_mean", "weighted_ensemble"):
            raise ValueError(f"Mode '{mode}' requires a physics model but none was provided.")

        if physics_fn is None:
            # Use a zero mean function as placeholder
            physics_fn = lambda x: torch.zeros(x.shape[0], dtype=x.dtype, device=x.device)

        surrogate = HybridSurrogate(
            physics_fn=physics_fn,
            mode=mode,
            kernel=kernel,
            noise_variance=noise_variance,
            device=device,
            dtype=dtype,
        )

        # Auto-fit if data is available
        if self.data_prior.n_observations >= self.min_data_for_gp:
            surrogate.fit(self.data_prior.X, self.data_prior.y)

        return surrogate

    def get_all_constraints(self) -> list:
        """Get all constraints from the physics prior."""
        if self.physics_prior is None:
            return []
        return self.physics_prior.constraints

    def validate_candidates(self, X: Tensor) -> Dict:
        """Validate candidate points against physics constraints.



        Returns:

            Dict with feasibility mask and violation details.

        """
        if self.physics_prior is None:
            return {
                "feasible": torch.ones(len(X), dtype=torch.bool),
                "violations": {},
            }

        feasible = self.physics_prior.check_feasibility(X)
        violations = {}
        for constraint in self.physics_prior.constraints:
            violations[constraint.name] = {
                "violation": constraint.evaluate(X),
                "feasible": constraint.is_feasible(X),
            }

        return {"feasible": feasible, "violations": violations}

    def update_with_observations(self, X_new: Tensor, y_new: Tensor) -> None:
        """Add new observations to the data prior."""
        self.data_prior.add_observations(X_new, y_new)

    def summary(self) -> Dict:
        """Return a summary of the current prior state."""
        return {
            "has_physics_model": self.physics_prior is not None,
            "n_physics_constraints": len(self.get_all_constraints()),
            "n_observations": self.data_prior.n_observations,
            "recommended_mode": self.recommend_surrogate_mode()
            if (self.physics_prior or self.data_prior.n_observations >= self.min_data_for_gp)
            else "insufficient_data",
        }