ravimohan19 commited on
Commit
058f8a4
·
verified ·
1 Parent(s): 1f47014

Upload models/base.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models/base.py +54 -0
models/base.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Abstract base classes for surrogate models."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ from torch import Tensor
8
+
9
+
10
+ class SurrogateModel(ABC):
11
+ """Abstract base class for all surrogate models in the platform.
12
+
13
+ A surrogate model provides predictions (mean + uncertainty) and can be
14
+ updated with new observations.
15
+ """
16
+
17
+ @abstractmethod
18
+ def predict(self, X: Tensor) -> Tuple[Tensor, Tensor]:
19
+ """Return posterior mean and variance at input locations X.
20
+
21
+ Args:
22
+ X: Input tensor of shape (n, d).
23
+
24
+ Returns:
25
+ mean: Predicted mean of shape (n, 1).
26
+ variance: Predicted variance of shape (n, 1).
27
+ """
28
+
29
+ @abstractmethod
30
+ def fit(self, X: Tensor, y: Tensor) -> None:
31
+ """Fit/update the surrogate model with observed data.
32
+
33
+ Args:
34
+ X: Training inputs of shape (n, d).
35
+ y: Training targets of shape (n, 1).
36
+ """
37
+
38
+ @abstractmethod
39
+ def posterior(self, X: Tensor):
40
+ """Return the full posterior distribution at X (for BoTorch compatibility).
41
+
42
+ Args:
43
+ X: Input tensor of shape (batch, n, d).
44
+ """
45
+
46
+ def condition_on_observations(self, X: Tensor, y: Tensor) -> "SurrogateModel":
47
+ """Return a new model conditioned on additional observations.
48
+
49
+ Default implementation refits the model. Subclasses can override
50
+ for fantasy-based conditioning.
51
+ """
52
+ raise NotImplementedError(
53
+ "Fantasy conditioning not implemented for this model. Use fit() instead."
54
+ )