yeomtong commited on
Commit
a774733
·
verified ·
1 Parent(s): bc11ff7

Upload 5 files

Browse files
Files changed (5) hide show
  1. data_prep.py +266 -0
  2. model.py +140 -0
  3. predicator.py +141 -0
  4. testing.py +80 -0
  5. training.py +182 -0
data_prep.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Optional
2
+ from torch.utils.data import Dataset
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModel, AutoConfig
5
+ from torch.utils.data import Dataset, DataLoader
6
+ import torch.nn as nn
7
+ from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
8
+ from sklearn.metrics import f1_score
9
+ import json
10
+
11
+
12
+
13
+ #Create data instance, words: tokenized word list, predicte_word_idx: index for predicte, labels: Semantic roles
14
+ !@dataclass
15
+ class SRLSample():
16
+ def __init__(self, words: List[str], predicate_word_idx: int, labels: List[str], predicate_form: Optional[str] = None):
17
+ self.words = words
18
+ self.predicate_word_idx = predicate_word_idx
19
+ self.labels = labels
20
+ self.predicate_form = predicate_form
21
+
22
+
23
+ #To Leah: SRL Sample is object for each dataset so we need another code for each instance(words, predicate_word_idx, labels) into list of SRLSample objects
24
+
25
+ def create_srl_samples(data_path):
26
+ samples = []
27
+ with open(data_path, 'r', encoding='utf-8') as f:
28
+ for line in f:
29
+ data = json.loads(line)
30
+ samples.append(SRLSample(**data))
31
+
32
+ return samples
33
+
34
+
35
+ #Example
36
+
37
+ #if __name__ == '__main__'
38
+
39
+ # data_class_train = create_srl_samples('/content/drive/MyDrive/Dissertation/srl_synthetic_100.jsonl')
40
+
41
+ # data_class_dev = create_srl_samples('/content/drive/MyDrive/Dissertation/srl_synthetic_dev_10.jsonl')
42
+
43
+ # data_class_test = create_srl_samples('/content/drive/MyDrive/Dissertation/srl_synthetic_test_10.jsonl')
44
+
45
+
46
+ class SRLDataset(Dataset):
47
+ """
48
+ Expects samples at WORD-level. We build BERT inputs as:
49
+ [CLS] <sentence (wordpiece)> [SEP] <predicate (wordpiece)> [SEP]
50
+ We keep:
51
+ - wordpiece indices for each word's FIRST subtoken (to pool BERT to word level)
52
+ - sentence lengths
53
+ - predicate's WORD index within the sentence (for gp from BiLSTM outputs)
54
+ """
55
+ def __init__(self, samples: List[SRLSample], tokenizer: AutoTokenizer, label2id: Dict[str, int], max_length: int = 256, debug_print= False):
56
+ self.samples = samples
57
+ self.tokenizer = tokenizer
58
+ self.label2id = label2id
59
+ self.id2label = {v: k for k, v in label2id.items()}
60
+ self.max_length = max_length
61
+ self.debug_print = debug_print
62
+
63
+ def __len__(self):
64
+ return len(self.samples)
65
+
66
+ def _tokenize_sentence(self, words: List[str]):
67
+ # Tokenize sentence as split words to preserve word boundaries
68
+ enc_sent = self.tokenizer(
69
+ words,
70
+ is_split_into_words=True,
71
+ add_special_tokens=False,
72
+ return_attention_mask=False,
73
+ return_token_type_ids=False
74
+ )
75
+ return enc_sent # dict with 'input_ids'
76
+
77
+ def _tokenize_predicate(self, form: str):
78
+ enc_pred = self.tokenizer(
79
+ form,
80
+ add_special_tokens=False,
81
+ return_attention_mask=False,
82
+ return_token_type_ids=False
83
+ )
84
+ return enc_pred
85
+
86
+ def __getitem__(self, idx):
87
+
88
+ instance = self.samples[idx]
89
+ words = instance.words
90
+ n_words = len(words)
91
+ assert 0 <= instance.predicate_word_idx < n_words, "Bad predicate index."
92
+
93
+ pred_form = instance.predicate_form if instance.predicate_form is not None else words[instance.predicate_word_idx]
94
+
95
+ # Tokenize sentence and predicate separately (Text -> numeric value)
96
+ enc_sent = self._tokenize_sentence(words)
97
+ enc_pred = self._tokenize_predicate(pred_form)
98
+
99
+ # print("This is enc_sent {}, this is enc_prec {} \n".format(enc_sent, enc_pred))
100
+
101
+
102
+ sent_wp_ids = enc_sent["input_ids"] # list[int]
103
+ pred_wp_ids = enc_pred["input_ids"] # list[int]
104
+
105
+ # Build final input ids and token type ids Here we added SEP for predicates create new input ids
106
+ # segment A (0): [CLS] sentence [SEP]
107
+ # segment B (1): predicate [SEP]
108
+ # [CLS] sentence [SEP] predicte [SEP]
109
+ # [CLS] sentence [SEP] ARG0_token [SEP] ARG1_token [SEP] ARG2_token [SEP] -> Model for emotion, formality and politeness
110
+ input_ids = [self.tokenizer.cls_token_id] + sent_wp_ids + [self.tokenizer.sep_token_id] \
111
+ + pred_wp_ids + [self.tokenizer.sep_token_id]
112
+
113
+ # token_type_ids: 0 for [CLS] + sentence + [SEP], 1 for predicate + [SEP]
114
+ ttids = [0] * (1 + len(sent_wp_ids) + 1) + [1] * (len(pred_wp_ids) + 1)
115
+
116
+ # Build mapping: each WORD -> index of its FIRST wordpiece inside the FULL sequence
117
+ # We iterate tokenizer.word_ids() by re-tokenizing with special tokens for alignment
118
+ # Simpler: reconstruct with pre-known structure:
119
+ # [CLS] at 0; sentence starts at 1; we need mapping from word index to its FIRST wordpiece offset in `sent_wp_ids`.
120
+ # We'll align by re-tokenizing sentence with is_split_into_words and reading the mapping.
121
+ # HuggingFace trick: get word_ids requires encoding with add_special_tokens=True, so let's do that quickly:
122
+ tmp = self.tokenizer(words, is_split_into_words=True, return_offsets_mapping=False)
123
+ word_ids = tmp.word_ids()
124
+ # print("This is tmp {}, word_ids {}\n".format(tmp, word_ids))
125
+ # Now, tmp.input_ids == [CLS] + sent_wp + [SEP]; positions:
126
+ # 0: CLS, 1..1+len(sent_wp_ids)-1: sentence, 1+len(sent_wp_ids): SEP
127
+ # We need FIRST position per word_id in this tmp encoding.
128
+ first_wp_pos_in_full = []
129
+ seen = set()
130
+ for pos, wid in enumerate(word_ids):
131
+ if wid is None:
132
+ continue
133
+ if wid not in seen:
134
+ seen.add(wid)
135
+ first_wp_pos_in_full.append(pos) # pos in tmp sequence
136
+ # Sort by wid order to align [0..n_words-1]
137
+ # word_ids may produce first_wp_pos_in_full in increasing pos order, but ensure length correctness:
138
+ # print("This is first_wp_posin_full {}\n".format(first_wp_pos_in_full))
139
+ first_wp_pos_in_full_sorted = [None] * n_words
140
+ # Build first index per wid:
141
+ first_pos_by_wid = {}
142
+ for pos, wid in enumerate(word_ids):
143
+ if wid is not None and wid not in first_pos_by_wid:
144
+ first_pos_by_wid[wid] = pos
145
+ for wid in range(n_words):
146
+ first_wp_pos_in_full_sorted[wid] = first_pos_by_wid[wid]
147
+
148
+ #first_wp_pos_in_full_sorted is the indices without special tokens (e.g., CLS, SEP)
149
+
150
+ # Convert those positions (which refer to tmp with specials) to positions in our final input (with extra predicate segment).
151
+ # In tmp: [CLS] sentence_wp [SEP]
152
+ # In final: [CLS] sentence_wp [SEP] predicate_wp [SEP]
153
+ # So for any position 'pos' inside tmp, it points to the SAME index in final, since the prefix is identical up to first [SEP].
154
+ word_first_wp_fullidx = first_wp_pos_in_full_sorted # list[int] length = n_words
155
+
156
+ # Labels to IDs
157
+ label_ids = [self.label2id[lbl] for lbl in instance.labels]
158
+ assert len(label_ids) == n_words
159
+
160
+ # Predicate indicator at word level (0/1)
161
+ indicator = [0] * n_words
162
+ indicator[instance.predicate_word_idx] = 1
163
+
164
+ # [0,0,0,0,0] -> [0,0,1,0,0]
165
+
166
+ # Attention mask for the full input
167
+ attention_mask = [1] * len(input_ids)
168
+
169
+ # Truncate if needed (rare for normal SRL sentences but keep safe)
170
+ if len(input_ids) > self.max_length:
171
+ # We only truncate the predicate side if absolutely necessary; for simplicity, just clip tail.
172
+ input_ids = input_ids[:self.max_length]
173
+ ttids = ttids[:self.max_length]
174
+ attention_mask = attention_mask[:self.max_length]
175
+ # NOTE: word_first_wp_fullidx could reference beyond max_length in pathological cases.
176
+ max_pos = self.max_length - 1
177
+ word_first_wp_fullidx = [min(p, max_pos) for p in word_first_wp_fullidx]
178
+
179
+ if self.debug_print:
180
+ toks_debug = self.tokenizer.convert_ids_to_tokens(input_ids, skip_special_tokens=False)
181
+ print("[DeBug idx = {}]".format(idx)+" ".join(toks_debug))
182
+
183
+ return {
184
+ "input_ids": torch.tensor(input_ids, dtype=torch.long),
185
+ "token_type_ids": torch.tensor(ttids, dtype=torch.long),
186
+ "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
187
+ "word_first_wp_fullidx": torch.tensor(word_first_wp_fullidx, dtype=torch.long), # [n_words]
188
+ "labels": torch.tensor(label_ids, dtype=torch.long), # [n_words]
189
+ "indicator": torch.tensor(indicator, dtype=torch.long), # [n_words]
190
+ "sent_len": torch.tensor(len(words), dtype=torch.long),
191
+ "pred_word_idx": torch.tensor(instance.predicate_word_idx, dtype=torch.long)
192
+ }
193
+
194
+
195
+ def srl_collate(batch: List[Dict], pad_token_id: int, pad_label_id: int = -100):
196
+ """
197
+ Pads full BERT inputs to same length; also pads word-level tensors to max sentence length.
198
+ Returns tensors ready for the model.
199
+ """
200
+ B = len(batch)
201
+ # Full sequence padding
202
+ max_L = max(item["input_ids"].size(0) for item in batch)
203
+ # print("This is B {}, max_L {}".format(B,max_L))
204
+ #make tensor B rows and Max_L columns
205
+ input_ids = torch.full((B, max_L), pad_token_id, dtype=torch.long)
206
+ token_type_ids = torch.zeros((B, max_L), dtype=torch.long)
207
+ attention_mask = torch.zeros((B, max_L), dtype=torch.long)
208
+
209
+ # Word-level padding
210
+ max_n = max(int(item["sent_len"]) for item in batch)
211
+ word_first_wp_fullidx = torch.full((B, max_n), -1, dtype=torch.long)
212
+ labels = torch.full((B, max_n), pad_label_id, dtype=torch.long)
213
+ indicator = torch.zeros((B, max_n), dtype=torch.long)
214
+ sent_lens = torch.zeros((B,), dtype=torch.long)
215
+ pred_word_idx = torch.zeros((B,), dtype=torch.long)
216
+ sentence_mask = torch.zeros((B, max_n), dtype=torch.bool)
217
+
218
+ for i, item in enumerate(batch):
219
+ # print("This is item {}".format(item))
220
+ L = item["input_ids"].size(0)
221
+ input_ids[i, :L] = item["input_ids"]
222
+ token_type_ids[i, :L] = item["token_type_ids"]
223
+ attention_mask[i, :L] = item["attention_mask"]
224
+
225
+ n = int(item["sent_len"])
226
+ word_first_wp_fullidx[i, :n] = item["word_first_wp_fullidx"]
227
+ labels[i, :n] = item["labels"]
228
+ indicator[i, :n] = item["indicator"]
229
+ sent_lens[i] = n
230
+ pred_word_idx[i] = item["pred_word_idx"]
231
+ sentence_mask[i, :n] = True
232
+
233
+ return {
234
+ "input_ids": input_ids,
235
+ "token_type_ids": token_type_ids,
236
+ "attention_mask": attention_mask,
237
+ "word_first_wp_fullidx": word_first_wp_fullidx, # [B, max_n] (full-seq positions; -1 for pad)
238
+ "sentence_mask": sentence_mask, # [B, max_n] (bool mask for valid words)
239
+ "labels": labels, # [B, max_n] (pad_label_id for pad)
240
+ "indicator": indicator, # [B, max_n] 0/1
241
+ "sent_lens": sent_lens, # [B]
242
+ "pred_word_idx": pred_word_idx # [B]
243
+ }
244
+
245
+
246
+ def data_processing_for_loader(train_dev_test: List[SRLSample], train_sample: List[SRLSample], dev_sample: List[SRLSample], test_sample: List[SRLSample], tokenizer):
247
+
248
+ '''
249
+ train_dev_test is an appended list of Train/Dev/Test SRLSamples
250
+ train_sample is a list of SRLSample
251
+ dev_sample is a list of SRLSample
252
+ test_sample is a list of SRLSample
253
+ '''
254
+
255
+ label2id = {}
256
+ for s in train_dev_test:
257
+ for l in s.labels:
258
+ label2id.setdefault(l, len(label2id))
259
+ id2label = {v: k for k, v in label2id.items()}
260
+
261
+ #train before loader
262
+ train_bf_loader = SRLDataset(train_sample, tokenizer, label2id, max_length = 128, debug_print = False)
263
+ dev_bf_loader = SRLDataset(dev_sample, tokenizer, label2id, max_length = 128, debug_print = False)
264
+ test_bf_loader = SRLDataset(test_sample, tokenizer, label2id, max_length = 128, debug_print = False)
265
+
266
+ return train_bf_loader, dev_bf_loader, test_bf_loader, label2id, id2label
model.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
4
+ from transformers import AutoModel, AutoConfig
5
+
6
+ class PredicateAwareSRL(nn.Module):
7
+ def __init__(self,
8
+ bert_name: str,
9
+ num_labels: int,
10
+ use_indicator: bool = True,
11
+ indicator_dim: int = 10, # CHANGED: 10-dim predicate indicator
12
+ lstm_hidden: int = 768, # CHANGED: BiLSTM hidden size = 768 (bidirectional)
13
+ mlp_hidden: int = 300, # CHANGED: MLP hidden size = 300
14
+ dropout: float = 0.1,
15
+ use_distance: bool = True, # NEW: enable relative position (distance) embeddings
16
+ pos_dim: int = 50, # NEW: size of position embedding (random init, trainable)
17
+ max_distance: int = 128): # NEW: clamp |i - p| to this range for bucketing
18
+ super().__init__()
19
+ self.config = AutoConfig.from_pretrained(bert_name)
20
+ self.bert = AutoModel.from_pretrained(bert_name)
21
+ self.use_indicator = use_indicator
22
+
23
+ # --- Input dim to BiLSTM = BERT_dim + (indicator_dim) + (pos_dim)
24
+ bert_dim = self.config.hidden_size
25
+ in_dim = bert_dim + (indicator_dim if use_indicator else 0)
26
+
27
+ # Two rows which indicate 0 not predicate 1 is predicate, so need to 2 embedding (rows)
28
+ # num_embeddings (int) – size of the dictionary of embeddings
29
+ # embedding_dim (int) – the size of each embedding vector
30
+
31
+ if use_indicator:
32
+ self.indicator_emb = nn.Embedding(2, indicator_dim) # 0/1 → 10-dim (random init, trainable) # CHANGED
33
+
34
+ self.use_distance = use_distance # NEW
35
+ self.max_distance = max_distance # NEW
36
+ if use_distance:
37
+ # Distance buckets: [-max_distance .. +max_distance] → indices [0 .. 2*max_distance]
38
+ self.pos_emb = nn.Embedding(2 * max_distance + 1, pos_dim) # NEW (random init, trainable)
39
+ in_dim += pos_dim # NEW
40
+
41
+ # BiLSTM (bidirectional): total output dim = lstm_hidden
42
+ self.bilstm = nn.LSTM(
43
+ input_size=in_dim,
44
+ hidden_size=lstm_hidden // 2, # bi → half per direction
45
+ num_layers=1,
46
+ dropout=0.0,
47
+ bidirectional=True,
48
+ batch_first=True
49
+ )
50
+
51
+ self.dropout = nn.Dropout(dropout)
52
+
53
+ # Classifier: concat(g_i, gp) so input dim = 2 * lstm_hidden
54
+ self.classifier = nn.Sequential(
55
+ nn.Linear(lstm_hidden * 2, mlp_hidden), # CHANGED (mlp_hidden=300)
56
+ nn.ReLU(),
57
+ nn.Dropout(dropout),
58
+ nn.Linear(mlp_hidden, num_labels)
59
+ )
60
+
61
+ self.pad_label_id = -100
62
+
63
+ def forward(self,
64
+ input_ids: torch.Tensor, # [B, L]
65
+ token_type_ids: torch.Tensor, # [B, L]
66
+ attention_mask: torch.Tensor, # [B, L]
67
+ word_first_wp_fullidx: torch.Tensor, # [B, max_n] (positions in full seq; -1 for pad)
68
+ sentence_mask: torch.Tensor, # [B, max_n] (bool)
69
+ sent_lens: torch.Tensor, # [B]
70
+ pred_word_idx: torch.Tensor, # [B]
71
+ indicator: torch.Tensor | None = None, # [B, max_n] 0/1
72
+ labels: torch.Tensor | None = None): # [B, max_n]
73
+
74
+ B, L = input_ids.size()
75
+ device = input_ids.device
76
+
77
+ # ---- BERT encoder
78
+ bert_out = self.bert(
79
+ input_ids=input_ids,
80
+ token_type_ids=token_type_ids,
81
+ attention_mask=attention_mask
82
+ )
83
+ H = bert_out.last_hidden_state # [B, L, D]
84
+
85
+ # ---- Subword → word pooling (first subword)
86
+
87
+ # Gather sentence word-level representations by taking FIRST subtoken hidden per word
88
+ # Prepare indices (replace -1 with 0 to avoid gather OOB; we'll mask later)
89
+ # This process is required to feed word level to predict BIO and role per word
90
+ #.clone is for deep copy won't change original data
91
+
92
+ gather_idx = word_first_wp_fullidx.clone()
93
+ gather_idx[gather_idx < 0] = 0
94
+ gather_idx = gather_idx.unsqueeze(-1).expand(-1, -1, H.size(-1)) # [B, max_n, D]
95
+ H_words = torch.gather(H, dim=1, index=gather_idx) # [B, max_n, D]
96
+ H_words = H_words * sentence_mask.unsqueeze(-1) # zero out pads
97
+
98
+ # ---- Concatenate predicate indicator (0/1 → emb)
99
+ # word_first_wp_fullidx: [1, 2, 3, -1, -1]
100
+ # gather_idx after clamp: [1, 2, 3, 0, 0] # 0 points to [CLS], just a placeholder
101
+ # H_words = gather(H, ...) # grabs vectors at positions 1,2,3,0,0
102
+ # sentence_mask: [1, 1, 1, 0, 0]
103
+ # H_words *= mask → [vec1, vec2, vec3, 0, 0] # padded slots zeroed out
104
+
105
+
106
+ X = H_words
107
+ if self.use_indicator and indicator is not None:
108
+ ind_emb = self.indicator_emb(indicator.clamp(0, 1)) # [B, max_n, 10] # CHANGED
109
+ X = torch.cat([X, ind_emb], dim=-1)
110
+
111
+ # ---- NEW: Relative position (distance-to-predicate) embeddings
112
+ if self.use_distance:
113
+ # positions: 0..max_n-1 per sentence
114
+ max_n = X.size(1)
115
+ positions = torch.arange(max_n, device=device).unsqueeze(0).expand(B, -1) # [B, max_n]
116
+ rel = positions - pred_word_idx.unsqueeze(1) # [B, max_n], can be <0
117
+ rel = rel.clamp(-self.max_distance, self.max_distance) + self.max_distance # shift to [0 .. 2*max_distance]
118
+ pos_feats = self.pos_emb(rel) # [B, max_n, pos_dim] # NEW
119
+ X = torch.cat([X, pos_feats], dim=-1) # [B, max_n, in_dim] # NEW
120
+
121
+ # ---- BiLSTM (packed)
122
+ lengths = sent_lens.detach().cpu()
123
+ packed = pack_padded_sequence(X, lengths=lengths, batch_first=True, enforce_sorted=False)
124
+ G_packed, _ = self.bilstm(packed)
125
+ G, _ = pad_packed_sequence(G_packed, batch_first=True) # [B, max_n, lstm_hidden]
126
+ G = self.dropout(G)
127
+
128
+ # ---- Predicate hidden (word-level) and concat to every position
129
+ batch_idx = torch.arange(B, device=device)
130
+ gp = G[batch_idx, pred_word_idx.clamp(min=0), :] # [B, lstm_hidden]
131
+ gp_expanded = gp.unsqueeze(1).expand(-1, G.size(1), -1) # [B, max_n, lstm_hidden]
132
+
133
+ logits = self.classifier(torch.cat([G, gp_expanded], dim=-1)) # [B, max_n, num_labels]
134
+
135
+ loss = None
136
+ if labels is not None:
137
+ loss_fct = nn.CrossEntropyLoss(ignore_index=self.pad_label_id)
138
+ loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
139
+
140
+ return logits, loss
predicator.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## This is testing
2
+
3
+ import torch
4
+
5
+ @torch.no_grad()
6
+ def predict_srl_single(model, tokenizer, words, predicate_word_idx, id2label, device="cuda"):
7
+ # tokenize sentence (no specials)
8
+ sent_enc = tokenizer(
9
+ words, is_split_into_words=True, add_special_tokens=False,
10
+ return_attention_mask=False, return_token_type_ids=False
11
+ )
12
+ sent_wp_ids = sent_enc["input_ids"]
13
+ sent_word_ids = sent_enc.word_ids()
14
+
15
+ # first-subword position per word in the FULL sequence: [CLS] sent [SEP] pred [SEP]
16
+ first_pos_by_wid = {}
17
+ for pos, wid in enumerate(sent_word_ids):
18
+ if wid is not None and wid not in first_pos_by_wid:
19
+ first_pos_by_wid[wid] = pos + 1 # +1 for [CLS]
20
+ n_words = len(words)
21
+ word_first_wp_fullidx = torch.tensor(
22
+ [first_pos_by_wid[i] for i in range(n_words)], dtype=torch.long
23
+ ).unsqueeze(0)
24
+
25
+ # predicate segment = surface form of the predicate word
26
+ pred_enc = tokenizer(
27
+ [words[predicate_word_idx]], is_split_into_words=True, add_special_tokens=False,
28
+ return_attention_mask=False, return_token_type_ids=False
29
+ )
30
+ pred_wp_ids = pred_enc["input_ids"]
31
+
32
+ # assemble full input
33
+ cls_id, sep_id = tokenizer.cls_token_id, tokenizer.sep_token_id
34
+ input_ids = [cls_id] + sent_wp_ids + [sep_id] + pred_wp_ids + [sep_id]
35
+ token_type_ids = [0] * (1 + len(sent_wp_ids) + 1) + [1] * (len(pred_wp_ids) + 1)
36
+ attention_mask = [1] * len(input_ids)
37
+
38
+ # tensors
39
+ input_ids = torch.tensor(input_ids).unsqueeze(0).to(device)
40
+ token_type_ids= torch.tensor(token_type_ids).unsqueeze(0).to(device)
41
+ attention_mask= torch.tensor(attention_mask).unsqueeze(0).to(device)
42
+
43
+ sent_len = torch.tensor([n_words], dtype=torch.long).to(device)
44
+ sentence_mask = torch.ones(1, n_words, dtype=torch.bool).to(device)
45
+ pred_word_idx = torch.tensor([predicate_word_idx], dtype=torch.long).to(device)
46
+ indicator = torch.zeros(1, n_words, dtype=torch.long).to(device)
47
+ indicator[0, predicate_word_idx] = 1
48
+ word_first_wp_fullidx = word_first_wp_fullidx.to(device)
49
+
50
+ # forward
51
+ logits, _ = model(
52
+ input_ids=input_ids,
53
+ token_type_ids=token_type_ids,
54
+ attention_mask=attention_mask,
55
+ word_first_wp_fullidx=word_first_wp_fullidx,
56
+ sentence_mask=sentence_mask,
57
+ sent_lens=sent_len,
58
+ pred_word_idx=pred_word_idx,
59
+ indicator=indicator,
60
+ labels=None
61
+ )
62
+
63
+ pred_ids = logits.argmax(-1).squeeze(0).tolist()
64
+ tags = [id2label[i] for i in pred_ids]
65
+ return tags, logits.squeeze(0).cpu() # [L_word, num_labels]
66
+
67
+ def bio_to_spans(tags):
68
+ spans = []
69
+ i = 0
70
+ while i < len(tags):
71
+ t = tags[i]
72
+ if t == "O" or t.endswith("-V"):
73
+ i += 1
74
+ continue
75
+ if t.startswith("B-"):
76
+ role = t[2:]
77
+ j = i + 1
78
+ while j < len(tags) and tags[j] == f"I-{role}":
79
+ j += 1
80
+ spans.append((role, i, j-1))
81
+ i = j
82
+ else:
83
+ i += 1
84
+ return spans
85
+
86
+ @torch.no_grad()
87
+ def predict_srl_all_predicates(model, tokenizer, sentence, id2label, device="cuda", prob_threshold=0.50):
88
+ words = sentence.split()
89
+ # find the numeric id for "B-V"
90
+ b_v_id = None
91
+ for k, v in id2label.items():
92
+ if v == "B-V":
93
+ b_v_id = k
94
+ break
95
+ if b_v_id is None:
96
+ raise ValueError("Label set has no 'B-V' tag.")
97
+
98
+ results = []
99
+ for p in range(len(words)):
100
+ tags, logits = predict_srl_single(model, tokenizer, words, p, id2label, device=device)
101
+ # check predicate decision at position p
102
+ pred_id_at_p = logits.argmax(-1)[p].item()
103
+ keep = (pred_id_at_p == b_v_id)
104
+
105
+ # optional confidence gate
106
+ if prob_threshold is not None:
107
+ probs = torch.softmax(logits[p], dim=-1)
108
+ keep = keep and (probs[b_v_id].item() >= prob_threshold)
109
+
110
+ if keep:
111
+ spans = bio_to_spans(tags)
112
+ results.append({
113
+ "predicate_index": p,
114
+ "predicate": words[p],
115
+ "tags": tags,
116
+ "spans": spans
117
+ })
118
+ return words, results
119
+
120
+
121
+
122
+ # words, preds = predict_srl_all_predicates(model, tokenizer, sentence, id2label, device=device)
123
+
124
+
125
+ def predicator_srl(sentence):
126
+ words, preds = predict_srl_all_predicates(model, tokenizer, sentence, id2label, device=device)
127
+
128
+ return words, preds
129
+
130
+ if __name__ == "__main__":
131
+ sentence = "Hojeong decide to go to the school"
132
+ words, preds = predicator_srl(sentence)
133
+ print(words)
134
+ for r in preds:
135
+ print(f"Predicate: {r['predicate']} (idx {r['predicate_index']})")
136
+ print("Tags:", list(zip(words, r["tags"])))
137
+ print("Spans:", r["spans"]) # (ROLE, start, end) indices over words
138
+ print("-" * 60)
139
+
140
+
141
+
testing.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from SRL_model import SRL_BERT_model
2
+ from collections import Counter
3
+ import torch
4
+
5
+ def bio_to_spans(tags):
6
+ spans = []
7
+ i = 0
8
+ while i < len(tags):
9
+ t = tags[i]
10
+ if t == "O" or t.endswith("-V"):
11
+ i += 1; continue
12
+ if t.startswith("B-"):
13
+ role = t[2:]; j = i + 1
14
+ while j < len(tags) and tags[j] == f"I-{role}":
15
+ j += 1
16
+ spans.append((role, i, j-1))
17
+ i = j
18
+ else:
19
+ i += 1
20
+ return spans
21
+
22
+ @torch.no_grad()
23
+ def eval_span_f1(model, dataloader, id2label, device="cuda"):
24
+ model.eval()
25
+ tp = fp = fn = 0
26
+ for batch in dataloader:
27
+ gold = batch["labels"] # [B, Lw]
28
+ mask = (gold != -100)
29
+
30
+ batch = {k:(v.to(device) if torch.is_tensor(v) else v) for k,v in batch.items()}
31
+ logits, _ = model(**batch)
32
+ pred = logits.argmax(-1).cpu() # [B, Lw]
33
+ print(pred)
34
+ for g_seq, p_seq, m in zip(gold, pred, mask):
35
+ gl = [id2label[int(i)] for i in g_seq[m].tolist()]
36
+ pl = [id2label[int(i)] for i in p_seq[m].tolist()]
37
+ G = Counter(bio_to_spans(gl))
38
+ P = Counter(bio_to_spans(pl))
39
+ # micro counts
40
+ common = G & P
41
+ tp += sum(common.values())
42
+ fp += sum(P.values()) - sum(common.values())
43
+ fn += sum(G.values()) - sum(common.values())
44
+
45
+ prec = tp / (tp + fp + 1e-12)
46
+ rec = tp / (tp + fn + 1e-12)
47
+ f1 = 2 * prec * rec / (prec + rec + 1e-12)
48
+ return prec, rec, f1
49
+
50
+
51
+ if __name__ =="__main__":
52
+
53
+ device = "cuda" if torch.cuda.is_available() else "cpu"
54
+ ckpt_path = "/blue/bonniejdorr/youms/SRL-Aware_Model/model/best_srl_Sep_29.ckpt" # <-- change if needed
55
+ ckpt = torch.load(ckpt_path, map_location=device)
56
+ hp = ckpt["hparams"]
57
+
58
+ model = SRL_BERT_model.PredicateAwareSRL(**hp).to(device)
59
+ model.load_state_dict(ckpt["state_dict"])
60
+ model.eval()
61
+
62
+ label2id = ckpt["label2id"]
63
+ id2label = {v: k for k, v in label2id.items()}
64
+
65
+ h = ckpt.get("hparams", {
66
+ "bert_name": "bert-base-cased",
67
+ "num_labels": len(label2id),
68
+ "use_indicator": True,
69
+ "use_distance": True,
70
+ "indicator_dim": 10,
71
+ "lstm_hidden": 768,
72
+ "mlp_hidden": 300,
73
+ "pos_dim": 50,
74
+ "max_distance": 128,
75
+ "dropout": 0.1,
76
+ })
77
+
78
+ #test_loader from SRL_BERT_model
79
+ prec, rec, span_f1 = eval_span_f1(model, test_loader, id2label, device=device)
80
+ print(f"[TEST-SPAN] P={prec:.3f} R={rec:.3f} F1={span_f1:.3f}")
training.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from SRL_MODEL import data_prep, SRL_BERT_model
2
+ import torch
3
+ from transformers import AutoTokenizer, get_linear_schedule_with_warmup
4
+ from sklearn.metrics import f1_score
5
+ import pickle
6
+
7
+ def save_pkl(tgt_list, svg_path):
8
+ with open(svg_path, "wb") as f:
9
+ pickle.dump(tgt_list, f)
10
+
11
+ def load_pkl(path) :
12
+ with open(path, "rb") as f:
13
+ data = pickle.load(f)
14
+ return data
15
+
16
+
17
+ def train_one_epoch(
18
+ model,
19
+ dataloader,
20
+ optimizer,
21
+ device="cuda",
22
+ scheduler=None,
23
+ grad_accum_steps=1,
24
+ amp=True,
25
+ max_grad_norm=1.0,
26
+ ):
27
+ model.train()
28
+ total_loss, n_steps = 0.0, 0
29
+
30
+ use_amp = amp and torch.cuda.is_available()
31
+ scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
32
+
33
+ optimizer.zero_grad(set_to_none=True)
34
+
35
+ for step, batch in enumerate(dataloader, 1):
36
+ batch = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch.items()}
37
+
38
+ with torch.cuda.amp.autocast(enabled=use_amp, dtype=torch.float16):
39
+ _, loss = model(**batch) # model must return (logits, loss)
40
+
41
+ total_loss += float(loss.detach().item())
42
+ n_steps += 1
43
+
44
+ loss = loss / grad_accum_steps # for accumulation
45
+
46
+ if use_amp:
47
+ scaler.scale(loss).backward()
48
+ else:
49
+ loss.backward()
50
+
51
+ if step % grad_accum_steps == 0:
52
+ if use_amp:
53
+ scaler.unscale_(optimizer)
54
+ nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
55
+
56
+ if use_amp:
57
+ scaler.step(optimizer)
58
+ scaler.update()
59
+ else:
60
+ optimizer.step()
61
+
62
+ optimizer.zero_grad(set_to_none=True)
63
+
64
+ if scheduler is not None:
65
+ scheduler.step()
66
+
67
+ return total_loss / max(1, n_steps)
68
+
69
+ #This is Validation
70
+ @torch.no_grad()
71
+ def eval_loss_and_token_f1(model, dataloader, id2label=None, device="cuda", average="micro"):
72
+
73
+ model.eval()
74
+ total_loss, n_batches = 0.0, 0
75
+ all_preds, all_golds = [], []
76
+
77
+ for batch in dataloader:
78
+ gold = batch["labels"] # keep on CPU for masking
79
+ mask = (gold != -100)
80
+
81
+ batch = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch.items()}
82
+ logits, loss = model(**batch) # loss computed once here
83
+ total_loss += float(loss.item()); n_batches += 1
84
+
85
+ preds = logits.argmax(-1).cpu()
86
+ all_preds.extend(preds[mask].tolist())
87
+ all_golds.extend(gold[mask].tolist())
88
+
89
+ f1 = f1_score(all_golds, all_preds, average=average)
90
+ return total_loss / max(1, n_batches), f1
91
+
92
+
93
+ if __name__ =='__main__':
94
+ bert_name = "bert-base-cased"
95
+ tokenizer = AutoTokenizer.from_pretrained(bert_name)
96
+
97
+ device = "cuda" if torch.cuda.is_available() else "cpu"
98
+ # tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
99
+
100
+ #data_class_train/dev/test from data_prep
101
+ train_dev_test_data = data_class_train + data_class_dev + data_class_test
102
+ train_bf_loader, dev_bf_loader,test_bf_loader, label2id, id2label = data_prep.data_processing_for_loader(train_dev_test_data, data_class_train, data_class_dev, data_class_test, tokenizer)
103
+
104
+ pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
105
+ collate = lambda b: data_prep.srl_collate(b, pad_token_id=pad_token_id, pad_label_id=-100)
106
+
107
+ train_loader = data_prep.DataLoader(train_bf_loader, batch_size=16, shuffle=True, collate_fn=collate)
108
+ dev_loader = data_prep.DataLoader(dev_bf_loader, batch_size=16, shuffle=False, collate_fn=collate)
109
+ test_loader = data_prep.DataLoader(test_bf_loader, batch_size=16, shuffle=False, collate_fn=collate)
110
+
111
+ # bert_name = "bert-base-cased"
112
+ # tokenizer = AutoTokenizer.from_pretrained(bert_name)
113
+
114
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
115
+
116
+ model = SRL_BERT_model.PredicateAwareSRL(
117
+ bert_name=bert_name,
118
+ num_labels=len(label2id),
119
+ use_indicator=True,
120
+ use_distance =True,
121
+ indicator_dim= 10,
122
+ lstm_hidden=768,
123
+ mlp_hidden=300,
124
+ pos_dim= 50,
125
+ max_distance = 128,
126
+ dropout=0.1
127
+ ).to(device)
128
+
129
+ # Optimizer (you may want to use AdamW with weight decay and a scheduler)
130
+ num_epochs = 12
131
+ grad_accum_steps = 1
132
+ optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
133
+
134
+ # # Train a couple of epochs (on toy data this is just to check shapes run)
135
+ # for epoch in range(3):
136
+ # tr_loss = train_one_epoch(model, train_loader, optimizer, device=device)
137
+ # f1 = evaluate_token_f1(model, dev_loader, id2label=id2label, device=device)
138
+ # print(f"Epoch {epoch+1} | loss={tr_loss:.4f} | token-F1={f1:.4f}")
139
+
140
+ total_steps = len(train_loader) * num_epochs // max(1, grad_accum_steps)
141
+ warmup_steps = int(0.1 * total_steps)
142
+
143
+ scheduler = get_linear_schedule_with_warmup(
144
+ optimizer,
145
+ num_warmup_steps=warmup_steps,
146
+ num_training_steps=total_steps
147
+ )
148
+
149
+ history = {"epoch": [], "train_loss": [], "dev_loss": [], "dev_f1": []}
150
+
151
+ best_dev, best_path = -1.0, "best_srl.ckpt"
152
+ for epoch in range(num_epochs):
153
+ tr_loss = train_one_epoch(
154
+ model, train_loader, optimizer, device=device,
155
+ scheduler=scheduler, grad_accum_steps=grad_accum_steps, amp=True, max_grad_norm=1.0
156
+ )
157
+ dev_loss, dev_f1 = eval_loss_and_token_f1(model, dev_loader, id2label, device=device)
158
+
159
+
160
+ history["epoch"].append(epoch + 1)
161
+ history["train_loss"].append(tr_loss)
162
+ history["dev_loss"].append(dev_loss)
163
+ history["dev_f1"].append(dev_f1)
164
+
165
+ print(f"Epoch {epoch+1}: train_loss={tr_loss:.4f} dev_loss={dev_loss:.4f} dev_F1={dev_f1:.4f}")
166
+
167
+ if dev_f1 > best_dev:
168
+ best_dev = dev_f1
169
+ torch.save({"model_state": model.state_dict(), "label2id": label2id}, best_path)
170
+ print(" ↳ new best dev; saved.")
171
+
172
+ save_pkl(history, #save_path_for_loss)
173
+
174
+ # best_dev, best_path = -1.0, "best_srl.ckpt"
175
+ # for epoch in range(num_epochs):
176
+ # tr_loss = train_one_epoch(model, train_loader, optimizer, device=device)
177
+ # dev_loss, dev_f1 = eval_loss_and_token_f1(model, dev_loader, id2label, device=device)
178
+ # print(f"Epoch {epoch+1}: train_loss={tr_loss:.4f} dev_loss={dev_loss:.4f} dev_F1={dev_f1:.4f}")
179
+ # if dev_f1 > best_dev:
180
+ # best_dev = dev_f1
181
+ # torch.save({"model_state": model.state_dict(), "label2id": label2id}, best_path)
182
+ # print(" ↳ new best dev; saved.")