Tolerance
Browse files- surrogate.py +4 -3
surrogate.py
CHANGED
|
@@ -39,6 +39,7 @@ PARAM_BOUNDS = [
|
|
| 39 |
{"name": "train_frac", "type": "range", "bounds": [0.01, 1.0]},
|
| 40 |
]
|
| 41 |
|
|
|
|
| 42 |
|
| 43 |
class Parameterization(BaseModel):
|
| 44 |
N: float # int
|
|
@@ -77,7 +78,7 @@ class Parameterization(BaseModel):
|
|
| 77 |
|
| 78 |
if param["type"] == "range":
|
| 79 |
min_val, max_val = param["bounds"]
|
| 80 |
-
if not min_val <= v <= max_val:
|
| 81 |
raise ValueError(
|
| 82 |
f"{info.field_name} must be between {min_val} and {max_val}"
|
| 83 |
)
|
|
@@ -89,11 +90,11 @@ class Parameterization(BaseModel):
|
|
| 89 |
|
| 90 |
@model_validator(mode="after")
|
| 91 |
def check_constraints(self) -> "Parameterization":
|
| 92 |
-
if self.betas1 > self.betas2:
|
| 93 |
raise ValueError(
|
| 94 |
f"Received betas1={self.betas1} which should be less than betas2={self.betas2}"
|
| 95 |
)
|
| 96 |
-
if self.emb_scaler + self.pos_scaler > 1.0:
|
| 97 |
raise ValueError(
|
| 98 |
f"Received emb_scaler={self.emb_scaler} and pos_scaler={self.pos_scaler} which should sum to less than or equal to 1.0" # noqa: E501
|
| 99 |
)
|
|
|
|
| 39 |
{"name": "train_frac", "type": "range", "bounds": [0.01, 1.0]},
|
| 40 |
]
|
| 41 |
|
| 42 |
+
tol = 1e-6
|
| 43 |
|
| 44 |
class Parameterization(BaseModel):
|
| 45 |
N: float # int
|
|
|
|
| 78 |
|
| 79 |
if param["type"] == "range":
|
| 80 |
min_val, max_val = param["bounds"]
|
| 81 |
+
if not (min_val - tol) <= v <= (max_val + tol):
|
| 82 |
raise ValueError(
|
| 83 |
f"{info.field_name} must be between {min_val} and {max_val}"
|
| 84 |
)
|
|
|
|
| 90 |
|
| 91 |
@model_validator(mode="after")
|
| 92 |
def check_constraints(self) -> "Parameterization":
|
| 93 |
+
if (self.betas1 - tol) > (self.betas2 + tol):
|
| 94 |
raise ValueError(
|
| 95 |
f"Received betas1={self.betas1} which should be less than betas2={self.betas2}"
|
| 96 |
)
|
| 97 |
+
if self.emb_scaler + self.pos_scaler - tol > 1.0:
|
| 98 |
raise ValueError(
|
| 99 |
f"Received emb_scaler={self.emb_scaler} and pos_scaler={self.pos_scaler} which should sum to less than or equal to 1.0" # noqa: E501
|
| 100 |
)
|