hoguseki's picture
Update model.py
4c89831 verified
import torch.nn as nn
from transformers import AutoModel
class BiasClassifier(nn.Module):
"""KoSBI ์ „์šฉ (15-label)"""
def __init__(self, model_name, num_labels):
super().__init__()
self.encoder = AutoModel.from_pretrained(model_name, attn_implementation="eager")
self.dropout = nn.Dropout(0.2)
self.classifier = nn.Linear(self.encoder.config.hidden_size, num_labels)
def forward(self, input_ids, attention_mask, output_attentions=False, **kwargs):
out = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions
)
logits = self.classifier(self.dropout(out.last_hidden_state[:, 0, :]))
attentions = out.attentions if output_attentions else None
return logits, attentions
class KMHaSClassifier(nn.Module):
"""K-MHaS ์ „์šฉ (8-label)"""
def __init__(self, model_name, num_labels=8, hidden_size=1024):
super().__init__()
self.encoder = AutoModel.from_pretrained(model_name, attn_implementation="eager")
self.classifier = nn.Sequential(
nn.Linear(hidden_size, 512),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(512, num_labels),
)
def forward(self, input_ids, attention_mask, **kwargs):
out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
return self.classifier(out.last_hidden_state[:, 0, :])
class FramingBiasClassifier(nn.Module):
"""BABE ์ „์šฉ (binary)"""
def __init__(self, model_name, hidden_size=1024):
super().__init__()
self.encoder = AutoModel.from_pretrained(model_name, attn_implementation="eager")
self.classifier = nn.Sequential(
nn.Linear(hidden_size, 256),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(256, 1),
)
def forward(self, input_ids, attention_mask, **kwargs):
out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
return self.classifier(out.last_hidden_state[:, 0, :])