| import torch |
| from torch import nn |
| from torch import Tensor, LongTensor |
|
|
| from transformers.activations import ACT2FN |
|
|
|
|
| class MlpClassifier(nn.Module): |
| """ Simple feed-forward multilayer perceptron classifier. """ |
|
|
| def __init__( |
| self, |
| input_size: int, |
| hidden_size: int, |
| n_classes: int, |
| activation: str, |
| dropout: float, |
| class_weights: list[float] = None, |
| ): |
| super().__init__() |
|
|
| self.n_classes = n_classes |
| self.classifier = nn.Sequential( |
| nn.Dropout(dropout), |
| nn.Linear(input_size, hidden_size), |
| ACT2FN[activation], |
| nn.Dropout(dropout), |
| nn.Linear(hidden_size, n_classes) |
| ) |
| if class_weights is not None: |
| class_weights = torch.tensor(class_weights, dtype=torch.long) |
| self.cross_entropy = nn.CrossEntropyLoss(weight=class_weights) |
|
|
| def forward(self, embeddings: Tensor, labels: LongTensor = None) -> dict: |
| logits = self.classifier(embeddings) |
| |
| loss = 0.0 |
| if labels is not None: |
| |
| loss = self.cross_entropy( |
| logits.view(-1, self.n_classes), |
| labels.view(-1) |
| ) |
| |
| preds = logits.argmax(dim=-1) |
| return {'preds': preds, 'loss': loss} |
|
|