Spaces:
Sleeping
Sleeping
File size: 2,115 Bytes
e9264aa ef1793f e9264aa 0a6055c e9264aa 0a6055c e9264aa 19cbff8 ef1793f 0a6055c ef1793f 0a6055c ef1793f 0a6055c ef1793f 4c89831 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 | import torch.nn as nn
from transformers import AutoModel
class BiasClassifier(nn.Module):
"""KoSBI 전용 (15-label)"""
def __init__(self, model_name, num_labels):
super().__init__()
self.encoder = AutoModel.from_pretrained(model_name, attn_implementation="eager")
self.dropout = nn.Dropout(0.2)
self.classifier = nn.Linear(self.encoder.config.hidden_size, num_labels)
def forward(self, input_ids, attention_mask, output_attentions=False, **kwargs):
out = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions
)
logits = self.classifier(self.dropout(out.last_hidden_state[:, 0, :]))
attentions = out.attentions if output_attentions else None
return logits, attentions
class KMHaSClassifier(nn.Module):
"""K-MHaS 전용 (8-label)"""
def __init__(self, model_name, num_labels=8, hidden_size=1024):
super().__init__()
self.encoder = AutoModel.from_pretrained(model_name, attn_implementation="eager")
self.classifier = nn.Sequential(
nn.Linear(hidden_size, 512),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(512, num_labels),
)
def forward(self, input_ids, attention_mask, **kwargs):
out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
return self.classifier(out.last_hidden_state[:, 0, :])
class FramingBiasClassifier(nn.Module):
"""BABE 전용 (binary)"""
def __init__(self, model_name, hidden_size=1024):
super().__init__()
self.encoder = AutoModel.from_pretrained(model_name, attn_implementation="eager")
self.classifier = nn.Sequential(
nn.Linear(hidden_size, 256),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(256, 1),
)
def forward(self, input_ids, attention_mask, **kwargs):
out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
return self.classifier(out.last_hidden_state[:, 0, :]) |