Spaces:
Running
Running
| import torch | |
| class IdentityModel(torch.nn.Module): | |
| def __init__(self, n_out_coords=3): | |
| super().__init__() | |
| self.n_out_coords = n_out_coords | |
| def forward(self, data): | |
| # data: instance of EventBatch | |
| inputs_v = data.input_vectors # four-momenta | |
| betas = torch.ones(data.input_vectors.shape[0]).to(inputs_v.device) | |
| norm_inputs_v = torch.norm(inputs_v, dim=1).unsqueeze(1) | |
| #print("inputs_v.shape", inputs_v.shape) | |
| #print("betas.shape", betas.shape) | |
| #print("norm_inputs_v.shape", norm_inputs_v.shape) | |
| #print("betas unsqueezed shape", betas.unsqueeze(1).shape) | |
| x = torch.cat([inputs_v / norm_inputs_v, betas.unsqueeze(1)], dim=1) | |
| return x | |
| def get_model(args): | |
| return IdentityModel() | |