| import torch |
| import torch.nn as nn |
|
|
| from typing import List |
| from dataclasses import dataclass |
|
|
| from transformers import PreTrainedModel |
| from transformers.file_utils import ModelOutput |
|
|
| from configuration_textcnn import TextCNNConfig |
|
|
|
|
| @dataclass |
| class TextCNNModelOutput(ModelOutput): |
| last_hidden_states: torch.FloatTensor = None |
| ngram_feature_maps: List[torch.FloatTensor] = None |
| |
| |
| @dataclass |
| class TextCNNSequenceClassificerOutput(ModelOutput): |
| loss: torch.FloatTensor = None |
| logits: torch.FloatTensor = None |
| last_hidden_states: torch.FloatTensor = None |
| ngram_feature_maps: List[torch.FloatTensor] = None |
| |
|
|
| class TextCNNPreTrainedModel(PreTrainedModel): |
| config_class = TextCNNConfig |
| base_model_prefix = "textcnn" |
| |
| def _init_weights(self, module): |
| return NotImplementedError |
| |
| @property |
| def dummy_inputs(self): |
| pad_token = self.config.pad_token_id |
| input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) |
| dummy_inputs = { |
| "attention_mask": input_ids.ne(pad_token), |
| "input_ids": input_ids, |
| } |
| return dummy_inputs |
| |
| |
| class TextCNNModel(TextCNNPreTrainedModel): |
| """ A style classifier Text-CNN """ |
| |
| def __init__(self, config): |
| super().__init__(config) |
| self.embeder = nn.Embedding(config.vocab_size, config.embed_dim) |
| self.convs = nn.ModuleList([ |
| nn.Conv2d(1, n, (f, config.embed_dim)) |
| for (n, f) in zip(config.num_filters, config.filter_sizes) |
| ]) |
| |
| def get_input_embeddings(self): |
| return self.embeder |
| |
| def set_input_embeddings(self, value): |
| self.embeder = value |
| |
| def forward(self, input_ids): |
| |
| x = self.embeder(input_ids).unsqueeze(1) |
| |
| convs = [torch.relu(conv(x)).squeeze(3) for conv in self.convs] |
| |
| pools = [torch.max_pool1d(conv, conv.size(2)).squeeze(2) for conv in convs] |
| |
| outputs = torch.cat(pools, 1) |
| |
| |
| return TextCNNModelOutput( |
| last_hidden_states=outputs, |
| ngram_feature_maps=pools, |
| ) |
| |
| |
| class TextCNNForSequenceClassification(TextCNNPreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
| self.feature_dim = sum(config.num_filters) |
| self.textcnn = TextCNNModel(config) |
| self.fc = nn.Sequential( |
| nn.Dropout(config.dropout), |
| nn.Linear(self.feature_dim, int(self.feature_dim / 2)), |
| nn.ReLU(), |
| nn.Linear(int(self.feature_dim / 2), config.num_labels) |
| ) |
| |
| def forward(self, input_ids, labels=None): |
| |
| |
| outputs = self.textcnn(input_ids) |
| |
| logits = self.fc(outputs[0]) |
| |
| loss = None |
| if labels is not None: |
| loss_fct = nn.CrossEntropyLoss() |
| loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) |
| |
| return TextCNNSequenceClassificerOutput( |
| loss=loss, |
| logits=logits, |
| last_hidden_states=outputs.last_hidden_states, |
| ngram_feature_maps=outputs.ngram_feature_maps, |
| ) |