| |
|
| | import torch.nn as nn
|
| |
|
| | class ModalityProjector(nn.Module):
|
| | def __init__(self, cfg):
|
| | super().__init__()
|
| | self.cfg = cfg
|
| | self.input_dim = cfg.vit_hidden_dim * (cfg.mp_pixel_shuffle_factor**2)
|
| | self.output_dim = cfg.lm_hidden_dim
|
| | self.scale_factor = cfg.mp_pixel_shuffle_factor
|
| |
|
| | self.proj = nn.Linear(self.input_dim, self.output_dim, bias=False)
|
| |
|
| | self.apply(self._init_weights)
|
| |
|
| | def _init_weights(self, module):
|
| | if isinstance(module, nn.Linear):
|
| | nn.init.normal_(self.proj.weight, mean=0.0, std=0.02)
|
| | if module.bias is not None:
|
| | nn.init.zeros_(module.bias)
|
| |
|
| |
|
| | def pixel_shuffle(self, x):
|
| | bsz, seq, embed_dim = x.size()
|
| | seq_root = int(seq**0.5)
|
| | assert seq_root**2 == seq
|
| | assert seq_root % self.scale_factor == 0
|
| |
|
| | height = width = seq_root
|
| | x = x.view(bsz, height, width, embed_dim)
|
| | h_out = height // self.scale_factor
|
| | w_out = width // self.scale_factor
|
| |
|
| | x = x.reshape(bsz, h_out, self.scale_factor, w_out, self.scale_factor, embed_dim)
|
| | x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
|
| | x = x.reshape(bsz, h_out * w_out, embed_dim * self.scale_factor**2)
|
| |
|
| | return x
|
| |
|
| | def forward(self, x):
|
| | x = self.pixel_shuffle(x)
|
| | x = self.proj(x)
|
| |
|
| | return x
|
| |
|
| | |