Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| from .tiny_block import TinyBlock | |
| from transformers import MambaConfig, MambaModel | |
| # from .conmamba import ConMamba | |
| class CSPTinyLayer(nn.Module): | |
| def __init__(self, in_channels, out_channels, num_blocks, ssm=False): | |
| super(CSPTinyLayer, self).__init__() | |
| self.ssm = ssm | |
| # Split channels | |
| self.split_channels = in_channels // 2 | |
| if self.ssm: | |
| # Mamba Blocks | |
| configuration = MambaConfig(vocab_size=0, hidden_size=self.split_channels, num_hidden_layers=num_blocks) | |
| self.mamba_blocks = MambaModel(configuration) | |
| # mamba_config = { | |
| # 'd_state': self.split_channels, | |
| # 'expand': 2, | |
| # 'd_conv': 4, | |
| # 'bidirectional': True | |
| # } | |
| # self.mamba_blocks = ConMamba( | |
| # num_blocks=num_blocks, | |
| # channels=self.split_channels, | |
| # height=8, | |
| # width=8, | |
| # mamba_config=mamba_config | |
| # ) | |
| else: | |
| # TinyBlocks | |
| self.tiny_blocks = nn.Sequential( | |
| *[TinyBlock(self.split_channels, self.split_channels) for _ in range(num_blocks)] | |
| ) | |
| # Transition layer to adjust channel dimensions | |
| self.transition = nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False), | |
| nn.BatchNorm2d(out_channels), | |
| nn.ReLU(inplace=True) | |
| ) | |
| def forward(self, x): | |
| # Split input into two parts | |
| p1 = x[:, :self.split_channels, :, :] | |
| p2 = x[:, self.split_channels:, :, :] | |
| if self.ssm: | |
| # Reshape to fit Mamba | |
| B, C, H, W = p2.shape | |
| p2 = p2.permute(0, 2, 3, 1) # [B, H, W, C] | |
| p2 = p2.reshape(B, H * W, C) # [B, L, C], L = H * W | |
| # Process p2 through MambaBlocks | |
| p2_out = self.mamba_blocks(inputs_embeds=p2).last_hidden_state | |
| # p2_out = self.mamba_blocks(p2) | |
| # Reshape back to original dimension | |
| p2_out = p2_out.reshape(B, H, W, -1) | |
| p2_out = p2_out.permute(0, 3, 1, 2) # [B, C, H, W] | |
| else: | |
| # Process p2 through TinyBlocks | |
| p2_out = self.tiny_blocks(p2) | |
| # Concatenate p1 and processed p2 | |
| concatenated = torch.cat((p1, p2_out), dim=1) | |
| # Apply transition layer | |
| out = self.transition(concatenated) | |
| return out | |
| if __name__ == "__main__": | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| model = CSPTinyLayer(32, 32, 2, True).to(device) | |
| print(model) | |
| dummy_input = torch.randn(256, 32, 8, 8).to(device) | |
| output = model(dummy_input) | |
| print(output.shape) | |