| import torch |
| import torch.nn as nn |
| from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence |
| from transformers import AutoModel, AutoConfig |
|
|
| class PredicateAwareSRL(nn.Module): |
| def __init__(self, |
| bert_name: str, |
| num_labels: int, |
| use_indicator: bool = True, |
| indicator_dim: int = 10, |
| lstm_hidden: int = 768, |
| mlp_hidden: int = 300, |
| dropout: float = 0.1, |
| use_distance: bool = True, |
| pos_dim: int = 50, |
| max_distance: int = 128): |
| super().__init__() |
| self.config = AutoConfig.from_pretrained(bert_name) |
| self.bert = AutoModel.from_pretrained(bert_name) |
| |
| self.use_indicator = use_indicator |
|
|
| |
| bert_dim = self.config.hidden_size |
| in_dim = bert_dim + (indicator_dim if use_indicator else 0) |
|
|
| |
| |
| |
|
|
| if use_indicator: |
| self.indicator_emb = nn.Embedding(2, indicator_dim) |
|
|
| self.use_distance = use_distance |
| self.max_distance = max_distance |
| if use_distance: |
| |
| self.pos_emb = nn.Embedding(2 * max_distance + 1, pos_dim) |
| in_dim += pos_dim |
|
|
| |
| self.bilstm = nn.LSTM( |
| input_size=in_dim, |
| hidden_size=lstm_hidden // 2, |
| num_layers=1, |
| dropout=0.0, |
| bidirectional=True, |
| batch_first=True |
| ) |
|
|
| self.dropout = nn.Dropout(dropout) |
|
|
| |
| self.classifier = nn.Sequential( |
| nn.Linear(lstm_hidden * 2, mlp_hidden), |
| nn.ReLU(), |
| nn.Dropout(dropout), |
| nn.Linear(mlp_hidden, num_labels) |
| ) |
|
|
| self.pad_label_id = -100 |
|
|
| def forward(self, |
| input_ids: torch.Tensor, |
| token_type_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| word_first_wp_fullidx: torch.Tensor, |
| sentence_mask: torch.Tensor, |
| sent_lens: torch.Tensor, |
| pred_word_idx: torch.Tensor, |
| indicator: torch.Tensor | None = None, |
| labels: torch.Tensor | None = None): |
|
|
| B, L = input_ids.size() |
| device = input_ids.device |
|
|
| |
| bert_out = self.bert( |
| |
| input_ids=input_ids, |
| token_type_ids=token_type_ids, |
| attention_mask=attention_mask |
| ) |
| H = bert_out.last_hidden_state |
|
|
| |
|
|
| |
| |
| |
| |
|
|
| gather_idx = word_first_wp_fullidx.clone() |
| gather_idx[gather_idx < 0] = 0 |
| gather_idx = gather_idx.unsqueeze(-1).expand(-1, -1, H.size(-1)) |
| H_words = torch.gather(H, dim=1, index=gather_idx) |
| H_words = H_words * sentence_mask.unsqueeze(-1) |
|
|
| |
| |
| |
| |
| |
| |
|
|
|
|
| X = H_words |
| if self.use_indicator and indicator is not None: |
| ind_emb = self.indicator_emb(indicator.clamp(0, 1)) |
| X = torch.cat([X, ind_emb], dim=-1) |
|
|
| |
| if self.use_distance: |
| |
| max_n = X.size(1) |
| positions = torch.arange(max_n, device=device).unsqueeze(0).expand(B, -1) |
| rel = positions - pred_word_idx.unsqueeze(1) |
| rel = rel.clamp(-self.max_distance, self.max_distance) + self.max_distance |
| pos_feats = self.pos_emb(rel) |
| X = torch.cat([X, pos_feats], dim=-1) |
|
|
| |
| lengths = sent_lens.detach().cpu() |
| packed = pack_padded_sequence(X, lengths=lengths, batch_first=True, enforce_sorted=False) |
| G_packed, _ = self.bilstm(packed) |
| G, _ = pad_packed_sequence(G_packed, batch_first=True) |
| G = self.dropout(G) |
|
|
| |
| batch_idx = torch.arange(B, device=device) |
| gp = G[batch_idx, pred_word_idx.clamp(min=0), :] |
| gp_expanded = gp.unsqueeze(1).expand(-1, G.size(1), -1) |
|
|
| logits = self.classifier(torch.cat([G, gp_expanded], dim=-1)) |
|
|
| loss = None |
| if labels is not None: |
| loss_fct = nn.CrossEntropyLoss(ignore_index=self.pad_label_id) |
| loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) |
|
|
| return logits, loss |
|
|