Spaces:
Build error
Build error
| import torch | |
| # Define a new classification head | |
| class NewClassificationHead(torch.nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.dense = torch.nn.Linear(config.hidden_size, config.hidden_size) | |
| self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) | |
| self.out_proj = torch.nn.Linear(config.hidden_size, config.num_labels) | |
| def forward(self, features, **kwargs): | |
| x = features[:, 0, :] # take <s> token (equiv. to [CLS]) | |
| x = self.dropout(x) | |
| x = self.dense(x) | |
| x = torch.nn.functional.relu(x) | |
| x = self.dropout(x) | |
| x = self.out_proj(x) | |
| return x | |