File size: 1,412 Bytes
50e9111 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import torch
from torch import nn
from torch import Tensor, LongTensor
from transformers.activations import ACT2FN
class MlpClassifier(nn.Module):
""" Simple feed-forward multilayer perceptron classifier. """
def __init__(
self,
input_size: int,
hidden_size: int,
n_classes: int,
activation: str,
dropout: float,
class_weights: list[float] = None,
):
super().__init__()
self.n_classes = n_classes
self.classifier = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(input_size, hidden_size),
ACT2FN[activation],
nn.Dropout(dropout),
nn.Linear(hidden_size, n_classes)
)
if class_weights is not None:
class_weights = torch.tensor(class_weights, dtype=torch.long)
self.cross_entropy = nn.CrossEntropyLoss(weight=class_weights)
def forward(self, embeddings: Tensor, labels: LongTensor = None) -> dict:
logits = self.classifier(embeddings)
# Calculate loss.
loss = 0.0
if labels is not None:
# Reshape tensors to match expected dimensions
loss = self.cross_entropy(
logits.view(-1, self.n_classes),
labels.view(-1)
)
# Predictions.
preds = logits.argmax(dim=-1)
return {'preds': preds, 'loss': loss}
|