File size: 1,456 Bytes
dec266f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
import torch
import torch.nn as nn
from transformers import AutoModel
from typing import Dict, Tuple
class ToxicClassifier(nn.Module):
def __init__(self, num_classes: int = 6, dropout: float = 0.3):
super(ToxicClassifier, self).__init__()
# BERT base model - freeze some layers to prevent overfitting
self.bert = AutoModel.from_pretrained('bert-base-uncased')
# Freeze the first 8 layers of BERT
for param in list(self.bert.parameters())[:-8]:
param.requires_grad = False
# Simplified architecture focusing on BERT's power
self.dropout = nn.Dropout(dropout)
self.classifier = nn.Linear(768, num_classes) # 768 is BERT's hidden size
# Initialize the classifier weights properly
torch.nn.init.xavier_uniform_(self.classifier.weight)
self.classifier.bias.data.fill_(0.0)
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
# Get BERT embeddings
outputs = self.bert(input_ids, attention_mask=attention_mask)
pooled_output = outputs.pooler_output # [batch_size, 768]
# Apply dropout and classification
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
return logits # Return logits directly, BCEWithLogitsLoss will handle the sigmoid |