Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch | |
| class_groups = { | |
| # group : indices (assuming 0th position is id) | |
| 0: (), | |
| 1: (1, 2, 3), | |
| 2: (4, 5), | |
| 3: (6, 7), | |
| 4: (8, 9), | |
| 5: (10, 11, 12, 13), | |
| 6: (14, 15), | |
| 7: (16, 17, 18), | |
| 8: (19, 20, 21, 22, 23, 24, 25), | |
| 9: (26, 27, 28), | |
| 10: (29, 30, 31), | |
| 11: (32, 33, 34, 35, 36, 37), | |
| } | |
| class_groups_indices = {g: np.array(ixs)-1 for g, ixs in class_groups.items()} | |
| hierarchy = { | |
| # group : parent (group, label) | |
| 2: (1, 1), | |
| 3: (2, 1), | |
| 4: (2, 1), | |
| 5: (2, 1), | |
| 7: (1, 0), | |
| 8: (6, 0), | |
| 9: (2, 0), | |
| 10: (4, 0), | |
| 11: (4, 0), | |
| } | |
| def make_galaxy_labels_hierarchical(labels: torch.Tensor) -> torch.Tensor: | |
| """ transform groups of galaxy label probabilities to follow the hierarchical order defined in galaxy zoo | |
| more info here: https://www.kaggle.com/c/galaxy-zoo-the-galaxy-challenge/overview/the-galaxy-zoo-decision-tree | |
| labels is a NxL torch tensor, where N is the batch size and L is the number of labels, | |
| all labels should be > 1 | |
| the indices of label groups are listed in class_groups_indices | |
| Return | |
| ------ | |
| hierarchical_labels : NxL torch tensor, where L is the total number of labels | |
| """ | |
| shift = labels.shape[1] > 37 ## in case the id is included at 0th position, shift indices accordingly | |
| index = lambda i: class_groups_indices[i] + shift | |
| for i in range(1, 12): | |
| ## normalize probabilities to 1 | |
| norm = torch.sum(labels[:, index(i)], dim=1, keepdims=True) | |
| norm[norm == 0] += 1e-4 ## add small number to prevent NaNs dividing by zero, yet keep track of gradient | |
| labels[:, index(i)] /= norm | |
| ## renormalize according to hierarchical structure | |
| if i not in [1, 6]: | |
| parent_group_label = labels[:, index(hierarchy[i][0])] | |
| labels[:, index(i)] *= parent_group_label[:, hierarchy[i][1]].unsqueeze(-1) | |
| return labels | |