| from models.segmentation_models.cen import ChannelExchangingNetwork | |
| from models.segmentation_models.deeplabv3p import DeepLabV3p_r101, DeepLabV3p_r18, DeepLabV3p_r50 | |
| from models.segmentation_models.linearfuse.segformer import WeTrLinearFusion | |
| from models.segmentation_models.linearfusebothmask.segformer import LinearFusionBothMask | |
| from models.segmentation_models.linearfusecons.segformer import LinearFusionConsistency | |
| from models.segmentation_models.linearfusemaemaskedcons.segformer import LinearFusionMAEMaskedConsistency | |
| from models.segmentation_models.linearfusemaskedcons.segformer import LinearFusionMaskedConsistency | |
| from models.segmentation_models.linearfusemaskedconsmixbatch.segformer import LinearFusionMaskedConsistencyMixBatch | |
| from models.segmentation_models.linearfusesepdecodermaskedcons.segformer import LinearFusionSepDecoderMaskedConsistency | |
| from models.segmentation_models.linearfusetokenmix.segformer import LinearFusionTokenMix | |
| from models.segmentation_models.randomexchangecons.segformer import RandomExchangePredConsistency | |
| from models.segmentation_models.randomfusion.segformer import WeTrRandomFusion | |
| from models.segmentation_models.randomfusiondmlp.segformer import WeTrRandomFusionDMLP | |
| from models.segmentation_models.refinenet import MyRefineNet | |
| from models.segmentation_models.segformer.segformer import SegFormer | |
| from models.segmentation_models.tokenfusion.segformer import WeTr | |
| from models.segmentation_models.tokenfusionbothmask.segformer import TokenFusionBothMask | |
| from models.segmentation_models.tokenfusionmaemaskedconsistency.segformer import TokenFusionMAEMaskedConsistency | |
| from models.segmentation_models.tokenfusionmaskedconsistency.segformer import TokenFusionMaskedConsistency | |
| from models.segmentation_models.tokenfusionmaskedconsistencymixbatch.segformer import TokenFusionMaskedConsistencyMixBatch | |
| from models.segmentation_models.unifiedrepresentation.segformer import UnifiedRepresentationNetwork | |
| from models.segmentation_models.unifiedrepresentationmoddrop.segformer import UnifiedRepresentationNetworkModDrop | |
| def get_model(args, **kwargs): | |
| if args.seg_model == "dlv3p": | |
| if args.base_model == "r18": | |
| return DeepLabV3p_r18(args.num_classes, args) | |
| elif args.base_model == "r50": | |
| return DeepLabV3p_r50(args.num_classes, args) | |
| elif args.base_model == "r101": | |
| return DeepLabV3p_r101(args.num_classes, args) | |
| else: | |
| raise Exception(f"{args.base_model} not configured") | |
| elif args.seg_model == 'refinenet': | |
| if args.base_model == 'r18': | |
| return MyRefineNet(num_layers = 18, num_classes = args.num_classes) | |
| if args.base_model == 'r50': | |
| return MyRefineNet(num_layers = 50, num_classes = args.num_classes) | |
| if args.base_model == 'r101': | |
| return MyRefineNet(num_layers = 101, num_classes = args.num_classes) | |
| elif args.seg_model == 'cen': | |
| if args.base_model == 'r18': | |
| return ChannelExchangingNetwork(num_layers = 18, num_classes = args.num_classes, num_parallel = 2, l1_lambda = args.l1_lambda, bn_threshold = args.exchange_threshold) | |
| if args.base_model == 'r50': | |
| return ChannelExchangingNetwork(num_layers = 50, num_classes = args.num_classes, num_parallel = 2, l1_lambda = args.l1_lambda, bn_threshold = args.exchange_threshold) | |
| if args.base_model == 'r101': | |
| return ChannelExchangingNetwork(num_layers = 101, num_classes = args.num_classes, num_parallel = 2, l1_lambda = args.l1_lambda, bn_threshold = args.exchange_threshold) | |
| elif args.seg_model == 'segformer': | |
| return SegFormer(args.base_model, args, num_classes= args.num_classes) | |
| elif args.seg_model == 'tokenfusion': | |
| return WeTr(args.base_model, args, l1_lambda = args.l1_lambda, num_classes = args.num_classes) | |
| elif args.seg_model == 'randomfusion': | |
| return WeTrRandomFusion(args.base_model, args, num_classes = args.num_classes) | |
| elif args.seg_model == 'randomfusiondmlp': | |
| return WeTrRandomFusionDMLP(args.base_model, args, num_classes = args.num_classes) | |
| elif args.seg_model == 'randomexchangepredconsistency': | |
| return RandomExchangePredConsistency(args.base_model, args, cons_lambda = args.cons_lambda, num_classes = args.num_classes) | |
| elif args.seg_model == 'linearfusion': | |
| pretrained = True | |
| if "pretrained_init" in args: | |
| pretrained = args.pretrained_init | |
| print("Using pretrained SegFormer? ", pretrained) | |
| return WeTrLinearFusion(args.base_model, args, num_classes = args.num_classes, pretrained=pretrained) | |
| elif args.seg_model == 'linearfusionconsistency': | |
| return LinearFusionConsistency(args.base_model, args, cons_lambda = args.cons_lambda, num_classes = args.num_classes) | |
| elif args.seg_model == 'linearfusionmaskedcons': | |
| pretrained = True | |
| if "pretrained_init" in args: | |
| pretrained = args.pretrained_init | |
| print("Using pretrained SegFormer? ", pretrained) | |
| return LinearFusionMaskedConsistency(args.base_model, args, num_classes = args.num_classes, pretrained=pretrained) | |
| elif args.seg_model == 'linearfusionmaskedconsmixbatch': | |
| return LinearFusionMaskedConsistencyMixBatch(args.base_model, args, num_classes = args.num_classes) | |
| elif args.seg_model == 'linearfusionsepdecodermaskedcons': | |
| return LinearFusionSepDecoderMaskedConsistency(args.base_model, args, num_classes = args.num_classes) | |
| elif args.seg_model == 'linearfusionmaemaskedcons': | |
| return LinearFusionMAEMaskedConsistency(args.base_model, args, num_classes = args.num_classes) | |
| elif args.seg_model == 'tokenfusionmaskedcons': | |
| return TokenFusionMaskedConsistency(args.base_model, args, l1_lambda = args.l1_lambda, num_classes = args.num_classes) | |
| elif args.seg_model == 'tokenfusionmaskedconsmixbatch': | |
| return TokenFusionMaskedConsistencyMixBatch(args.base_model, args, l1_lambda = args.l1_lambda, num_classes = args.num_classes) | |
| elif args.seg_model == 'tokenfusionbothmask': | |
| return TokenFusionBothMask(args.base_model, args, l1_lambda = args.l1_lambda, num_classes = args.num_classes, **kwargs) | |
| elif args.seg_model == "linearfusebothmask": | |
| return LinearFusionBothMask(args.base_model, args, num_classes = args.num_classes) | |
| elif args.seg_model == "linearfusiontokenmix": | |
| return LinearFusionTokenMix(args.base_model, args, num_classes = args.num_classes, exchange_percent = args.exchange_percent) | |
| elif args.seg_model == "tokenfusionmaemaskedcons": | |
| return TokenFusionMAEMaskedConsistency(args.base_model, args, l1_lambda = args.l1_lambda, num_classes = args.num_classes) | |
| elif args.seg_model == "unifiedrepresentationnetwork": | |
| return UnifiedRepresentationNetwork(args.base_model, args, num_classes = args.num_classes) | |
| elif args.seg_model == "unifiedrepresentationnetworkmoddrop": | |
| return UnifiedRepresentationNetworkModDrop(args.base_model, args, num_classes = args.num_classes) | |
| else: | |
| raise Exception(f"{args.seg_model} not configured") |