RussianDMRecognizer_dual / modeling_dual_head_dm.py
MariaOls's picture
Train & push dual-head DM (final model only) + README
434d57d verified
import torch, torch.nn as nn
from transformers import PreTrainedModel, BertModel, BertConfig
class DualHeadDMModel(PreTrainedModel):
config_class = BertConfig
base_model_prefix = "encoder"
def __init__(self, config, num_token_labels=3, num_seq_labels=2, seq_loss_weight=0.5):
super().__init__(config)
if not isinstance(config, BertConfig):
config = BertConfig.from_dict(config.to_dict())
self.hidden_size = config.hidden_size
self.encoder = BertModel(config)
self.dropout = nn.Dropout(0.1)
self.token_classifier = nn.Linear(self.hidden_size, num_token_labels)
self.seq_classifier = nn.Linear(self.hidden_size, num_seq_labels)
def forward(self, input_ids=None, attention_mask=None, candidate_mask=None, **kwargs):
out = self.encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
H = self.dropout(out.last_hidden_state)
logits_tok = self.token_classifier(H)
cls = H[:, 0, :]
logits_seq = self.seq_classifier(self.dropout(cls))
return {"logits": (logits_tok, logits_seq)}