Spaces:
Sleeping
Sleeping
File size: 1,019 Bytes
37163a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
import torch
models = dict()
def register(name):
def decorator(cls):
models[name] = cls
return cls
return decorator
def load_sd_from_ckpt(ckpt, keys_only=None):
sd = torch.load(ckpt, map_location='cpu')['model']['sd']
if keys_only is not None:
keys_only_dot = tuple([_ + '.' for _ in keys_only])
keys_only = set(keys_only)
for k in list(sd.keys()):
if not (k in keys_only or k.startswith(keys_only_dot)):
sd.pop(k)
return sd
def make(spec, load_sd=False):
args = spec.get('args')
if args is None:
args = dict()
model = models[spec['name']](**args)
print('args', args)
if spec.get('load_ckpt') is not None:
sd = load_sd_from_ckpt(spec['load_ckpt'], spec.get('load_ckpt_keys_only'))
model.load_state_dict(sd, strict=False)
if load_sd:
model.load_state_dict(spec['sd'])
return model
@register('identity')
def make_identity():
return torch.nn.Identity()
|