cesarali's picture
manual runtime bundle push from load_and_push.ipynb
5686f5b verified
import torch
import numpy as np
def define_mask(X,databatch):
"""
returns mask according to the process dimensions
Args
X (Tensor[B,D]):
"""
D = X.size(1)
B = X.size(0)
mask = torch.arange(D, device=X.device).expand(B, -1) < databatch.process_dimension # Shape [B, D]
return mask
# Define Mesh Points
def define_mesh_points(total_points = 100,n_dims = 1, ranges=[])->torch.Tensor: # Number of dimensions
"""
returns a points form the mesh defined in the range given the list ranges
"""
# Calculate the number of points per dimension
number_of_points = int(np.round(total_points ** (1 / n_dims)))
if len(ranges) == n_dims:
# Define the range for each dimension
axes_grid = [torch.linspace(ranges[_][0], ranges[_][1], number_of_points) for _ in range(n_dims)]
else:
axes_grid = [torch.linspace(-1.0, 1.0, number_of_points) for _ in range(n_dims)]
# Create a meshgrid for n dimensions
meshgrids = torch.meshgrid(*axes_grid, indexing='ij')
# Stack and reshape to get the observation points
points = torch.stack(meshgrids, dim=-1).view(-1, n_dims)
return points