from .model import Model import torch from torch import nn from libs.utils.comm import get_world_size def build_model(cfg): if get_world_size() == 1: norm_layer = nn.BatchNorm2d else: norm_layer = nn.BatchNorm2d model = Model( cfg, norm_layer=norm_layer ) return model