cesarali's picture
manual runtime bundle push from load_and_push.ipynb
5686f5b verified
import torch
import gpytorch
def create_kernel(config):
kernel_params = config.kernel_params
if 'type' not in kernel_params:
raise ValueError("Kernel type must be specified in kernel_params")
if kernel_params['type'] == 'RBF':
kernel = gpytorch.kernels.RBFKernel(ard_num_dims=config.input_dim, requires_grad=False)
kernel_params_ = kernel_params.get('params', {})
kernel_length_scale = kernel_params_["raw_lengthscale"]
kernel_length_scale = torch.tensor([kernel_length_scale] * config.input_dim)
kernel.initialize(raw_lengthscale=kernel_length_scale)
return kernel
raise ValueError(f"Unsupported kernel type: {kernel_params['type']}")
def create_kernel_mix(kernel_params,input_dim=1):
if 'type' not in kernel_params:
raise ValueError("Kernel type must be specified in kernel_params")
if kernel_params['type'] == 'RBF':
kernel = gpytorch.kernels.RBFKernel(ard_num_dims=input_dim, requires_grad=False)
kernel_params_ = kernel_params.get('params', {})
kernel_length_scale = kernel_params_["raw_lengthscale"]
kernel_length_scale = torch.tensor([kernel_length_scale] * input_dim)
kernel.initialize(raw_lengthscale=kernel_length_scale)
return kernel
raise ValueError(f"Unsupported kernel type: {kernel_params['type']}")