Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, AutoConfig | |
| class _MLPVectorProjector(nn.Module): | |
| def __init__( | |
| self, | |
| input_hidden_size: int = 512, | |
| lm_hidden_size: int = 2560, | |
| num_layers: int = 1, | |
| width: int = 4 | |
| ): | |
| super(_MLPVectorProjector, self).__init__() | |
| self.mlps = nn.ModuleList() | |
| for _ in range(width): | |
| mlp = [nn.Linear(input_hidden_size, lm_hidden_size, bias=False)] | |
| for _ in range(1, num_layers): | |
| mlp.append(nn.GELU()) | |
| mlp.append(nn.Linear(lm_hidden_size, lm_hidden_size, bias=False)) | |
| self.mlps.append(nn.Sequential(*mlp)) | |
| def forward(self, x): | |
| return torch.cat([mlp(x) for mlp in self.mlps], dim=-2) | |
| def build_mlp_vector_projector( | |
| input_hidden_size: int = 512, | |
| lm_hidden_size: int = 2560, | |
| num_layers: int = 1, | |
| num_tokens: int = 4 | |
| ): | |
| return _MLPVectorProjector( | |
| input_hidden_size, lm_hidden_size, num_layers, num_tokens | |
| ) | |