| | from collections import OrderedDict |
| |
|
| | import torch |
| |
|
| |
|
| | def normalize_activation(x, eps=1e-10): |
| | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) |
| | return x / (norm_factor + eps) |
| |
|
| |
|
| | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): |
| | |
| | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ |
| | + f'master/lpips/weights/v{version}/{net_type}.pth' |
| |
|
| | |
| | old_state_dict = torch.hub.load_state_dict_from_url( |
| | url, progress=True, |
| | map_location=None if torch.cuda.is_available() else torch.device('cpu') |
| | ) |
| |
|
| | |
| | new_state_dict = OrderedDict() |
| | for key, val in old_state_dict.items(): |
| | new_key = key |
| | new_key = new_key.replace('lin', '') |
| | new_key = new_key.replace('model.', '') |
| | new_state_dict[new_key] = val |
| |
|
| | return new_state_dict |
| |
|