import torch.nn as nn from transformers import AutoModel class BiasClassifier(nn.Module): """KoSBI 전용 (15-label)""" def __init__(self, model_name, num_labels): super().__init__() self.encoder = AutoModel.from_pretrained(model_name, attn_implementation="eager") self.dropout = nn.Dropout(0.2) self.classifier = nn.Linear(self.encoder.config.hidden_size, num_labels) def forward(self, input_ids, attention_mask, output_attentions=False, **kwargs): out = self.encoder( input_ids=input_ids, attention_mask=attention_mask, output_attentions=output_attentions ) logits = self.classifier(self.dropout(out.last_hidden_state[:, 0, :])) attentions = out.attentions if output_attentions else None return logits, attentions class KMHaSClassifier(nn.Module): """K-MHaS 전용 (8-label)""" def __init__(self, model_name, num_labels=8, hidden_size=1024): super().__init__() self.encoder = AutoModel.from_pretrained(model_name, attn_implementation="eager") self.classifier = nn.Sequential( nn.Linear(hidden_size, 512), nn.GELU(), nn.Dropout(0.1), nn.Linear(512, num_labels), ) def forward(self, input_ids, attention_mask, **kwargs): out = self.encoder(input_ids=input_ids, attention_mask=attention_mask) return self.classifier(out.last_hidden_state[:, 0, :]) class FramingBiasClassifier(nn.Module): """BABE 전용 (binary)""" def __init__(self, model_name, hidden_size=1024): super().__init__() self.encoder = AutoModel.from_pretrained(model_name, attn_implementation="eager") self.classifier = nn.Sequential( nn.Linear(hidden_size, 256), nn.GELU(), nn.Dropout(0.1), nn.Linear(256, 1), ) def forward(self, input_ids, attention_mask, **kwargs): out = self.encoder(input_ids=input_ids, attention_mask=attention_mask) return self.classifier(out.last_hidden_state[:, 0, :])