File size: 1,019 Bytes
37163a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import torch


models = dict()


def register(name):
    def decorator(cls):
        models[name] = cls
        return cls
    return decorator


def load_sd_from_ckpt(ckpt, keys_only=None):
    sd = torch.load(ckpt, map_location='cpu')['model']['sd']
    if keys_only is not None:
        keys_only_dot = tuple([_ + '.' for _ in keys_only])
        keys_only = set(keys_only)
        for k in list(sd.keys()):
            if not (k in keys_only or k.startswith(keys_only_dot)):
                sd.pop(k)
    return sd


def make(spec, load_sd=False):
    args = spec.get('args')
    if args is None:
        args = dict()
    model = models[spec['name']](**args)
    print('args', args)

    if spec.get('load_ckpt') is not None:
        sd = load_sd_from_ckpt(spec['load_ckpt'], spec.get('load_ckpt_keys_only'))
        model.load_state_dict(sd, strict=False)
    
    if load_sd:
        model.load_state_dict(spec['sd'])

    return model


@register('identity')
def make_identity():
    return torch.nn.Identity()