GenrePrediction / genre_model.py
Stanford-TH's picture
Upload model
d5249c9 verified
raw
history blame contribute delete
716 Bytes
import torch
from transformers import PreTrainedModel
from transformers import AutoModel
from .genre_configuration import GenreConfig
class GenreModel(PreTrainedModel):
config_class = GenreConfig
def __init__(self,config):
super().__init__(config)
self.l1 = AutoModel.from_pretrained('bert-base-uncased')
self.l2 = torch.nn.Dropout(0.3)
self.l3 = torch.nn.Linear(768, 17)
def forward(self, input_ids, attention_mask, token_type_ids):
output_1= self.l1(input_ids, attention_mask = attention_mask, token_type_ids = token_type_ids)['pooler_output']
output_2 = self.l2(output_1)
output = self.l3(output_2)
return output