File size: 2,979 Bytes
5ad7d0c 49d23e0 5ad7d0c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 | from abc import ABCMeta
import torch
from transformers.pytorch_utils import nn
import torch.nn.functional as F
from transformers import BertModel, BertForSequenceClassification, PreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers import BertConfig
from transformers import PretrainedConfig
class BertAttentionConfig(PretrainedConfig):
model_type = "bertAttentionForSequenceClassification" # Update the model type
def __init__(self,
num_classes=2,
hidden_size=768, # Update embed_dim to hidden_size
fc_hidden=128, # New parameter for FC layer
num_layers=12,
dropout_rate=0.1,
**kwargs):
super().__init__(**kwargs)
self.num_classes = num_classes
self.hidden_size = hidden_size # Update embed_dim to hidden_size
self.fc_hidden = fc_hidden # Assign FC layer hidden units
self.num_layers = num_layers
self.dropout_rate = dropout_rate
self.id2label = {
0: "fake",
1: "true",
}
self.label2id = {
"fake": 0,
"true": 1,
}
class BertAttentionForSequenceClassification(PreTrainedModel, metaclass=ABCMeta):
config_class = BertAttentionConfig # Use the appropriate BERT configuration class
def __init__(self, config):
super(BertAttentionForSequenceClassification, self).__init__(config)
self.num_classes = config.num_classes
self.embed_dim = config.hidden_size # Hidden size is the BERT embedding dimension
self.bert = BertModel.from_pretrained('bert-base-uncased', output_hidden_states=True, output_attentions=True)
print("BERT Model Loaded")
self.fc = nn.Linear(config.hidden_size, self.num_classes)
def forward(self, input_ids, attention_mask, token_type_ids, labels=None):
bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
hidden_states = bert_output.last_hidden_state # Use the last hidden state
# Apply self-attention (scaled dot-product attention)
attention_scores = torch.matmul(hidden_states, hidden_states.transpose(1, 2))
attention_scores = attention_scores / (self.embed_dim ** 0.5)
attention_probs = F.softmax(attention_scores, dim=-1)
attention_output = torch.matmul(attention_probs, hidden_states)
# Pool over the sequence length to get the final representation
pooled_output = torch.mean(attention_output, dim=1)
logits = self.fc(pooled_output)
loss = None
if labels is not None:
loss = F.cross_entropy(logits, labels)
out = SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=bert_output.hidden_states,
attentions=bert_output.attentions,
)
return out
|