File size: 567 Bytes
ea7bb95 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 | import segmentation_models_pytorch as smp
def build_model(config):
"""Build U-Net with pretrained encoder and multi-channel sigmoid output.
Each output channel predicts one mask type independently (multi-label).
Returns raw logits — apply sigmoid in loss/inference.
"""
model = smp.Unet(
encoder_name=config.encoder_name,
encoder_weights=config.encoder_weights,
in_channels=3,
classes=config.num_classes,
activation=None,
decoder_attention_type=config.decoder_attention,
)
return model
|