import torch import torch.nn as nn from transformers import AutoTokenizer, AutoModelForSequenceClassification class MyTinyBERT(nn.Module): def __init__(self): super().__init__() self.bert = AutoModelForSequenceClassification.from_pretrained('cointegrated/rubert-tiny-toxicity') for param in self.bert.parameters(): param.requires_grad = False self.linear = nn.Sequential( nn.Linear(5, 256), nn.Sigmoid(), nn.Dropout(), nn.Linear(256, 512), nn.Sigmoid(), nn.Dropout(p=0.4), nn.Linear(512, 1) ) def forward(self, x): bert_out = self.bert(x[0], attention_mask=x[1]) normed_bert_out = nn.functional.normalize(bert_out.logits) out = self.linear(normed_bert_out) return out