| import torch |
| import torch.nn as nn |
|
|
| from .modules.HF_Decoder import HiF_Decoder |
| from .modules.MiT_Encoder import MixVisionTransformer |
| from .modules.Seg_Head import SegmentationHead |
|
|
| class DepthPolyp(nn.Module): |
| def __init__( |
| self, |
| in_channels: int = 3, |
| num_classes: int = 2, |
| encoder_name: str = 'b0', |
| decoder_channels: int = 256, |
| activation: str = None, |
| upsampling: int = 4, |
| ): |
| super().__init__() |
|
|
| |
| encoder_configs = { |
| 'b0': { |
| 'embed_dims': [32, 64, 160, 256], |
| 'num_heads': [1, 2, 5, 8], |
| 'mlp_ratios': [4, 4, 4, 4], |
| 'depths': [2, 2, 2, 2], |
| 'sr_ratios': [8, 4, 2, 1], |
| }, |
| 'b1': { |
| 'embed_dims': [64, 128, 320, 512], |
| 'num_heads': [1, 2, 5, 8], |
| 'mlp_ratios': [4, 4, 4, 4], |
| 'depths': [2, 2, 2, 2], |
| 'sr_ratios': [8, 4, 2, 1], |
| }, |
| 'b2': { |
| 'embed_dims': [64, 128, 320, 512], |
| 'num_heads': [1, 2, 5, 8], |
| 'mlp_ratios': [4, 4, 4, 4], |
| 'depths': [3, 4, 6, 3], |
| 'sr_ratios': [8, 4, 2, 1], |
| }, |
| 'b3': { |
| 'embed_dims': [64, 128, 320, 512], |
| 'num_heads': [1, 2, 5, 8], |
| 'mlp_ratios': [4, 4, 4, 4], |
| 'depths': [3, 4, 18, 3], |
| 'sr_ratios': [8, 4, 2, 1], |
| }, |
| 'b4': { |
| 'embed_dims': [64, 128, 320, 512], |
| 'num_heads': [1, 2, 5, 8], |
| 'mlp_ratios': [4, 4, 4, 4], |
| 'depths': [3, 8, 27, 3], |
| 'sr_ratios': [8, 4, 2, 1], |
| }, |
| 'b5': { |
| 'embed_dims': [64, 128, 320, 512], |
| 'num_heads': [1, 2, 5, 8], |
| 'mlp_ratios': [4, 4, 4, 4], |
| 'depths': [3, 6, 40, 3], |
| 'sr_ratios': [8, 4, 2, 1], |
| }, |
| } |
|
|
| if encoder_name not in encoder_configs: |
| raise ValueError(f"encoder_name should be one of {list(encoder_configs.keys())}, got {encoder_name}") |
|
|
| config = encoder_configs[encoder_name] |
|
|
| |
| self.encoder = MixVisionTransformer( |
| in_chans=in_channels, |
| embed_dims=config['embed_dims'], |
| num_heads=config['num_heads'], |
| mlp_ratios=config['mlp_ratios'], |
| qkv_bias=True, |
| depths=config['depths'], |
| sr_ratios=config['sr_ratios'], |
| drop_rate=0.0, |
| drop_path_rate=0.1, |
| ) |
|
|
|
|
| self.decoder = HiF_Decoder( |
| encoder_channels=config['embed_dims'], |
| decoder_channels=decoder_channels, |
| ) |
|
|
| |
| self.segmentation_head = SegmentationHead( |
| in_channels=decoder_channels//4, |
| out_channels=num_classes, |
| activation=activation, |
| kernel_size=1, |
| upsampling=upsampling, |
| ) |
|
|
| self.name = f"DepthPolyp-{encoder_name}" |
|
|
| def forward(self, x): |
| """Forward pass |
| |
| Args: |
| x: Input tensor of shape (B, C, H, W) |
| |
| Returns: |
| Output tensor of shape (B, num_classes, H, W) |
| """ |
| |
| encoder_features = self.encoder(x) |
| |
| |
|
|
| fpn_features = self.decoder(encoder_features) |
| decoder_output = fpn_features |
| |
|
|
| |
| masks = self.segmentation_head(decoder_output) |
| pred_seg = torch.sigmoid(masks[:, 0:1, :, :]) |
| pred_depth = torch.sigmoid(masks[:, 1:2, :, :]) |
|
|
| return pred_seg, pred_depth |
|
|
| @torch.no_grad() |
| def predict(self, x): |
| """Inference method""" |
| if self.training: |
| self.eval() |
| return self(x) |
|
|
| def load_pretrained(self, checkpoint_path, strict=True): |
| """Load pretrained weights |
| |
| Args: |
| checkpoint_path: Path to checkpoint file |
| strict: Whether to strictly enforce key matching |
| """ |
| state_dict = torch.load(checkpoint_path, map_location='cpu') |
| |
| |
| if 'state_dict' in state_dict: |
| state_dict = state_dict['state_dict'] |
| elif 'model' in state_dict: |
| state_dict = state_dict['model'] |
| |
| |
| new_state_dict = {} |
| for k, v in state_dict.items(): |
| if k.startswith('module.'): |
| new_state_dict[k[7:]] = v |
| else: |
| new_state_dict[k] = v |
| |
| self.load_state_dict(new_state_dict, strict=strict) |
| print(f"✓ Loaded pretrained weights from {checkpoint_path}") |
|
|
|
|
| def build_depthpolyp( |
| encoder_name='b0', |
| in_channels=3, |
| num_classes=2, |
| decoder_channels=256, |
| activation=None, |
| ): |
| """ |
| Create a DepthPolyp model |
| |
| Args: |
| encoder_name: Encoder variant ('b0', 'b1', 'b2', 'b3', 'b4', 'b5') |
| in_channels: Number of input channels |
| num_classes: Number of output classes |
| decoder_channels: Number of channels in decoder |
| activation: Output activation ('sigmoid', 'softmax', or None) |
| |
| Returns: |
| DepthPolyp model |
| |
| Example: |
| >>> model = build_depthpolyp('b2', num_classes=21, activation='softmax') |
| >>> print(model) |
| """ |
|
|
| model = DepthPolyp( |
| in_channels=in_channels, |
| num_classes=num_classes, |
| encoder_name=encoder_name, |
| decoder_channels=decoder_channels, |
| activation=activation, |
| ) |
| return model |
|
|
| if __name__ == '__main__': |
| print("="*60) |
| print("Loading Model .....") |
| model = build_depthpolyp( |
| encoder_name='b0', |
| in_channels=3, |
| num_classes=2, |
| decoder_channels=256, |
| activation='sigmoid', |
| ) |
| print("="*60) |
| print("Validating Model .....") |
| print("Check the Param and Complexity(GMACs)") |
| import ptflops |
| macs, params = ptflops.get_model_complexity_info( |
| model, (3, 224, 224), as_strings=True, |
| print_per_layer_stat=False, verbose=False |
| ) |
| print(f" MACs: {macs}, Params: {params}") |
| |
| print("="*60) |
| print("Check the output .....") |
| dummy_input = torch.randn(1, 3, 224, 224) |
| output_seg, output_depth = model(dummy_input) |
| print("input_shape is:", dummy_input.shape) |
| print("output_seg shape is:", output_seg.shape) |
| print("output_depth shape is:", output_depth.shape) |
|
|