File size: 1,302 Bytes
04c78c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch.nn as nn

from pathlib import Path

from ..source import ResnetEncoder, MultiHeadDecoder, DenseMTL, DenseReg, Vanilla


def replace_batchnorm_(module: nn.Module):
    for name, child in module.named_children():
        if isinstance(child, nn.BatchNorm2d):
            setattr(module, name, nn.InstanceNorm2d(child.num_features))
        else:
            replace_batchnorm_(child)

def get_model(archi):
    assert archi == 'densemtl'

    encoder = ResnetEncoder(num_layers=101, pretrained=True, in_channels=3)
    decoder = MultiHeadDecoder(
        num_ch_enc=encoder.num_ch_enc,
        tasks=dict(albedo=3, roughness=1, normals=2),
        return_feats=False,
        use_skips=True)

    model = nn.Sequential(encoder, decoder)
    replace_batchnorm_(model)
    return model

def get_module(args):
    loss = DenseReg(**args.loss)
    model = get_model(args.archi)

    weights = args.load_weights_from
    if weights:
        assert weights.is_file()
        return Vanilla.load_from_checkpoint(str(weights), model=model, loss=loss, strict=False, **args.routine)

    return Vanilla(model, loss, **args.routine)

def get_inference_module(pt):
    assert Path(pt).exists()
    model = get_model('densemtl')
    return Vanilla.load_from_checkpoint(str(pt), model=model, strict=False)