import torch import torch.nn as nn import re import math BatchNorm2d=nn.BatchNorm2d def conv3x3(in_planes, out_planes, stride=1): """3x3 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) class BasicBlockLayerNorm(nn.Module): expansion = 1 def __init__(self, inplanes, planes,norm_shape, stride=1, downsample=None, dcn=None): super(BasicBlockLayerNorm, self).__init__() self.with_dcn = dcn is not None self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = nn.LayerNorm(norm_shape) self.relu = nn.ReLU(inplace=True) self.with_modulated_dcn = False if self.with_dcn: fallback_on_stride = dcn.get('fallback_on_stride', False) self.with_modulated_dcn = dcn.get('modulated', False) self.conv2 = conv3x3(planes, planes) if not self.with_dcn or fallback_on_stride: self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False) else: raise NotImplementedError self.bn2 = nn.LayerNorm(norm_shape) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) if not self.with_dcn: out = self.conv2(out) elif self.with_modulated_dcn: offset_mask = self.conv2_offset(out) offset = offset_mask[:, :18, :, :] mask = offset_mask[:, -9:, :, :].sigmoid() out = self.conv2(out, offset, mask) else: offset = self.conv2_offset(out) out = self.conv2(out, offset) out = self.bn2(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None, dcn=None): super(BasicBlock, self).__init__() self.with_dcn = dcn is not None self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = nn.BatchNorm2d(planes) self.relu = nn.ReLU(inplace=True) self.with_modulated_dcn = False if self.with_dcn: fallback_on_stride = dcn.get('fallback_on_stride', False) self.with_modulated_dcn = dcn.get('modulated', False) self.conv2 = conv3x3(planes, planes) if not self.with_dcn or fallback_on_stride: self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False) else: raise NotImplementedError self.bn2 = nn.BatchNorm2d(planes) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) if not self.with_dcn: out = self.conv2(out) elif self.with_modulated_dcn: offset_mask = self.conv2_offset(out) offset = offset_mask[:, :18, :, :] mask = offset_mask[:, -9:, :, :].sigmoid() out = self.conv2(out, offset, mask) else: offset = self.conv2_offset(out) out = self.conv2(out, offset) out = self.bn2(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class BasicBlockWOnorm(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None, dcn=None): super(BasicBlockWOnorm, self).__init__() self.with_dcn = dcn is not None self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = BatchNorm2d(planes) self.relu = nn.ReLU(inplace=True) self.with_modulated_dcn = False if self.with_dcn: fallback_on_stride = dcn.get('fallback_on_stride', False) self.with_modulated_dcn = dcn.get('modulated', False) self.conv2 = conv3x3(planes, planes) if not self.with_dcn or fallback_on_stride: self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False) else: raise NotImplementedError self.bn2 = BatchNorm2d(planes) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) if not self.with_dcn: out = self.conv2(out) elif self.with_modulated_dcn: offset_mask = self.conv2_offset(out) offset = offset_mask[:, :18, :, :] mask = offset_mask[:, -9:, :, :].sigmoid() out = self.conv2(out, offset, mask) else: offset = self.conv2_offset(out) out = self.conv2(out, offset) out = self.bn2(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class ResNetWOnorm(nn.Module): def __init__(self, dcn=None, out_dim=4096): print('using resnet without batchnorm') self.dcn = dcn self.inplanes = 256 super(ResNetWOnorm, self).__init__() self.layer1 = self._make_layer( BasicBlockWOnorm, 1024, 1, stride=2, dcn=dcn) self.layer2 = self._make_layer( BasicBlockWOnorm, 4096, 1, stride=2, dcn=dcn) self.fc = nn.Linear(4096, out_dim) for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() def _make_layer(self, block, planes, blocks, stride=1, dcn=None): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), BatchNorm2d(planes * block.expansion), ) layers = [] layers.append(block(self.inplanes, planes, stride, downsample, dcn=dcn)) self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(self.inplanes, planes, dcn=dcn)) return nn.Sequential(*layers) def forward(self, x): x = self.layer1(x) x = self.layer2(x) x = x.reshape(x.shape[0],x.shape[1],-1) x = x.permute(0,2,1) x = self.fc(x) return x class ResNetLayerNorm(nn.Module): def __init__(self, dcn=None, out_dim=4096): print('using resnet with layernorm') self.dcn = dcn self.inplanes = 256 h,w = 64,64 super(ResNetLayerNorm, self).__init__() self.layer1 = self._make_layer( BasicBlockLayerNorm, 1024, 1,[1024,32,32], stride=2, dcn=dcn) self.layer2 = self._make_layer( BasicBlockLayerNorm, 4096, 1,[4096,16,16], stride=2, dcn=dcn) self.fc = nn.Linear(4096, out_dim) for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() def _make_layer(self, block, planes, blocks,norm_shape, stride=1, dcn=None): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), nn.LayerNorm(norm_shape), ) layers = [] layers.append(block(self.inplanes, planes,norm_shape, stride, downsample, dcn=dcn)) self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(self.inplanes, planes, dcn=dcn)) return nn.Sequential(*layers) def forward(self, x): x = self.layer1(x) x = self.layer2(x) x = x.reshape(x.shape[0],x.shape[1],-1) x = x.permute(0,2,1) x = self.fc(x) return x class ResNet(nn.Module): def __init__(self, dcn=None, out_dim=4096): self.dcn = dcn self.inplanes = 256 super(ResNet, self).__init__() self.layer1 = self._make_layer( BasicBlock, 1024, 1, stride=2, dcn=dcn) self.layer2 = self._make_layer( BasicBlock, 4096, 1, stride=2, dcn=dcn) self.fc = nn.Linear(4096, out_dim) for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() def _make_layer(self, block, planes, blocks, stride=1, dcn=None): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes * block.expansion), ) layers = [] layers.append(block(self.inplanes, planes, stride, downsample, dcn=dcn)) self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(self.inplanes, planes, dcn=dcn)) return nn.Sequential(*layers) def forward(self, x): x = self.layer1(x) x = self.layer2(x) x = x.reshape(x.shape[0],x.shape[1],-1) x = x.permute(0,2,1) x = self.fc(x) return x class ResNetSwin(nn.Module): def __init__(self, dcn=None, input_dim=1024, out_dim=4096): self.dcn = dcn self.inplanes = input_dim super(ResNetSwin, self).__init__() self.layer1 = self._make_layer( BasicBlock, 2048, 1, stride=2, dcn=dcn) self.fc = nn.Linear(2048, out_dim) for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() def _make_layer(self, block, planes, blocks, stride=1, dcn=None): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes * block.expansion), ) layers = [] layers.append(block(self.inplanes, planes, stride, downsample, dcn=dcn)) self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(self.inplanes, planes, dcn=dcn)) return nn.Sequential(*layers) def forward(self, x): x = self.layer1(x) x = x.reshape(x.shape[0],x.shape[1],-1) x = x.permute(0,2,1) x = self.fc(x) return x class IdentityMap(nn.Module): def __init__(self): super().__init__() def forward(self, x, *args, **kwargs): return x @property def config(self): return {"mm_projector_type": 'identity'} class SimpleResBlock(nn.Module): def __init__(self, channels): super().__init__() self.pre_norm = nn.LayerNorm(channels) self.proj = nn.Sequential( nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels) ) def forward(self, x): x = self.pre_norm(x) return x + self.proj(x) def build_vision_projector(config, delay_load=False, **kwargs): projector_type = getattr(config, 'mm_projector_type', 'linear') print("projector_type:", projector_type) #debug if projector_type == 'linear': return nn.Linear(config.mm_hidden_size, config.hidden_size) if projector_type == 'conv': with_norm = getattr(config, 'with_norm', True) with_layernorm = getattr(config, 'with_layernorm', True) out_dim = getattr(config,'projector_outdim',4096) # print(out_dim) if with_layernorm: return ResNetLayerNorm(out_dim=out_dim) if with_norm: return ResNet(out_dim=out_dim) else: return ResNetWOnorm(out_dim=out_dim) if projector_type == 'swin_conv': out_dim = getattr(config,'projector_outdim',4096) input_dim = getattr(config,'mm_input_embeds',1024) return ResNetSwin(input_dim=input_dim,out_dim=out_dim) mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) print("mlp_gelu_match:", mlp_gelu_match) #debug if mlp_gelu_match: mlp_depth = int(mlp_gelu_match.group(1)) modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] for _ in range(1, mlp_depth): modules.append(nn.GELU()) modules.append(nn.Linear(config.hidden_size, config.hidden_size)) return nn.Sequential(*modules) if projector_type == 'identity': return IdentityMap() raise ValueError(f'Unknown projector type: {projector_type}') if __name__ == '__main__': class Config: def __init__(self): self.mm_projector_type = 'conv' self.with_layernorm = True self.with_norm = False config = Config() net = build_vision_projector(config) image = torch.randn((4,256,64,64)) print(net) print(net(image).shape)