| import torch | |
| import torch.nn as nn | |
| from transformers import PreTrainedModel, AutoModel, AutoConfig, PretrainedConfig | |
| import transformers | |
| class DistilBertClassifier(PreTrainedModel): | |
| def __init__(self, bert_config, model_name='distilbert-base-uncased', tokenizer_len=30528, freeze_bert=False): | |
| super().__init__(bert_config) | |
| D_in, H, D_out = 256, 50, 91 | |
| self.bert = AutoModel.from_pretrained(model_name) | |
| self.bert.resize_token_embeddings(tokenizer_len) | |
| self.classifier = nn.Sequential( | |
| nn.GELU(), | |
| nn.Linear(self.bert.config.hidden_size, 300), | |
| nn.GELU(), | |
| nn.Dropout(0.05), | |
| nn.Linear(300, 91) | |
| ) | |
| if freeze_bert: | |
| for param in self.bert.parameters(): | |
| param.requires_grad = False | |
| def forward(self, input_ids, attention_mask): | |
| outputs = self.bert(input_ids=input_ids, | |
| attention_mask=attention_mask) | |
| last_hidden_state_cls = outputs[0][:, 0, :] | |
| logits = self.classifier(last_hidden_state_cls) | |
| return logits |