Spaces:
Sleeping
Sleeping
| import torch | |
| models = dict() | |
| def register(name): | |
| def decorator(cls): | |
| models[name] = cls | |
| return cls | |
| return decorator | |
| def load_sd_from_ckpt(ckpt, keys_only=None): | |
| sd = torch.load(ckpt, map_location='cpu')['model']['sd'] | |
| if keys_only is not None: | |
| keys_only_dot = tuple([_ + '.' for _ in keys_only]) | |
| keys_only = set(keys_only) | |
| for k in list(sd.keys()): | |
| if not (k in keys_only or k.startswith(keys_only_dot)): | |
| sd.pop(k) | |
| return sd | |
| def make(spec, load_sd=False): | |
| args = spec.get('args') | |
| if args is None: | |
| args = dict() | |
| model = models[spec['name']](**args) | |
| print('args', args) | |
| if spec.get('load_ckpt') is not None: | |
| sd = load_sd_from_ckpt(spec['load_ckpt'], spec.get('load_ckpt_keys_only')) | |
| model.load_state_dict(sd, strict=False) | |
| if load_sd: | |
| model.load_state_dict(spec['sd']) | |
| return model | |
| def make_identity(): | |
| return torch.nn.Identity() | |