Upload SRL_BERT_model.py
Browse files- SRL_BERT_model.py +140 -0
SRL_BERT_model.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
| 4 |
+
from transformers import AutoModel, AutoConfig
|
| 5 |
+
|
| 6 |
+
class PredicateAwareSRL(nn.Module):
|
| 7 |
+
def __init__(self,
|
| 8 |
+
bert_name: str,
|
| 9 |
+
num_labels: int,
|
| 10 |
+
use_indicator: bool = True,
|
| 11 |
+
indicator_dim: int = 10, # CHANGED: 10-dim predicate indicator
|
| 12 |
+
lstm_hidden: int = 768, # CHANGED: BiLSTM hidden size = 768 (bidirectional)
|
| 13 |
+
mlp_hidden: int = 300, # CHANGED: MLP hidden size = 300
|
| 14 |
+
dropout: float = 0.1,
|
| 15 |
+
use_distance: bool = True, # NEW: enable relative position (distance) embeddings
|
| 16 |
+
pos_dim: int = 50, # NEW: size of position embedding (random init, trainable)
|
| 17 |
+
max_distance: int = 128): # NEW: clamp |i - p| to this range for bucketing
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.config = AutoConfig.from_pretrained(bert_name)
|
| 20 |
+
self.bert = AutoModel.from_pretrained(bert_name)
|
| 21 |
+
self.use_indicator = use_indicator
|
| 22 |
+
|
| 23 |
+
# --- Input dim to BiLSTM = BERT_dim + (indicator_dim) + (pos_dim)
|
| 24 |
+
bert_dim = self.config.hidden_size
|
| 25 |
+
in_dim = bert_dim + (indicator_dim if use_indicator else 0)
|
| 26 |
+
|
| 27 |
+
# Two rows which indicate 0 not predicate 1 is predicate, so need to 2 embedding (rows)
|
| 28 |
+
# num_embeddings (int) – size of the dictionary of embeddings
|
| 29 |
+
# embedding_dim (int) – the size of each embedding vector
|
| 30 |
+
|
| 31 |
+
if use_indicator:
|
| 32 |
+
self.indicator_emb = nn.Embedding(2, indicator_dim) # 0/1 → 10-dim (random init, trainable) # CHANGED
|
| 33 |
+
|
| 34 |
+
self.use_distance = use_distance # NEW
|
| 35 |
+
self.max_distance = max_distance # NEW
|
| 36 |
+
if use_distance:
|
| 37 |
+
# Distance buckets: [-max_distance .. +max_distance] → indices [0 .. 2*max_distance]
|
| 38 |
+
self.pos_emb = nn.Embedding(2 * max_distance + 1, pos_dim) # NEW (random init, trainable)
|
| 39 |
+
in_dim += pos_dim # NEW
|
| 40 |
+
|
| 41 |
+
# BiLSTM (bidirectional): total output dim = lstm_hidden
|
| 42 |
+
self.bilstm = nn.LSTM(
|
| 43 |
+
input_size=in_dim,
|
| 44 |
+
hidden_size=lstm_hidden // 2, # bi → half per direction
|
| 45 |
+
num_layers=1,
|
| 46 |
+
dropout=0.0,
|
| 47 |
+
bidirectional=True,
|
| 48 |
+
batch_first=True
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
self.dropout = nn.Dropout(dropout)
|
| 52 |
+
|
| 53 |
+
# Classifier: concat(g_i, gp) so input dim = 2 * lstm_hidden
|
| 54 |
+
self.classifier = nn.Sequential(
|
| 55 |
+
nn.Linear(lstm_hidden * 2, mlp_hidden), # CHANGED (mlp_hidden=300)
|
| 56 |
+
nn.ReLU(),
|
| 57 |
+
nn.Dropout(dropout),
|
| 58 |
+
nn.Linear(mlp_hidden, num_labels)
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
self.pad_label_id = -100
|
| 62 |
+
|
| 63 |
+
def forward(self,
|
| 64 |
+
input_ids: torch.Tensor, # [B, L]
|
| 65 |
+
token_type_ids: torch.Tensor, # [B, L]
|
| 66 |
+
attention_mask: torch.Tensor, # [B, L]
|
| 67 |
+
word_first_wp_fullidx: torch.Tensor, # [B, max_n] (positions in full seq; -1 for pad)
|
| 68 |
+
sentence_mask: torch.Tensor, # [B, max_n] (bool)
|
| 69 |
+
sent_lens: torch.Tensor, # [B]
|
| 70 |
+
pred_word_idx: torch.Tensor, # [B]
|
| 71 |
+
indicator: torch.Tensor | None = None, # [B, max_n] 0/1
|
| 72 |
+
labels: torch.Tensor | None = None): # [B, max_n]
|
| 73 |
+
|
| 74 |
+
B, L = input_ids.size()
|
| 75 |
+
device = input_ids.device
|
| 76 |
+
|
| 77 |
+
# ---- BERT encoder
|
| 78 |
+
bert_out = self.bert(
|
| 79 |
+
input_ids=input_ids,
|
| 80 |
+
token_type_ids=token_type_ids,
|
| 81 |
+
attention_mask=attention_mask
|
| 82 |
+
)
|
| 83 |
+
H = bert_out.last_hidden_state # [B, L, D]
|
| 84 |
+
|
| 85 |
+
# ---- Subword → word pooling (first subword)
|
| 86 |
+
|
| 87 |
+
# Gather sentence word-level representations by taking FIRST subtoken hidden per word
|
| 88 |
+
# Prepare indices (replace -1 with 0 to avoid gather OOB; we'll mask later)
|
| 89 |
+
# This process is required to feed word level to predict BIO and role per word
|
| 90 |
+
#.clone is for deep copy won't change original data
|
| 91 |
+
|
| 92 |
+
gather_idx = word_first_wp_fullidx.clone()
|
| 93 |
+
gather_idx[gather_idx < 0] = 0
|
| 94 |
+
gather_idx = gather_idx.unsqueeze(-1).expand(-1, -1, H.size(-1)) # [B, max_n, D]
|
| 95 |
+
H_words = torch.gather(H, dim=1, index=gather_idx) # [B, max_n, D]
|
| 96 |
+
H_words = H_words * sentence_mask.unsqueeze(-1) # zero out pads
|
| 97 |
+
|
| 98 |
+
# ---- Concatenate predicate indicator (0/1 → emb)
|
| 99 |
+
# word_first_wp_fullidx: [1, 2, 3, -1, -1]
|
| 100 |
+
# gather_idx after clamp: [1, 2, 3, 0, 0] # 0 points to [CLS], just a placeholder
|
| 101 |
+
# H_words = gather(H, ...) # grabs vectors at positions 1,2,3,0,0
|
| 102 |
+
# sentence_mask: [1, 1, 1, 0, 0]
|
| 103 |
+
# H_words *= mask → [vec1, vec2, vec3, 0, 0] # padded slots zeroed out
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
X = H_words
|
| 107 |
+
if self.use_indicator and indicator is not None:
|
| 108 |
+
ind_emb = self.indicator_emb(indicator.clamp(0, 1)) # [B, max_n, 10] # CHANGED
|
| 109 |
+
X = torch.cat([X, ind_emb], dim=-1)
|
| 110 |
+
|
| 111 |
+
# ---- NEW: Relative position (distance-to-predicate) embeddings
|
| 112 |
+
if self.use_distance:
|
| 113 |
+
# positions: 0..max_n-1 per sentence
|
| 114 |
+
max_n = X.size(1)
|
| 115 |
+
positions = torch.arange(max_n, device=device).unsqueeze(0).expand(B, -1) # [B, max_n]
|
| 116 |
+
rel = positions - pred_word_idx.unsqueeze(1) # [B, max_n], can be <0
|
| 117 |
+
rel = rel.clamp(-self.max_distance, self.max_distance) + self.max_distance # shift to [0 .. 2*max_distance]
|
| 118 |
+
pos_feats = self.pos_emb(rel) # [B, max_n, pos_dim] # NEW
|
| 119 |
+
X = torch.cat([X, pos_feats], dim=-1) # [B, max_n, in_dim] # NEW
|
| 120 |
+
|
| 121 |
+
# ---- BiLSTM (packed)
|
| 122 |
+
lengths = sent_lens.detach().cpu()
|
| 123 |
+
packed = pack_padded_sequence(X, lengths=lengths, batch_first=True, enforce_sorted=False)
|
| 124 |
+
G_packed, _ = self.bilstm(packed)
|
| 125 |
+
G, _ = pad_packed_sequence(G_packed, batch_first=True) # [B, max_n, lstm_hidden]
|
| 126 |
+
G = self.dropout(G)
|
| 127 |
+
|
| 128 |
+
# ---- Predicate hidden (word-level) and concat to every position
|
| 129 |
+
batch_idx = torch.arange(B, device=device)
|
| 130 |
+
gp = G[batch_idx, pred_word_idx.clamp(min=0), :] # [B, lstm_hidden]
|
| 131 |
+
gp_expanded = gp.unsqueeze(1).expand(-1, G.size(1), -1) # [B, max_n, lstm_hidden]
|
| 132 |
+
|
| 133 |
+
logits = self.classifier(torch.cat([G, gp_expanded], dim=-1)) # [B, max_n, num_labels]
|
| 134 |
+
|
| 135 |
+
loss = None
|
| 136 |
+
if labels is not None:
|
| 137 |
+
loss_fct = nn.CrossEntropyLoss(ignore_index=self.pad_label_id)
|
| 138 |
+
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
|
| 139 |
+
|
| 140 |
+
return logits, loss
|