| from enum import Enum | |
| import yaml | |
| from easydict import EasyDict as edict | |
| import torch.nn as nn | |
| import torch | |
| def load_yaml(path): | |
| with open(path, 'r') as f: | |
| return edict(yaml.safe_load(f)) | |
| def move_to_device(obj, device): | |
| if isinstance(obj, nn.Module): | |
| return obj.to(device) | |
| if torch.is_tensor(obj): | |
| return obj.to(device) | |
| if isinstance(obj, (tuple, list)): | |
| return [move_to_device(el, device) for el in obj] | |
| if isinstance(obj, dict): | |
| return {name: move_to_device(val, device) for name, val in obj.items()} | |
| raise ValueError(f'Unexpected type {type(obj)}') | |
| class SmallMode(Enum): | |
| DROP = "drop" | |
| UPSCALE = "upscale" | |