import torch import numpy as np def define_mask(X,databatch): """ returns mask according to the process dimensions Args X (Tensor[B,D]): """ D = X.size(1) B = X.size(0) mask = torch.arange(D, device=X.device).expand(B, -1) < databatch.process_dimension # Shape [B, D] return mask # Define Mesh Points def define_mesh_points(total_points = 100,n_dims = 1, ranges=[])->torch.Tensor: # Number of dimensions """ returns a points form the mesh defined in the range given the list ranges """ # Calculate the number of points per dimension number_of_points = int(np.round(total_points ** (1 / n_dims))) if len(ranges) == n_dims: # Define the range for each dimension axes_grid = [torch.linspace(ranges[_][0], ranges[_][1], number_of_points) for _ in range(n_dims)] else: axes_grid = [torch.linspace(-1.0, 1.0, number_of_points) for _ in range(n_dims)] # Create a meshgrid for n dimensions meshgrids = torch.meshgrid(*axes_grid, indexing='ij') # Stack and reshape to get the observation points points = torch.stack(meshgrids, dim=-1).view(-1, n_dims) return points