training_sem / libs /model /__init__.py
kai-2054's picture
Initial commit: add code
cb0ad2d
raw
history blame
323 Bytes
from .model import Model
import torch
from torch import nn
from libs.utils.comm import get_world_size
def build_model(cfg):
if get_world_size() == 1:
norm_layer = nn.BatchNorm2d
else:
norm_layer = nn.BatchNorm2d
model = Model(
cfg,
norm_layer=norm_layer
)
return model