koelectra-absa-steam-classifier / electra_multilabel.py
Wing4's picture
Upload electra_multilabel.py with huggingface_hub
1712f62 verified
import torch
from transformers import AutoModel, PreTrainedModel
class ElectraForMultiLabel(PreTrainedModel):
def __init__(self, config, pos_weight=None):
super().__init__(config)
self.num_labels = config.num_labels
self.enc = AutoModel.from_pretrained(config._name_or_path)
hid = self.enc.config.hidden_size
self.dropout = torch.nn.Dropout(0.2)
self.head = torch.nn.Linear(hid, self.num_labels)
# pos_weight ๋“ฑ๋ก (GPU ์ด๋™ ์ž๋™ ์ง€์›)
if pos_weight is not None:
self.register_buffer("pos_weight", pos_weight)
else:
self.pos_weight = None
self.post_init()
def forward(self, input_ids=None, attention_mask=None, labels=None):
# [CLS] ํ† ํฐ ์ž„๋ฒ ๋”ฉ ์ถ”์ถœ
h = self.enc(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0]
h = self.dropout(h)
logits = self.head(h)
loss = None
if labels is not None:
if self.pos_weight is not None:
loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=self.pos_weight)
else:
loss_fn = torch.nn.BCEWithLogitsLoss()
loss = loss_fn(logits, labels)
return {"loss": loss, "logits": logits}