citclass_small_91 / classifier.py
citclass's picture
Upload DistilBertClassifier
4037ba7
import torch
import torch.nn as nn
from transformers import PreTrainedModel, AutoModel, AutoConfig, PretrainedConfig
import transformers
class DistilBertClassifier(PreTrainedModel):
def __init__(self, bert_config, model_name='distilbert-base-uncased', tokenizer_len=30528, freeze_bert=False):
super().__init__(bert_config)
D_in, H, D_out = 256, 50, 91
self.bert = AutoModel.from_pretrained(model_name)
self.bert.resize_token_embeddings(tokenizer_len)
self.classifier = nn.Sequential(
nn.GELU(),
nn.Linear(self.bert.config.hidden_size, 300),
nn.GELU(),
nn.Dropout(0.05),
nn.Linear(300, 91)
)
if freeze_bert:
for param in self.bert.parameters():
param.requires_grad = False
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids,
attention_mask=attention_mask)
last_hidden_state_cls = outputs[0][:, 0, :]
logits = self.classifier(last_hidden_state_cls)
return logits