import torch import torch.nn as nn from transformers import PreTrainedModel, AutoModel, AutoConfig, PretrainedConfig import transformers class DistilBertClassifier(PreTrainedModel): def __init__(self, bert_config, model_name='distilbert-base-uncased', tokenizer_len=30528, freeze_bert=False): super().__init__(bert_config) D_in, H, D_out = 256, 50, 91 self.bert = AutoModel.from_pretrained(model_name) self.bert.resize_token_embeddings(tokenizer_len) self.classifier = nn.Sequential( nn.GELU(), nn.Linear(self.bert.config.hidden_size, 300), nn.GELU(), nn.Dropout(0.05), nn.Linear(300, 91) ) if freeze_bert: for param in self.bert.parameters(): param.requires_grad = False def forward(self, input_ids, attention_mask): outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) last_hidden_state_cls = outputs[0][:, 0, :] logits = self.classifier(last_hidden_state_cls) return logits