|
|
"""Custom model definition for boilerplate detection""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import PreTrainedModel, PretrainedConfig, AutoModel |
|
|
from transformers.modeling_outputs import SequenceClassifierOutput |
|
|
|
|
|
|
|
|
class BoilerplateConfig(PretrainedConfig): |
|
|
model_type = "boilerplate" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
base_model_name="sentence-transformers/all-mpnet-base-v2", |
|
|
num_labels=2, |
|
|
hidden_size=768, |
|
|
classifier_dims=[16, 8], |
|
|
dropout=0.05, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__(num_labels=num_labels, **kwargs) |
|
|
self.base_model_name = base_model_name |
|
|
self.hidden_size = hidden_size |
|
|
self.classifier_dims = classifier_dims |
|
|
self.dropout = dropout |
|
|
self.id2label = {0: "NOT_BOILERPLATE", 1: "BOILERPLATE"} |
|
|
self.label2id = {"NOT_BOILERPLATE": 0, "BOILERPLATE": 1} |
|
|
|
|
|
|
|
|
class BoilerplateDetector(PreTrainedModel): |
|
|
config_class = BoilerplateConfig |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.transformer = AutoModel.from_pretrained(config.base_model_name) |
|
|
for param in self.transformer.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
self.dropout = nn.Dropout(config.dropout) |
|
|
self.fc1 = nn.Linear(config.hidden_size, config.classifier_dims[0]) |
|
|
self.fc2 = nn.Linear(config.classifier_dims[0], config.classifier_dims[1]) |
|
|
self.fc3 = nn.Linear(config.classifier_dims[1], config.num_labels) |
|
|
|
|
|
self.init_weights() |
|
|
|
|
|
def mean_pooling(self, model_output, attention_mask): |
|
|
token_embeddings = model_output[0] |
|
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
|
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( |
|
|
input_mask_expanded.sum(1), min=1e-9 |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids=None, |
|
|
attention_mask=None, |
|
|
labels=None, |
|
|
return_dict=None, |
|
|
**kwargs |
|
|
): |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
outputs = self.transformer( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
return_dict=True, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
sentence_embeddings = self.mean_pooling(outputs, attention_mask) |
|
|
|
|
|
|
|
|
x = torch.nn.functional.relu(self.fc1(sentence_embeddings)) |
|
|
if self.training: |
|
|
x = self.dropout(x) |
|
|
x = torch.nn.functional.relu(self.fc2(x)) |
|
|
if self.training: |
|
|
x = self.dropout(x) |
|
|
logits = self.fc3(x) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss_fct = nn.CrossEntropyLoss() |
|
|
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) |
|
|
|
|
|
if not return_dict: |
|
|
output = (logits,) + outputs[2:] |
|
|
return ((loss,) + output) if loss is not None else output |
|
|
|
|
|
return SequenceClassifierOutput( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
) |