DepthPolyp / model /depthpolyp.py
ReaganWZY's picture
Upload DepthPolyp model artifacts
5acc7ae verified
import torch
import torch.nn as nn
from .modules.HF_Decoder import HiF_Decoder
from .modules.MiT_Encoder import MixVisionTransformer
from .modules.Seg_Head import SegmentationHead
class DepthPolyp(nn.Module):
def __init__(
self,
in_channels: int = 3,
num_classes: int = 2, # 1 for seg, 1 for depth
encoder_name: str = 'b0',
decoder_channels: int = 256,
activation: str = None,
upsampling: int = 4,
):
super().__init__()
# Encoder configurations
encoder_configs = {
'b0': {
'embed_dims': [32, 64, 160, 256],
'num_heads': [1, 2, 5, 8],
'mlp_ratios': [4, 4, 4, 4],
'depths': [2, 2, 2, 2],
'sr_ratios': [8, 4, 2, 1],
},
'b1': {
'embed_dims': [64, 128, 320, 512],
'num_heads': [1, 2, 5, 8],
'mlp_ratios': [4, 4, 4, 4],
'depths': [2, 2, 2, 2],
'sr_ratios': [8, 4, 2, 1],
},
'b2': {
'embed_dims': [64, 128, 320, 512],
'num_heads': [1, 2, 5, 8],
'mlp_ratios': [4, 4, 4, 4],
'depths': [3, 4, 6, 3],
'sr_ratios': [8, 4, 2, 1],
},
'b3': {
'embed_dims': [64, 128, 320, 512],
'num_heads': [1, 2, 5, 8],
'mlp_ratios': [4, 4, 4, 4],
'depths': [3, 4, 18, 3],
'sr_ratios': [8, 4, 2, 1],
},
'b4': {
'embed_dims': [64, 128, 320, 512],
'num_heads': [1, 2, 5, 8],
'mlp_ratios': [4, 4, 4, 4],
'depths': [3, 8, 27, 3],
'sr_ratios': [8, 4, 2, 1],
},
'b5': {
'embed_dims': [64, 128, 320, 512],
'num_heads': [1, 2, 5, 8],
'mlp_ratios': [4, 4, 4, 4],
'depths': [3, 6, 40, 3],
'sr_ratios': [8, 4, 2, 1],
},
}
if encoder_name not in encoder_configs:
raise ValueError(f"encoder_name should be one of {list(encoder_configs.keys())}, got {encoder_name}")
config = encoder_configs[encoder_name]
# Build encoder
self.encoder = MixVisionTransformer(
in_chans=in_channels,
embed_dims=config['embed_dims'],
num_heads=config['num_heads'],
mlp_ratios=config['mlp_ratios'],
qkv_bias=True,
depths=config['depths'],
sr_ratios=config['sr_ratios'],
drop_rate=0.0,
drop_path_rate=0.1,
)
self.decoder = HiF_Decoder(
encoder_channels=config['embed_dims'],
decoder_channels=decoder_channels,
)
# Build segmentation head (nn.Sequential style)
self.segmentation_head = SegmentationHead(
in_channels=decoder_channels//4,
out_channels=num_classes,
activation=activation,
kernel_size=1,
upsampling=upsampling,
)
self.name = f"DepthPolyp-{encoder_name}"
def forward(self, x):
"""Forward pass
Args:
x: Input tensor of shape (B, C, H, W)
Returns:
Output tensor of shape (B, num_classes, H, W)
"""
# Encoder - returns features at [H/4, H/8, H/16, H/32]
encoder_features = self.encoder(x)
# Decoder - returns features at H/4
fpn_features = self.decoder(encoder_features)
decoder_output = fpn_features
# print(f"Decoder output shape: {decoder_output.shape}")
# Segmentation head - upsample to original size
masks = self.segmentation_head(decoder_output)
pred_seg = torch.sigmoid(masks[:, 0:1, :, :]) # segmentation 通道
pred_depth = torch.sigmoid(masks[:, 1:2, :, :]) # depth 通道,通常是回归,不做激活
return pred_seg, pred_depth
@torch.no_grad()
def predict(self, x):
"""Inference method"""
if self.training:
self.eval()
return self(x)
def load_pretrained(self, checkpoint_path, strict=True):
"""Load pretrained weights
Args:
checkpoint_path: Path to checkpoint file
strict: Whether to strictly enforce key matching
"""
state_dict = torch.load(checkpoint_path, map_location='cpu')
# Handle different checkpoint formats
if 'state_dict' in state_dict:
state_dict = state_dict['state_dict']
elif 'model' in state_dict:
state_dict = state_dict['model']
# Remove module. prefix if present (from DataParallel)
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith('module.'):
new_state_dict[k[7:]] = v
else:
new_state_dict[k] = v
self.load_state_dict(new_state_dict, strict=strict)
print(f"✓ Loaded pretrained weights from {checkpoint_path}")
def build_depthpolyp(
encoder_name='b0',
in_channels=3,
num_classes=2,
decoder_channels=256,
activation=None,
):
"""
Create a DepthPolyp model
Args:
encoder_name: Encoder variant ('b0', 'b1', 'b2', 'b3', 'b4', 'b5')
in_channels: Number of input channels
num_classes: Number of output classes
decoder_channels: Number of channels in decoder
activation: Output activation ('sigmoid', 'softmax', or None)
Returns:
DepthPolyp model
Example:
>>> model = build_depthpolyp('b2', num_classes=21, activation='softmax')
>>> print(model)
"""
model = DepthPolyp(
in_channels=in_channels,
num_classes=num_classes,
encoder_name=encoder_name,
decoder_channels=decoder_channels,
activation=activation,
)
return model
if __name__ == '__main__':
print("="*60)
print("Loading Model .....")
model = build_depthpolyp(
encoder_name='b0',
in_channels=3, # Input channels
num_classes=2, # Total 2. 1 for seg, 1 for depth
decoder_channels=256,
activation='sigmoid',
)
print("="*60)
print("Validating Model .....")
print("Check the Param and Complexity(GMACs)")
import ptflops
macs, params = ptflops.get_model_complexity_info(
model, (3, 224, 224), as_strings=True,
print_per_layer_stat=False, verbose=False
)
print(f" MACs: {macs}, Params: {params}")
# output is MACs: 862.17 MMac, Params: 3.57 M
print("="*60)
print("Check the output .....")
dummy_input = torch.randn(1, 3, 224, 224) # B, C, H, W, single RGB image
output_seg, output_depth = model(dummy_input)
print("input_shape is:", dummy_input.shape)
print("output_seg shape is:", output_seg.shape)
print("output_depth shape is:", output_depth.shape)