import torch, torch.nn as nn from transformers import PreTrainedModel, BertModel, BertConfig class DualHeadDMModel(PreTrainedModel): config_class = BertConfig base_model_prefix = "encoder" def __init__(self, config, num_token_labels=3, num_seq_labels=2, seq_loss_weight=0.5): super().__init__(config) if not isinstance(config, BertConfig): config = BertConfig.from_dict(config.to_dict()) self.hidden_size = config.hidden_size self.encoder = BertModel(config) self.dropout = nn.Dropout(0.1) self.token_classifier = nn.Linear(self.hidden_size, num_token_labels) self.seq_classifier = nn.Linear(self.hidden_size, num_seq_labels) def forward(self, input_ids=None, attention_mask=None, candidate_mask=None, **kwargs): out = self.encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True) H = self.dropout(out.last_hidden_state) logits_tok = self.token_classifier(H) cls = H[:, 0, :] logits_seq = self.seq_classifier(self.dropout(cls)) return {"logits": (logits_tok, logits_seq)}