yeomtong commited on
Commit
c9edf6c
·
verified ·
1 Parent(s): fdbdd29

Upload SRL_BERT_model.py

Browse files
Files changed (1) hide show
  1. SRL_BERT_model.py +140 -0
SRL_BERT_model.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
4
+ from transformers import AutoModel, AutoConfig
5
+
6
+ class PredicateAwareSRL(nn.Module):
7
+ def __init__(self,
8
+ bert_name: str,
9
+ num_labels: int,
10
+ use_indicator: bool = True,
11
+ indicator_dim: int = 10, # CHANGED: 10-dim predicate indicator
12
+ lstm_hidden: int = 768, # CHANGED: BiLSTM hidden size = 768 (bidirectional)
13
+ mlp_hidden: int = 300, # CHANGED: MLP hidden size = 300
14
+ dropout: float = 0.1,
15
+ use_distance: bool = True, # NEW: enable relative position (distance) embeddings
16
+ pos_dim: int = 50, # NEW: size of position embedding (random init, trainable)
17
+ max_distance: int = 128): # NEW: clamp |i - p| to this range for bucketing
18
+ super().__init__()
19
+ self.config = AutoConfig.from_pretrained(bert_name)
20
+ self.bert = AutoModel.from_pretrained(bert_name)
21
+ self.use_indicator = use_indicator
22
+
23
+ # --- Input dim to BiLSTM = BERT_dim + (indicator_dim) + (pos_dim)
24
+ bert_dim = self.config.hidden_size
25
+ in_dim = bert_dim + (indicator_dim if use_indicator else 0)
26
+
27
+ # Two rows which indicate 0 not predicate 1 is predicate, so need to 2 embedding (rows)
28
+ # num_embeddings (int) – size of the dictionary of embeddings
29
+ # embedding_dim (int) – the size of each embedding vector
30
+
31
+ if use_indicator:
32
+ self.indicator_emb = nn.Embedding(2, indicator_dim) # 0/1 → 10-dim (random init, trainable) # CHANGED
33
+
34
+ self.use_distance = use_distance # NEW
35
+ self.max_distance = max_distance # NEW
36
+ if use_distance:
37
+ # Distance buckets: [-max_distance .. +max_distance] → indices [0 .. 2*max_distance]
38
+ self.pos_emb = nn.Embedding(2 * max_distance + 1, pos_dim) # NEW (random init, trainable)
39
+ in_dim += pos_dim # NEW
40
+
41
+ # BiLSTM (bidirectional): total output dim = lstm_hidden
42
+ self.bilstm = nn.LSTM(
43
+ input_size=in_dim,
44
+ hidden_size=lstm_hidden // 2, # bi → half per direction
45
+ num_layers=1,
46
+ dropout=0.0,
47
+ bidirectional=True,
48
+ batch_first=True
49
+ )
50
+
51
+ self.dropout = nn.Dropout(dropout)
52
+
53
+ # Classifier: concat(g_i, gp) so input dim = 2 * lstm_hidden
54
+ self.classifier = nn.Sequential(
55
+ nn.Linear(lstm_hidden * 2, mlp_hidden), # CHANGED (mlp_hidden=300)
56
+ nn.ReLU(),
57
+ nn.Dropout(dropout),
58
+ nn.Linear(mlp_hidden, num_labels)
59
+ )
60
+
61
+ self.pad_label_id = -100
62
+
63
+ def forward(self,
64
+ input_ids: torch.Tensor, # [B, L]
65
+ token_type_ids: torch.Tensor, # [B, L]
66
+ attention_mask: torch.Tensor, # [B, L]
67
+ word_first_wp_fullidx: torch.Tensor, # [B, max_n] (positions in full seq; -1 for pad)
68
+ sentence_mask: torch.Tensor, # [B, max_n] (bool)
69
+ sent_lens: torch.Tensor, # [B]
70
+ pred_word_idx: torch.Tensor, # [B]
71
+ indicator: torch.Tensor | None = None, # [B, max_n] 0/1
72
+ labels: torch.Tensor | None = None): # [B, max_n]
73
+
74
+ B, L = input_ids.size()
75
+ device = input_ids.device
76
+
77
+ # ---- BERT encoder
78
+ bert_out = self.bert(
79
+ input_ids=input_ids,
80
+ token_type_ids=token_type_ids,
81
+ attention_mask=attention_mask
82
+ )
83
+ H = bert_out.last_hidden_state # [B, L, D]
84
+
85
+ # ---- Subword → word pooling (first subword)
86
+
87
+ # Gather sentence word-level representations by taking FIRST subtoken hidden per word
88
+ # Prepare indices (replace -1 with 0 to avoid gather OOB; we'll mask later)
89
+ # This process is required to feed word level to predict BIO and role per word
90
+ #.clone is for deep copy won't change original data
91
+
92
+ gather_idx = word_first_wp_fullidx.clone()
93
+ gather_idx[gather_idx < 0] = 0
94
+ gather_idx = gather_idx.unsqueeze(-1).expand(-1, -1, H.size(-1)) # [B, max_n, D]
95
+ H_words = torch.gather(H, dim=1, index=gather_idx) # [B, max_n, D]
96
+ H_words = H_words * sentence_mask.unsqueeze(-1) # zero out pads
97
+
98
+ # ---- Concatenate predicate indicator (0/1 → emb)
99
+ # word_first_wp_fullidx: [1, 2, 3, -1, -1]
100
+ # gather_idx after clamp: [1, 2, 3, 0, 0] # 0 points to [CLS], just a placeholder
101
+ # H_words = gather(H, ...) # grabs vectors at positions 1,2,3,0,0
102
+ # sentence_mask: [1, 1, 1, 0, 0]
103
+ # H_words *= mask → [vec1, vec2, vec3, 0, 0] # padded slots zeroed out
104
+
105
+
106
+ X = H_words
107
+ if self.use_indicator and indicator is not None:
108
+ ind_emb = self.indicator_emb(indicator.clamp(0, 1)) # [B, max_n, 10] # CHANGED
109
+ X = torch.cat([X, ind_emb], dim=-1)
110
+
111
+ # ---- NEW: Relative position (distance-to-predicate) embeddings
112
+ if self.use_distance:
113
+ # positions: 0..max_n-1 per sentence
114
+ max_n = X.size(1)
115
+ positions = torch.arange(max_n, device=device).unsqueeze(0).expand(B, -1) # [B, max_n]
116
+ rel = positions - pred_word_idx.unsqueeze(1) # [B, max_n], can be <0
117
+ rel = rel.clamp(-self.max_distance, self.max_distance) + self.max_distance # shift to [0 .. 2*max_distance]
118
+ pos_feats = self.pos_emb(rel) # [B, max_n, pos_dim] # NEW
119
+ X = torch.cat([X, pos_feats], dim=-1) # [B, max_n, in_dim] # NEW
120
+
121
+ # ---- BiLSTM (packed)
122
+ lengths = sent_lens.detach().cpu()
123
+ packed = pack_padded_sequence(X, lengths=lengths, batch_first=True, enforce_sorted=False)
124
+ G_packed, _ = self.bilstm(packed)
125
+ G, _ = pad_packed_sequence(G_packed, batch_first=True) # [B, max_n, lstm_hidden]
126
+ G = self.dropout(G)
127
+
128
+ # ---- Predicate hidden (word-level) and concat to every position
129
+ batch_idx = torch.arange(B, device=device)
130
+ gp = G[batch_idx, pred_word_idx.clamp(min=0), :] # [B, lstm_hidden]
131
+ gp_expanded = gp.unsqueeze(1).expand(-1, G.size(1), -1) # [B, max_n, lstm_hidden]
132
+
133
+ logits = self.classifier(torch.cat([G, gp_expanded], dim=-1)) # [B, max_n, num_labels]
134
+
135
+ loss = None
136
+ if labels is not None:
137
+ loss_fct = nn.CrossEntropyLoss(ignore_index=self.pad_label_id)
138
+ loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
139
+
140
+ return logits, loss