from collections import OrderedDict import torch def normalize_activation(x, eps=1e-10): norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) return x / (norm_factor + eps) def get_state_dict(net_type: str = "alex", version: str = "0.1"): # build url url = ( "https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/" + f"master/lpips/weights/v{version}/{net_type}.pth" ) # download old_state_dict = torch.hub.load_state_dict_from_url( url, progress=True, map_location=None if torch.cuda.is_available() else torch.device("cpu"), ) # rename keys new_state_dict = OrderedDict() for key, val in old_state_dict.items(): new_key = key new_key = new_key.replace("lin", "") new_key = new_key.replace("model.", "") new_state_dict[new_key] = val return new_state_dict