File size: 1,935 Bytes
acceea9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import torch
from torch import nn
from torch import Tensor, LongTensor
from transformers.activations import ACT2FN
class MlpClassifier(nn.Module):
""" Simple feed-forward multilayer perceptron classifier. """
def __init__(
self,
input_size: int,
hidden_size: int,
n_classes: int,
activation: str,
dropout: float,
class_weights: list[float] = None,
extra_hidden_size: int = None, #новый слой
):
super().__init__()
self.n_classes = n_classes
self.classifier = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(input_size, hidden_size),
ACT2FN[activation],
nn.Dropout(dropout),
nn.Linear(hidden_size, n_classes)
)
# Новый дополнительный слой (можно поменять на свой размер)
self.extra_output = nn.Linear(n_classes, n_classes)
if class_weights is not None:
class_weights = torch.tensor(class_weights, dtype=torch.long)
self.cross_entropy = nn.CrossEntropyLoss(weight=class_weights)
# сразу заморозим все параметры кроме extra_output
for param in self.classifier.parameters():
param.requires_grad = False
for param in self.extra_output.parameters():
param.requires_grad = True
def forward(self, embeddings: Tensor, labels: LongTensor = None) -> dict:
logits = self.classifier(embeddings)
# Calculate loss.
loss = 0.0
if labels is not None:
# Reshape tensors to match expected dimensions
loss = self.cross_entropy(
logits.view(-1, self.n_classes),
labels.view(-1)
)
# Predictions.
preds = logits.argmax(dim=-1)
return {'preds': preds, 'loss': loss}
|