import torch import gpytorch def create_kernel(config): kernel_params = config.kernel_params if 'type' not in kernel_params: raise ValueError("Kernel type must be specified in kernel_params") if kernel_params['type'] == 'RBF': kernel = gpytorch.kernels.RBFKernel(ard_num_dims=config.input_dim, requires_grad=False) kernel_params_ = kernel_params.get('params', {}) kernel_length_scale = kernel_params_["raw_lengthscale"] kernel_length_scale = torch.tensor([kernel_length_scale] * config.input_dim) kernel.initialize(raw_lengthscale=kernel_length_scale) return kernel raise ValueError(f"Unsupported kernel type: {kernel_params['type']}") def create_kernel_mix(kernel_params,input_dim=1): if 'type' not in kernel_params: raise ValueError("Kernel type must be specified in kernel_params") if kernel_params['type'] == 'RBF': kernel = gpytorch.kernels.RBFKernel(ard_num_dims=input_dim, requires_grad=False) kernel_params_ = kernel_params.get('params', {}) kernel_length_scale = kernel_params_["raw_lengthscale"] kernel_length_scale = torch.tensor([kernel_length_scale] * input_dim) kernel.initialize(raw_lengthscale=kernel_length_scale) return kernel raise ValueError(f"Unsupported kernel type: {kernel_params['type']}")