| | import torch |
| | from transformers import AutoModel, PreTrainedModel |
| |
|
| | class ElectraForMultiLabel(PreTrainedModel): |
| | def __init__(self, config, pos_weight=None): |
| | super().__init__(config) |
| | self.num_labels = config.num_labels |
| | self.enc = AutoModel.from_pretrained(config._name_or_path) |
| | hid = self.enc.config.hidden_size |
| | self.dropout = torch.nn.Dropout(0.2) |
| | self.head = torch.nn.Linear(hid, self.num_labels) |
| | |
| | if pos_weight is not None: |
| | self.register_buffer("pos_weight", pos_weight) |
| | else: |
| | self.pos_weight = None |
| | self.post_init() |
| |
|
| | def forward(self, input_ids=None, attention_mask=None, labels=None): |
| | |
| | h = self.enc(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0] |
| | h = self.dropout(h) |
| | logits = self.head(h) |
| |
|
| | loss = None |
| | if labels is not None: |
| | if self.pos_weight is not None: |
| | loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=self.pos_weight) |
| | else: |
| | loss_fn = torch.nn.BCEWithLogitsLoss() |
| | loss = loss_fn(logits, labels) |
| |
|
| | return {"loss": loss, "logits": logits} |
| |
|