| import torch | |
| def mean_activations(tensor): | |
| """Computes mean of activation maps tensor.""" | |
| # squeeze to remove batch dimension | |
| return torch.mean(tensor.detach().cpu(), dim=1).squeeze(dim=0) | |
| def load_weights(model, weights): | |
| """Loads the weights of only the layers present in the given model.""" | |
| pretrained_dict = torch.load(weights, map_location='cpu') | |
| model_dict = model.state_dict() | |
| pretrained_dict = {k: v for k, | |
| v in pretrained_dict.items() if k in model_dict} | |
| model_dict.update(pretrained_dict) | |
| model.load_state_dict(model_dict) | |