yeomtong commited on
Commit
fdbdd29
·
verified ·
1 Parent(s): 0bdcf24

Upload predicator.py

Browse files
Files changed (1) hide show
  1. predicator.py +226 -42
predicator.py CHANGED
@@ -1,5 +1,7 @@
1
- ## This is testing
2
 
 
 
 
3
  import torch
4
 
5
  @torch.no_grad()
@@ -64,6 +66,58 @@ def predict_srl_single(model, tokenizer, words, predicate_word_idx, id2label, de
64
  tags = [id2label[i] for i in pred_ids]
65
  return tags, logits.squeeze(0).cpu() # [L_word, num_labels]
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  def bio_to_spans(tags):
68
  spans = []
69
  i = 0
@@ -84,58 +138,188 @@ def bio_to_spans(tags):
84
  return spans
85
 
86
  @torch.no_grad()
87
- def predict_srl_all_predicates(model, tokenizer, sentence, id2label, device="cuda", prob_threshold=0.50):
88
- words = sentence.split()
89
- # find the numeric id for "B-V"
90
- b_v_id = None
91
- for k, v in id2label.items():
92
- if v == "B-V":
93
- b_v_id = k
94
- break
95
- if b_v_id is None:
96
- raise ValueError("Label set has no 'B-V' tag.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
 
 
 
98
  results = []
99
- for p in range(len(words)):
100
- tags, logits = predict_srl_single(model, tokenizer, words, p, id2label, device=device)
101
- # check predicate decision at position p
102
- pred_id_at_p = logits.argmax(-1)[p].item()
103
- keep = (pred_id_at_p == b_v_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
- # optional confidence gate
106
- if prob_threshold is not None:
107
- probs = torch.softmax(logits[p], dim=-1)
108
- keep = keep and (probs[b_v_id].item() >= prob_threshold)
 
 
 
 
 
109
 
110
- if keep:
111
- spans = bio_to_spans(tags)
112
- results.append({
113
- "predicate_index": p,
114
- "predicate": words[p],
115
- "tags": tags,
116
- "spans": spans
117
- })
118
- return words, results
 
119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
 
 
 
 
 
 
121
 
122
- # words, preds = predict_srl_all_predicates(model, tokenizer, sentence, id2label, device=device)
123
 
 
 
 
 
 
 
 
124
 
125
- def predicator_srl(sentence):
126
- words, preds = predict_srl_all_predicates(model, tokenizer, sentence, id2label, device=device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
- return words, preds
 
 
 
 
 
 
129
 
130
- if __name__ == "__main__":
131
- sentence = "Hojeong decide to go to the school"
132
- words, preds = predicator_srl(sentence)
133
- print(words)
134
- for r in preds:
135
- print(f"Predicate: {r['predicate']} (idx {r['predicate_index']})")
136
- print("Tags:", list(zip(words, r["tags"])))
137
- print("Spans:", r["spans"]) # (ROLE, start, end) indices over words
138
- print("-" * 60)
139
 
 
 
 
 
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
 
 
 
 
 
1
 
2
+ from SRL_model import SRL_BERT_model
3
+ from transformers import AutoTokenizer
4
+ import spacy
5
  import torch
6
 
7
  @torch.no_grad()
 
66
  tags = [id2label[i] for i in pred_ids]
67
  return tags, logits.squeeze(0).cpu() # [L_word, num_labels]
68
 
69
+
70
+
71
+ def spacy_verb_indices(nlp, sentence: str,word_spans):
72
+ """
73
+ Map spaCy POS to each tokenizer word span by max-overlap token,
74
+ and return indices that are verbs (VB*, or POS in {VERB, AUX}).
75
+ """
76
+ doc = nlp(sentence)
77
+ verb_idxs = []
78
+ for i, (wb, we) in enumerate(word_spans):
79
+ # find spaCy token with maximum overlap
80
+ best_tok, best_olap = None, 0
81
+ for tok in doc:
82
+ tb, te = tok.idx, tok.idx + len(tok)
83
+ olap = max(0, min(we, te) - max(wb, tb))
84
+ if olap > best_olap:
85
+ best_olap = olap
86
+ best_tok = tok
87
+ if best_tok is None:
88
+ continue
89
+ is_verb = best_tok.tag_.startswith("VB") or best_tok.pos_ in {"VERB", "AUX"}
90
+ if is_verb:
91
+ verb_idxs.append(i)
92
+
93
+
94
+ return verb_idxs
95
+
96
+ def words_and_spans_from_tokenizer(sentence: str, nlp, tokenizer):
97
+ """
98
+ Returns:
99
+ words : list[str] (tokenizer-aligned words)
100
+ spans : list[(start_char, end_char)] for each word
101
+ """
102
+ enc = tokenizer(sentence, add_special_tokens=False, return_offsets_mapping=True)
103
+ word_ids = enc.word_ids()
104
+ offsets = enc["offset_mapping"]
105
+ # print("Offsets:", offsets)
106
+
107
+ words, spans, seen = [], [], set()
108
+ for wid, (b, e) in zip(word_ids, offsets):
109
+ if wid is None or wid in seen:
110
+ continue
111
+ seen.add(wid)
112
+ words.append(sentence[b:e])
113
+ spans.append((b, e))
114
+
115
+
116
+ doc = nlp(sentence)
117
+ words_final = [token for token in doc]
118
+
119
+ return words, spans, words_final
120
+
121
  def bio_to_spans(tags):
122
  spans = []
123
  i = 0
 
138
  return spans
139
 
140
  @torch.no_grad()
141
+ def predict_srl_allennlp_like_spacy(
142
+ model, tokenizer, nlp, sentence, id2label,
143
+ device="cuda"):
144
+
145
+ device = "cuda" if torch.cuda.is_available() else "cpu"
146
+ model.eval()
147
+
148
+ # 0) tokenizer-aligned words (and spans for spaCy alignment)
149
+ words, spans, words_final = words_and_spans_from_tokenizer(sentence, nlp,tokenizer)
150
+ # words = words_and_spans_from_tokenizer(sentence, nlp)
151
+ # words = sentence.split(' ')
152
+ # print(spans)
153
+ n = len(words)
154
+ if n == 0:
155
+ return [], []
156
+
157
+ # 1) verb candidates via spaCy POS (aligned to tokenizer words by overlap)
158
+ verb_idxs = spacy_verb_indices(nlp, sentence, spans)
159
+ # verb_idxs = spacy_verb_indices(nlp, sentence)
160
+ if not verb_idxs:
161
+ # no verbs → either return empty or consider all tokens as a fallback
162
+ return words, []
163
+
164
+ # 2) find the predicate label id ('B-V' or 'V')
165
+ pred_ids = [i for i, t in id2label.items() if t in ("B-V", "V")]
166
+ if not pred_ids:
167
+ raise ValueError("Label set has no predicate tag ('B-V' or 'V').")
168
+ b_v_id = pred_ids[0]
169
 
170
+ keep = verb_idxs
171
+
172
+ # 4) PASS 2: final tagging with indicator ON
173
  results = []
174
+ for p in keep:
175
+ tags, logits = predict_srl_single(
176
+ model, tokenizer, words, p, id2label, device=device
177
+ # , use_indicator=True
178
+ )
179
+ # if require_core_arg and not has_core_argument(tags):
180
+ # continue
181
+ p_bv = torch.softmax(logits[p], dim=-1)[b_v_id].item()
182
+ spans_out = bio_to_spans(tags)
183
+ results.append({
184
+ "predicate_index": p,
185
+ "predicate": words[p],
186
+ "p_bv": p_bv,
187
+ "tags": tags,
188
+ "spans": spans_out
189
+ })
190
+
191
+ return words_final, results
192
+
193
+
194
+ def create_description(words, tag_list):
195
+ desc_list = []
196
+ for tok, tag in zip(words, tag_list):
197
+ desc_list.append("["+tag+": "+tok+"]")
198
+
199
+ return (' ').join(desc_list)
200
 
201
+ def print_srl_frames_pretty(words, frames, show_grid=True, color=False):
202
+ """
203
+ Pretty-print SRL frames.
204
+ - Description: Token+Labels
205
+ - Frames: Predicate/Roles
206
+ - show_grid: also print a token/label grid aligned by column
207
+ - color: add simple ANSI colors per role (terminal only)
208
+ """
209
+ import itertools
210
 
211
+ # tiny colorizer (terminal); safe no-op if color=False
212
+ ANSI = {
213
+ "ARG0": "\033[38;5;34m", "ARG1": "\033[38;5;33m", "ARG2": "\033[38;5;129m",
214
+ "ARG3": "\033[38;5;172m", "ARG4": "\033[38;5;166m", "ARGM": "\033[38;5;244m",
215
+ "V": "\033[1;37m", "RESET": "\033[0m"
216
+ }
217
+ def paint(txt, role):
218
+ if not color: return txt
219
+ key = "ARGM" if role.startswith("ARGM") else ("V" if role.endswith("V") or role=="V" else role)
220
+ return f"{ANSI.get(key, '')}{txt}{ANSI['RESET']}"
221
 
222
+ def spans_from_bio(tags):
223
+ spans = []
224
+ i = 0
225
+ while i < len(tags):
226
+ t = tags[i]
227
+ if t == "O":
228
+ i += 1; continue
229
+ if t.endswith("-V"): # you can include/exclude the V span as you like
230
+ spans.append(("V", i, i))
231
+ i += 1; continue
232
+ if t.startswith("B-"):
233
+ role = t[2:]
234
+ j = i + 1
235
+ while j < len(tags) and tags[j] == f"I-{role}":
236
+ j += 1
237
+ spans.append((role, i, j-1))
238
+ i = j
239
+ else:
240
+ i += 1
241
+ return spans
242
+
243
+ words = [word.text for word in words]
244
+ print("Sentence:", " ".join(words))
245
+ if not frames:
246
+ print(" (no predicates detected)")
247
+ return
248
 
249
+ for k, fr in enumerate(frames, 1):
250
+ tags = fr["tags"]
251
+ spans = fr.get("spans") or spans_from_bio(tags)
252
+ pred_idx = fr["predicate_index"]
253
+ pred = fr["predicate"]
254
+ p_bv = fr.get("p_bv", None)
255
 
256
+ print("\n" + "—"*60)
257
 
258
+ print(create_description(words, tags))
259
+
260
+ # Aggregate phrases per role for a clean summary
261
+ by_role = {}
262
+ for role, s, e in spans:
263
+ phrase = " ".join(words[s:e+1])
264
+ by_role.setdefault(role, []).append(phrase)
265
 
266
+ # Put V first, then core args, then ARGM*
267
+ order = (
268
+ (("V",),),
269
+ tuple((r,) for r in ["ARG0","ARG1","ARG2","ARG3","ARG4"]),
270
+ (tuple(sorted([r for r in by_role if r.startswith("ARGM")])),)
271
+ )
272
+ ordered_roles = []
273
+ for group in order:
274
+ for r in itertools.chain.from_iterable(group):
275
+ if r in by_role: ordered_roles.append(r)
276
+ # add any leftover roles
277
+ for r in sorted(by_role):
278
+ if r not in ordered_roles: ordered_roles.append(r)
279
+ print("Predicate:")
280
+ print(f" {r:<8}: {pred}")
281
+ print("Roles:")
282
+ for r in ordered_roles:
283
+ joined = "; ".join(by_role[r])
284
+ print(f" {r:<8}: {paint(joined, r)}")
285
 
286
+ if show_grid:
287
+ # token/tag grid aligned by column width
288
+ colw = [max(len(w), len(t)) for w, t in zip(words, tags)]
289
+ tok_row = " ".join(w.ljust(colw[i]) for i, w in enumerate(words))
290
+ tag_row = " ".join((t if t != "O" else ".").ljust(colw[i]) for i, t in enumerate(tags))
291
+ print("\nTOKEN:", tok_row)
292
+ print("LABEL:", tag_row)
293
 
294
+ def main_predictor(model_path, bert_name, sentence):
 
 
 
 
 
 
 
 
295
 
296
+ device = "cuda" if torch.cuda.is_available() else "cpu"
297
+ # model_path = "/blue/bonniejdorr/youms/SRL-Aware_Model/model/best_srl_Sep_29.ckpt"
298
+ ckpt = torch.load(model_path, map_location=device)
299
+ hp = ckpt["hparams"]
300
 
301
+ model = SRL_BERT_model.PredicateAwareSRL(**hp).to(device)
302
+ model.load_state_dict(ckpt["state_dict"])
303
+ model.eval()
304
+
305
+ label2id = ckpt["label2id"]
306
+ id2label = {v:k for k,v in label2id.items()}
307
+
308
+ # bert_name = "bert-large-cased" or "bert-based-cased"
309
+ bert_name = bert_name
310
+ tokenizer = AutoTokenizer.from_pretrained(bert_name)
311
+
312
+ nlp = spacy.load("en_core_web_md")
313
+
314
+ words, frames = predict_srl_allennlp_like_spacy(
315
+ model, tokenizer, nlp, sentence, id2label,
316
+ device=device,
317
+ prob_threshold=0.40, # tune on dev; try 0.3–0.6
318
+ top_k=None,
319
+ pick_best_if_none=True
320
+ )
321
+ return words, frames
322
 
323
+ if __name__ =="__main__":
324
+ words, frames = main_predictor(model_path, bert_namem sentence)
325
+ print_srl_frames_pretty(words, frames, show_grid=True, color=False)