| import torch, torch.nn as nn | |
| from transformers import PreTrainedModel, BertModel, BertConfig | |
| class DualHeadDMModel(PreTrainedModel): | |
| config_class = BertConfig | |
| base_model_prefix = "encoder" | |
| def __init__(self, config, num_token_labels=3, num_seq_labels=2, seq_loss_weight=0.5): | |
| super().__init__(config) | |
| if not isinstance(config, BertConfig): | |
| config = BertConfig.from_dict(config.to_dict()) | |
| self.hidden_size = config.hidden_size | |
| self.encoder = BertModel(config) | |
| self.dropout = nn.Dropout(0.1) | |
| self.token_classifier = nn.Linear(self.hidden_size, num_token_labels) | |
| self.seq_classifier = nn.Linear(self.hidden_size, num_seq_labels) | |
| def forward(self, input_ids=None, attention_mask=None, candidate_mask=None, **kwargs): | |
| out = self.encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True) | |
| H = self.dropout(out.last_hidden_state) | |
| logits_tok = self.token_classifier(H) | |
| cls = H[:, 0, :] | |
| logits_seq = self.seq_classifier(self.dropout(cls)) | |
| return {"logits": (logits_tok, logits_seq)} | |