| import torch.nn as nn | |
| from pathlib import Path | |
| from ..source import ResnetEncoder, MultiHeadDecoder, DenseMTL, DenseReg, Vanilla | |
| def replace_batchnorm_(module: nn.Module): | |
| for name, child in module.named_children(): | |
| if isinstance(child, nn.BatchNorm2d): | |
| setattr(module, name, nn.InstanceNorm2d(child.num_features)) | |
| else: | |
| replace_batchnorm_(child) | |
| def get_model(archi): | |
| assert archi == 'densemtl' | |
| encoder = ResnetEncoder(num_layers=101, pretrained=True, in_channels=3) | |
| decoder = MultiHeadDecoder( | |
| num_ch_enc=encoder.num_ch_enc, | |
| tasks=dict(albedo=3, roughness=1, normals=2), | |
| return_feats=False, | |
| use_skips=True) | |
| model = nn.Sequential(encoder, decoder) | |
| replace_batchnorm_(model) | |
| return model | |
| def get_module(args): | |
| loss = DenseReg(**args.loss) | |
| model = get_model(args.archi) | |
| weights = args.load_weights_from | |
| if weights: | |
| assert weights.is_file() | |
| return Vanilla.load_from_checkpoint(str(weights), model=model, loss=loss, strict=False, **args.routine) | |
| return Vanilla(model, loss, **args.routine) | |
| def get_inference_module(pt): | |
| assert Path(pt).exists() | |
| model = get_model('densemtl') | |
| return Vanilla.load_from_checkpoint(str(pt), model=model, strict=False) |