ROBERTA_MODEL / ROBERTA.py
jungseok's picture
add model
d4e2a17
raw
history blame contribute delete
615 Bytes
from transformers import PreTrainedModel
import torch.nn as nn
class CustomModel(PreTrainedModel):
def __init__(self,config, base_model, n_label):
super().__init__(config)
self.model = base_model
### New layers:
self.linear = nn.Linear(1024, n_label)
self.log_softmax = nn.LogSoftmax(dim= 1)
def forward(self,X):
bert_output = self.model(**X)
# sequence_output has the following shape: (batch_size, sequence_length, 1024)
output = self.linear(bert_output[1]) ##
output = self.log_softmax(output)
return output