|
|
import torch |
|
|
from torch import nn |
|
|
from torch import Tensor, LongTensor |
|
|
|
|
|
from transformers.activations import ACT2FN |
|
|
|
|
|
|
|
|
class MlpClassifier(nn.Module): |
|
|
""" Simple feed-forward multilayer perceptron classifier. """ |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
input_size: int, |
|
|
hidden_size: int, |
|
|
n_classes: int, |
|
|
activation: str, |
|
|
dropout: float, |
|
|
class_weights: list[float] = None, |
|
|
extra_hidden_size: int = None, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.n_classes = n_classes |
|
|
self.classifier = nn.Sequential( |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(input_size, hidden_size), |
|
|
ACT2FN[activation], |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(hidden_size, n_classes) |
|
|
) |
|
|
|
|
|
|
|
|
self.extra_output = nn.Linear(n_classes, n_classes) |
|
|
|
|
|
|
|
|
if class_weights is not None: |
|
|
class_weights = torch.tensor(class_weights, dtype=torch.long) |
|
|
self.cross_entropy = nn.CrossEntropyLoss(weight=class_weights) |
|
|
|
|
|
|
|
|
for param in self.classifier.parameters(): |
|
|
param.requires_grad = False |
|
|
for param in self.extra_output.parameters(): |
|
|
param.requires_grad = True |
|
|
|
|
|
|
|
|
def forward(self, embeddings: Tensor, labels: LongTensor = None) -> dict: |
|
|
logits = self.classifier(embeddings) |
|
|
|
|
|
loss = 0.0 |
|
|
if labels is not None: |
|
|
|
|
|
loss = self.cross_entropy( |
|
|
logits.view(-1, self.n_classes), |
|
|
labels.view(-1) |
|
|
) |
|
|
|
|
|
preds = logits.argmax(dim=-1) |
|
|
return {'preds': preds, 'loss': loss} |
|
|
|