Spaces:
Build error
Build error
| # Modified from: | |
| # https://github.com/anibali/pytorch-stacked-hourglass | |
| # https://github.com/bearpaw/pytorch-pose | |
| import torch | |
| from torch.nn import Conv2d, ModuleList | |
| def change_hg_outputs(model, indices): | |
| """Change the output classes of the model. | |
| Args: | |
| model: The model to modify. | |
| indices: An array of indices describing the new model outputs. For example, [3, 4, None] | |
| will modify the model to have 3 outputs, the first two of which have parameters | |
| copied from the fourth and fifth outputs of the original model. | |
| """ | |
| with torch.no_grad(): | |
| new_n_outputs = len(indices) | |
| new_score = ModuleList() | |
| for conv in model.score: | |
| new_conv = Conv2d(conv.in_channels, new_n_outputs, conv.kernel_size, conv.stride) | |
| new_conv = new_conv.to(conv.weight.device, conv.weight.dtype) | |
| for i, index in enumerate(indices): | |
| if index is not None: | |
| new_conv.weight[i] = conv.weight[index] | |
| new_conv.bias[i] = conv.bias[index] | |
| new_score.append(new_conv) | |
| model.score = new_score | |
| new_score_ = ModuleList() | |
| for conv in model.score_: | |
| new_conv = Conv2d(new_n_outputs, conv.out_channels, conv.kernel_size, conv.stride) | |
| new_conv = new_conv.to(conv.weight.device, conv.weight.dtype) | |
| for i, index in enumerate(indices): | |
| if index is not None: | |
| new_conv.weight[:, i] = conv.weight[:, index] | |
| new_conv.bias = conv.bias | |
| new_score_.append(new_conv) | |
| model.score_ = new_score_ | |