| import numpy as np |
| import gradio as gr |
| import pandas as pd |
| from sklearn.preprocessing import MinMaxScaler |
| from surrogate import CrabNetSurrogateModel, PARAM_BOUNDS |
| from pydantic import ( |
| BaseModel, |
| ValidationError, |
| ValidationInfo, |
| field_validator, |
| model_validator, |
| ) |
|
|
| model = CrabNetSurrogateModel() |
|
|
| |
| example_parameterization = { |
| "N": 3, |
| "alpha": 0.5, |
| "d_model": 512, |
| "dim_feedforward": 2048, |
| "dropout": 0.1, |
| "emb_scaler": 0.5, |
| "epochs_step": 10, |
| "eps": 0.000001, |
| "fudge": 0.02, |
| "heads": 4, |
| "k": 6, |
| "lr": 0.001, |
| "pe_resolution": 5000, |
| "ple_resolution": 5000, |
| "pos_scaler": 0.5, |
| "weight_decay": 0, |
| "batch_size": 32, |
| "out_hidden4": 128, |
| "betas1": 0.9, |
| "betas2": 0.999, |
| "bias": False, |
| "criterion": "RobustL1", |
| "elem_prop": "mat2vec", |
| "train_frac": 0.5, |
| } |
|
|
| example_results = model.surrogate_evaluate([example_parameterization]) |
| example_result = example_results[0] |
|
|
| scalers = { |
| param_info["name"]: MinMaxScaler() |
| for param_info in PARAM_BOUNDS |
| if param_info["type"] == "range" |
| } |
|
|
|
|
| class BlindedParameterization(BaseModel): |
| x1: float |
| x2: float |
| x3: float |
| x4: float |
| x5: float |
| x6: float |
| x7: float |
| x8: float |
| x9: float |
| x10: float |
| x11: float |
| x12: float |
| x13: float |
| x14: float |
| x15: float |
| x16: float |
| x17: float |
| x18: float |
| x19: float |
| x20: float |
| c1: bool |
| c2: str |
| c3: str |
| f1: float |
|
|
| @field_validator("*") |
| def check_bounds(cls, v: int, info: ValidationInfo) -> int: |
| param = next( |
| (item for item in PARAM_BOUNDS if item["name"] == info.field_name), |
| None, |
| ) |
| if param is None: |
| return v |
|
|
| if param["type"] == "range": |
| min_val, max_val = param["bounds"] |
| if not min_val <= v <= max_val: |
| raise ValueError( |
| f"{info.field_name} must be between {min_val} and {max_val}" |
| ) |
| elif param["type"] == "choice": |
| if v not in param["values"]: |
| raise ValueError(f"{info.field_name} must be one of {param['values']}") |
|
|
| return v |
|
|
| @model_validator(mode="after") |
| def check_constraints(self) -> "BlindedParameterization": |
| if self.x19 > self.x20: |
| raise ValueError( |
| f"Received x19={self.x19} which should be less than x20={self.x20}" |
| ) |
| if self.x6 + self.x15 > 1.0: |
| raise ValueError( |
| f"Received x6={self.x6} and x15={self.x15} which should sum to less than or equal to 1.0" |
| ) |
|
|
|
|
| def evaluate(*args): |
| |
| params_df = pd.DataFrame([args], columns=[param["name"] for param in PARAM_BOUNDS]) |
|
|
| |
| BlindedParameterization(**params_df.to_dict("records")[0]) |
|
|
| |
| for param_info in PARAM_BOUNDS: |
| key = param_info["name"] |
| if param_info["type"] == "range": |
| scaler = scalers[key] |
| params_df[key] = scaler.inverse_transform(params_df[[key]]) |
| elif param_info["type"] == "choice": |
| |
| choice_index = int(params_df[key].str.split("_").str[-1].iloc[0]) |
| params_df[key] = param_info["values"][choice_index] |
|
|
| |
| params_list = params_df.to_dict("records") |
|
|
| |
| results = model.surrogate_evaluate(params_list) |
|
|
| |
| results_list = [list(result.values()) for result in results] |
| return results_list |
|
|
|
|
| def get_interface(param_info, numeric_index, choice_index): |
| key = param_info["name"] |
| default_value = example_parameterization[key] |
| if param_info["type"] == "range": |
| |
| scaler = scalers[key] |
| scaler.fit([[bound] for bound in param_info["bounds"]]) |
| scaled_value = scaler.transform([[default_value]])[0][0] |
| scaled_bounds = scaler.transform([[bound] for bound in param_info["bounds"]]) |
| label = f"f1" if key == "train_frac" else f"x{numeric_index}" |
| return ( |
| gr.Slider( |
| value=scaled_value, |
| minimum=scaled_bounds[0][0], |
| maximum=scaled_bounds[1][0], |
| label=label, |
| step=(scaled_bounds[1][0] - scaled_bounds[0][0]) / 100, |
| ), |
| numeric_index + 1, |
| choice_index, |
| ) |
| elif param_info["type"] == "choice": |
| return ( |
| gr.Radio( |
| choices=[ |
| f"c{choice_index}_{i}" for i in range(len(param_info["values"])) |
| ], |
| label=f"c{choice_index}", |
| value=f"c{choice_index}_{param_info['values'].index(default_value)}", |
| ), |
| numeric_index, |
| choice_index + 1, |
| ) |
|
|
|
|
| numeric_index = 1 |
| choice_index = 1 |
| inputs = [] |
| for param in PARAM_BOUNDS: |
| input, numeric_index, choice_index = get_interface( |
| param, numeric_index, choice_index |
| ) |
| inputs.append(input) |
|
|
| iface = gr.Interface( |
| title="CrabNetSurrogateModel", |
| fn=evaluate, |
| inputs=inputs, |
| outputs=gr.Numpy( |
| value=np.array([list(example_result.values())]), |
| headers=[f"y{i+1}" for i in range(len(example_result))], |
| col_count=(len(example_result), "fixed"), |
| datatype=["number"] * len(example_result), |
| ), |
| description=""" |
| `y1`, `y2`, `y3`, and `y4`, should all be minimized. `y1` and `y2` are |
| correlated, whereas `y1` and `y2` are both anticorrelated with `y3`. `y1`, |
| `y2`, and `y3` are stochastic (heteroskedastic, parameter-free noise), |
| whereas `y4` is deterministic, but still considered 'black-box'. In other |
| words, repeat calls with the same input arguments will result in different |
| values for `y1`, `y2`, and `y3`, but the same value for `y4`. |
| |
| If `y1` is less than 0.2, the result is considered "bad" no matter how good |
| the other values are. If `y2` is less than 0.7, the result is considered |
| "bad" no matter how good the other values are. If `y3` is greater than 1800, |
| the result is considered "bad" no matter how good the other values are. If `y4` |
| is greater than 40e6, the result is considered "bad" no matter how good the |
| other values are. |
| |
| `fidelity1` is a fidelity parameter. 0 is the lowest fidelity, and 1 is the |
| highest fidelity. The higher the fidelity, typically the more expensive the |
| evaluation. However, this also typically means higher quality and relevance |
| to the optimization campaign goals. `fidelity1` and `y3` are |
| correlated. |
| """, |
| ) |
| iface.launch() |
|
|