predict-ai-abstract-api / custom_model.py
jmt-r's picture
Create custom_model.py
fd363e5 verified
raw
history blame contribute delete
979 Bytes
import torch
from torch import nn
from transformers import BertPreTrainedModel, BertModel
class MyBERTClassifier(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.bert = BertModel(config)
hidden_size = config.hidden_size
num_labels = config.num_labels
self.classifier = nn.Sequential(
nn.Linear(hidden_size, hidden_size // 2),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_size // 2, num_labels)
)
self.post_init()
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, **kwargs):
outputs = self.bert(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
return_dict=True,
**kwargs
)
pooled_output = outputs.pooler_output
logits = self.classifier(pooled_output)
return logits