vrevar
Add application file
04c78c7
import torch.nn as nn
from pathlib import Path
from ..source import ResnetEncoder, MultiHeadDecoder, DenseMTL, DenseReg, Vanilla
def replace_batchnorm_(module: nn.Module):
for name, child in module.named_children():
if isinstance(child, nn.BatchNorm2d):
setattr(module, name, nn.InstanceNorm2d(child.num_features))
else:
replace_batchnorm_(child)
def get_model(archi):
assert archi == 'densemtl'
encoder = ResnetEncoder(num_layers=101, pretrained=True, in_channels=3)
decoder = MultiHeadDecoder(
num_ch_enc=encoder.num_ch_enc,
tasks=dict(albedo=3, roughness=1, normals=2),
return_feats=False,
use_skips=True)
model = nn.Sequential(encoder, decoder)
replace_batchnorm_(model)
return model
def get_module(args):
loss = DenseReg(**args.loss)
model = get_model(args.archi)
weights = args.load_weights_from
if weights:
assert weights.is_file()
return Vanilla.load_from_checkpoint(str(weights), model=model, loss=loss, strict=False, **args.routine)
return Vanilla(model, loss, **args.routine)
def get_inference_module(pt):
assert Path(pt).exists()
model = get_model('densemtl')
return Vanilla.load_from_checkpoint(str(pt), model=model, strict=False)