| | import torch |
| | import numpy as np |
| | import hashlib |
| |
|
| | def wrap(func, *args, unsqueeze=False): |
| | """ |
| | Wrap a torch function so it can be called with NumPy arrays. |
| | Input and return types are seamlessly converted. |
| | """ |
| | |
| | |
| | args = list(args) |
| | for i, arg in enumerate(args): |
| | if type(arg) == np.ndarray: |
| | args[i] = torch.from_numpy(arg) |
| | if unsqueeze: |
| | args[i] = args[i].unsqueeze(0) |
| | |
| | result = func(*args) |
| | |
| | |
| | if isinstance(result, tuple): |
| | result = list(result) |
| | for i, res in enumerate(result): |
| | if type(res) == torch.Tensor: |
| | if unsqueeze: |
| | res = res.squeeze(0) |
| | result[i] = res.numpy() |
| | return tuple(result) |
| | elif type(result) == torch.Tensor: |
| | if unsqueeze: |
| | result = result.squeeze(0) |
| | return result.numpy() |
| | else: |
| | return result |
| | |
| | def deterministic_random(min_value, max_value, data): |
| | digest = hashlib.sha256(data.encode()).digest() |
| | raw_value = int.from_bytes(digest[:4], byteorder='little', signed=False) |
| | return int(raw_value / (2**32 - 1) * (max_value - min_value)) + min_value |
| |
|
| | def load_pretrained_weights(model, checkpoint): |
| | """Load pretrianed weights to model |
| | Incompatible layers (unmatched in name or size) will be ignored |
| | Args: |
| | - model (nn.Module): network model, which must not be nn.DataParallel |
| | - weight_path (str): path to pretrained weights |
| | """ |
| | import collections |
| | if 'state_dict' in checkpoint: |
| | state_dict = checkpoint['state_dict'] |
| | else: |
| | state_dict = checkpoint |
| | model_dict = model.state_dict() |
| | new_state_dict = collections.OrderedDict() |
| | matched_layers, discarded_layers = [], [] |
| | for k, v in state_dict.items(): |
| | |
| | |
| | if k.startswith('module.'): |
| | k = k[7:] |
| | if k in model_dict and model_dict[k].size() == v.size(): |
| | new_state_dict[k] = v |
| | matched_layers.append(k) |
| | else: |
| | discarded_layers.append(k) |
| | |
| | model_dict.update(new_state_dict) |
| |
|
| | model.load_state_dict(model_dict) |
| | print('load_weight', len(matched_layers)) |
| | |
| | return model |
| |
|