| import torch.nn as nn | |
| class Pooler(nn.Module): | |
| def __init__(self, hidden_size): | |
| super().__init__() | |
| self.dense = nn.Linear(hidden_size, hidden_size) | |
| self.activation = nn.Tanh() | |
| def forward(self, hidden_states): | |
| first_token_tensor = hidden_states[:, 0] | |
| pooled_output = self.dense(first_token_tensor) | |
| pooled_output = self.activation(pooled_output) | |
| return pooled_output | |
| class ITCHead(nn.Module): | |
| def __init__(self, hidden_size, out_size): | |
| super().__init__() | |
| self.fc = nn.Linear(hidden_size, out_size, bias=False) | |
| def forward(self, x): | |
| x = self.fc(x) | |
| return x | |