| | import torch.nn as nn |
| | from transformers import PreTrainedModel, AutoModelForSequenceClassification |
| | from transformers.modeling_outputs import SequenceClassifierOutput |
| | from .configuration_mlp import MLPConfig |
| |
|
| |
|
| | class MLP(PreTrainedModel): |
| | r""" |
| | A simple MLP model that takes a 3D input [batch_size, seq_length, embedding_size] |
| | and performs multi-label classification using BCE loss. |
| | """ |
| | config_class = MLPConfig |
| |
|
| | def __init__(self, config: MLPConfig): |
| | super().__init__(config) |
| | self.config = config |
| |
|
| | |
| | layers = [] |
| | input_dim = config.embedding_size * config.sequence_length |
| | for _ in range(config.num_hidden_layers): |
| | layers.append(nn.Linear(input_dim, config.hidden_size)) |
| | layers.append(nn.ReLU()) |
| | layers.append(nn.Dropout(config.dropout)) |
| | input_dim = config.hidden_size |
| | |
| | layers.append(nn.Linear(input_dim, config.num_labels)) |
| |
|
| | self.mlp = nn.Sequential(*layers) |
| |
|
| | |
| | self.post_init() |
| |
|
| | def forward( |
| | self, |
| | inputs_embeds=None, |
| | labels=None, |
| | **kwargs |
| | ): |
| | """ |
| | Forward pass of the MLP. |
| | |
| | Args: |
| | inputs_embeds (torch.FloatTensor): |
| | A 3D tensor of shape [batch_size, seq_length, embedding_size]. |
| | labels (torch.FloatTensor): |
| | Multi-hot labels for multi-label classification, shape [batch_size, num_labels]. |
| | |
| | Returns: |
| | SequenceClassifierOutput with fields: |
| | - loss (optional) |
| | - logits |
| | - hidden_states (None) |
| | - attentions (None) |
| | """ |
| | |
| | |
| | |
| | B, L, E = inputs_embeds.shape |
| | |
| | x = inputs_embeds.reshape(B, L * E) |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | logits = self.mlp(x) |
| |
|
| | loss = None |
| | if labels is not None: |
| | |
| | loss_fct = nn.BCEWithLogitsLoss() |
| | |
| | loss = loss_fct(logits, labels.float()) |
| |
|
| | return SequenceClassifierOutput( |
| | loss=loss, |
| | logits=logits, |
| | hidden_states=None, |
| | attentions=None |
| | ) |
| |
|
| |
|
| | AutoModelForSequenceClassification.register(MLPConfig, MLP) |
| |
|