import torch from transformers import AutoModel, PreTrainedModel class ElectraForMultiLabel(PreTrainedModel): def __init__(self, config, pos_weight=None): super().__init__(config) self.num_labels = config.num_labels self.enc = AutoModel.from_pretrained(config._name_or_path) hid = self.enc.config.hidden_size self.dropout = torch.nn.Dropout(0.2) self.head = torch.nn.Linear(hid, self.num_labels) # pos_weight 등록 (GPU 이동 자동 지원) if pos_weight is not None: self.register_buffer("pos_weight", pos_weight) else: self.pos_weight = None self.post_init() def forward(self, input_ids=None, attention_mask=None, labels=None): # [CLS] 토큰 임베딩 추출 h = self.enc(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0] h = self.dropout(h) logits = self.head(h) loss = None if labels is not None: if self.pos_weight is not None: loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=self.pos_weight) else: loss_fn = torch.nn.BCEWithLogitsLoss() loss = loss_fn(logits, labels) return {"loss": loss, "logits": logits}