yeomtong commited on
Commit
e1567ed
·
verified ·
1 Parent(s): 9fc7f0b

Upload 2 files

Browse files
Files changed (2) hide show
  1. predictor_dev.py +385 -0
  2. visualizer_dev.py +191 -0
predictor_dev.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model import PredicateAwareSRL
2
+ import torch, json
3
+ from transformers import AutoTokenizer
4
+ import spacy
5
+ from spacy import cli as spacy_cli
6
+
7
+ # ============ GLOBAL CACHE (NEW) ============
8
+ _CACHE = {
9
+ "ckpt_path": None,
10
+ "bert_name": None,
11
+ "spacy_model": None,
12
+ "device": None,
13
+ "model": None,
14
+ "tokenizer": None,
15
+ "label2id": None,
16
+ "id2label": None,
17
+ "hparams": None,
18
+ "nlp": None,
19
+ }
20
+
21
+ # --- Add near the top of your module (where _CACHE lives) ---
22
+ _CACHE = {
23
+ "model": None, "tokenizer": None, "id2label": None, "nlp": None, "device": None,
24
+ "ckpt_path": None, "bert_name": None, "spacy_model": None,
25
+ }
26
+
27
+ def srl_init(model_path, bert_name="bert-base-cased", spacy_model="en_core_web_md"):
28
+ """
29
+ Call ONCE per session to load and cache model/tokenizer/spaCy.
30
+ After this, you can call: prediction("your sentence here")
31
+ """
32
+ # reuse your existing loader logic
33
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
+ ckpt = torch.load(model_path, map_location=device)
35
+ hp = ckpt.get("hparams", ckpt.get("hyper_parameters", {}))
36
+ if "bert_name" not in hp:
37
+ hp["bert_name"] = bert_name
38
+ if "num_labels" not in hp:
39
+ label2id = ckpt.get("label2id") or {v:k for k,v in ckpt["id2label"].items()}
40
+ hp["num_labels"] = len(label2id)
41
+
42
+ model = PredicateAwareSRL(**hp).to(device).eval()
43
+ state = ckpt.get("model_state") or ckpt.get("state_dict") or ckpt
44
+ model.load_state_dict(state)
45
+
46
+ tokenizer = AutoTokenizer.from_pretrained(hp.get("bert_name", bert_name), use_fast=True)
47
+
48
+ try:
49
+ nlp = spacy.load(spacy_model, disable=["parser","ner","lemmatizer"])
50
+ except OSError:
51
+ spacy_cli.download(spacy_model)
52
+ nlp = spacy.load(spacy_model, disable=["parser","ner","lemmatizer"])
53
+
54
+ label2id = ckpt.get("label2id") or {v:k for k,v in ckpt["id2label"].items()}
55
+ id2label = {int(v): k for k, v in label2id.items()}
56
+
57
+ _CACHE.update({
58
+ "model": model, "tokenizer": tokenizer, "id2label": id2label,
59
+ "nlp": nlp, "device": device, "ckpt_path": model_path,
60
+ "bert_name": hp.get("bert_name", bert_name), "spacy_model": spacy_model,
61
+ })
62
+ torch.set_grad_enabled(False)
63
+
64
+ def _predict_cached(sentence):
65
+ """Internal: uses cached objects set by srl_init()."""
66
+ if _CACHE["model"] is None:
67
+ raise RuntimeError("Model not loaded. Call srl_init(ckpt_path, bert_name) once first.")
68
+ model = _CACHE["model"]
69
+ tokenizer = _CACHE["tokenizer"]
70
+ id2label = _CACHE["id2label"]
71
+ nlp = _CACHE["nlp"]
72
+ device = "cuda" if (_CACHE["device"].type == "cuda") else "cpu"
73
+
74
+ # your existing function name stays the same:
75
+ return predict_srl_allennlp_like_spacy(
76
+ model, tokenizer, nlp, sentence, id2label,
77
+ device=device, prob_threshold=0.40, top_k=None, pick_best_if_none=True
78
+ )
79
+
80
+ def _pick_device(dev=None): # NEW
81
+ if dev == "cpu":
82
+ return torch.device("cpu")
83
+ if dev and dev.startswith("cuda") and torch.cuda.is_available():
84
+ return torch.device(dev)
85
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
86
+
87
+ def _ensure_loaded(model_path, bert_name, spacy_model, model_cls): # NEW
88
+ """Load model/tokenizer/spaCy once and reuse."""
89
+ must_reload = (
90
+ _CACHE["model"] is None
91
+ or _CACHE["ckpt_path"] != model_path
92
+ or _CACHE["bert_name"] != bert_name
93
+ or _CACHE["spacy_model"] != spacy_model
94
+ )
95
+ if not must_reload:
96
+ return # already warm
97
+
98
+ device = _pick_device()
99
+ ckpt = torch.load(model_path, map_location=device)
100
+ h = ckpt.get("hparams", ckpt.get("hyper_parameters", {}))
101
+
102
+ # defaults if not present in ckpt
103
+ if "bert_name" not in h: h["bert_name"] = bert_name
104
+ if "num_labels" not in h:
105
+ label2id = ckpt.get("label2id")
106
+ if label2id is None and "id2label" in ckpt:
107
+ label2id = {v:k for k,v in ckpt["id2label"].items()}
108
+ h["num_labels"] = len(label2id) if label2id else 0
109
+
110
+ model = model_cls(**h).to(device).eval()
111
+ state = ckpt.get("model_state") or ckpt.get("state_dict") or ckpt
112
+ model.load_state_dict(state)
113
+
114
+ tok = AutoTokenizer.from_pretrained(h.get("bert_name", bert_name), use_fast=True)
115
+
116
+ try:
117
+ nlp = spacy.load(spacy_model, disable=["parser","ner","lemmatizer"])
118
+ except OSError:
119
+ spacy_cli.download(spacy_model)
120
+ nlp = spacy.load(spacy_model, disable=["parser","ner","lemmatizer"])
121
+
122
+ label2id = ckpt.get("label2id")
123
+ if label2id is None and "id2label" in ckpt:
124
+ label2id = {v:k for k,v in ckpt["id2label"].items()}
125
+ id2label = {int(v): k for k, v in label2id.items()}
126
+
127
+ _CACHE.update({
128
+ "ckpt_path": model_path,
129
+ "bert_name": h.get("bert_name", bert_name),
130
+ "spacy_model": spacy_model,
131
+ "device": device,
132
+ "model": model,
133
+ "tokenizer": tok,
134
+ "label2id": label2id,
135
+ "id2label": id2label,
136
+ "hparams": h,
137
+ "nlp": nlp,
138
+ })
139
+ torch.set_grad_enabled(False)
140
+
141
+
142
+ def normalize_whitespace(s: str) -> str:
143
+ if s is None: return ""
144
+ return s.replace("\u00A0", " ").replace("\u2009", " ").strip()
145
+
146
+ def spacy_verb_indices(nlp, sentence: str):
147
+ doc = nlp(sentence)
148
+ return [i for i, t in enumerate(doc) if t.pos_ in ("VERB","AUX") or t.tag_.startswith("VB")]
149
+
150
+ def words_and_spans_spacy(sentence: str, nlp):
151
+ doc = nlp(sentence)
152
+ words = [t.text for t in doc]
153
+ spans = [(t.idx, t.idx + len(t.text)) for t in doc]
154
+ return words, spans
155
+
156
+ def bio_to_spans(tags):
157
+ spans = []; i = 0
158
+ while i < len(tags):
159
+ t = tags[i]
160
+ if t == "O" or t.endswith("-V"):
161
+ i += 1; continue
162
+ if t.startswith("B-"):
163
+ role = t[2:]; j = i+1
164
+ while j < len(tags) and tags[j] == f"I-{role}": j += 1
165
+ spans.append((role, i, j-1)); i = j
166
+ else:
167
+ i += 1
168
+ return spans
169
+
170
+
171
+ @torch.no_grad()
172
+ def predict_srl_single(model, tokenizer, words, predicate_word_idx, id2label, device="cuda"):
173
+ model.eval()
174
+ sent_enc = tokenizer(
175
+ words, is_split_into_words=True, add_special_tokens=False,
176
+ return_attention_mask=False, return_token_type_ids=False,
177
+ )
178
+ # word ids
179
+ try:
180
+ sent_word_ids = sent_enc.word_ids()
181
+ except Exception:
182
+ raise ValueError("Tokenizer must be fast (use_fast=True).")
183
+
184
+ sent_wp_ids = sent_enc["input_ids"]
185
+ if isinstance(sent_wp_ids[0], list):
186
+ sent_wp_ids = sent_wp_ids[0]
187
+
188
+ first_pos_by_wid = {}
189
+ for pos, wid in enumerate(sent_word_ids):
190
+ if wid is not None and wid not in first_pos_by_wid:
191
+ first_pos_by_wid[wid] = pos + 1
192
+
193
+ n_words = len(words)
194
+ word_first_wp_fullidx = torch.tensor(
195
+ [first_pos_by_wid[i] for i in range(n_words)], dtype=torch.long
196
+ ).unsqueeze(0)
197
+
198
+ pred_enc = tokenizer(
199
+ [words[predicate_word_idx]], is_split_into_words=True, add_special_tokens=False,
200
+ return_attention_mask=False, return_token_type_ids=False,
201
+ )
202
+ pred_wp_ids = pred_enc["input_ids"]
203
+ if isinstance(pred_wp_ids[0], list):
204
+ pred_wp_ids = pred_wp_ids[0]
205
+
206
+ cls_id, sep_id = tokenizer.cls_token_id, tokenizer.sep_token_id
207
+ input_ids = [cls_id] + sent_wp_ids + [sep_id] + pred_wp_ids + [sep_id]
208
+ token_type_ids = [0] * (1 + len(sent_wp_ids) + 1) + [1] * (len(pred_wp_ids) + 1)
209
+ attention_mask = [1] * len(input_ids)
210
+
211
+ device = _pick_device(device)
212
+ input_ids = torch.tensor(input_ids).unsqueeze(0).to(device)
213
+ token_type_ids = torch.tensor(token_type_ids).unsqueeze(0).to(device)
214
+ attention_mask = torch.tensor(attention_mask).unsqueeze(0).to(device)
215
+
216
+ sent_len = torch.tensor([n_words], dtype=torch.long).to(device)
217
+ sentence_mask = torch.ones(1, n_words, dtype=torch.bool).to(device)
218
+ pred_word_idx = torch.tensor([predicate_word_idx], dtype=torch.long).to(device)
219
+ indicator = torch.zeros(1, n_words, dtype=torch.long).to(device)
220
+ indicator[0, predicate_word_idx] = 1
221
+ word_first_wp_fullidx = word_first_wp_fullidx.to(device)
222
+
223
+ logits, _ = model(
224
+ input_ids=input_ids,
225
+ token_type_ids=token_type_ids,
226
+ attention_mask=attention_mask,
227
+ word_first_wp_fullidx=word_first_wp_fullidx,
228
+ sentence_mask=sentence_mask,
229
+ sent_lens=sent_len,
230
+ pred_word_idx=pred_word_idx,
231
+ indicator=indicator,
232
+ labels=None,
233
+ )
234
+ pred_ids = logits.argmax(-1).squeeze(0).tolist()
235
+ tags = [id2label[i] for i in pred_ids]
236
+ return tags, logits.squeeze(0).cpu()
237
+
238
+
239
+ def _encode_sentence_once(words, tokenizer): # NEW
240
+ enc = tokenizer(
241
+ words, is_split_into_words=True, add_special_tokens=False,
242
+ return_attention_mask=False, return_token_type_ids=False,
243
+ )
244
+ sent_wp_ids = enc["input_ids"]
245
+ if isinstance(sent_wp_ids[0], list):
246
+ sent_wp_ids = sent_wp_ids[0]
247
+ wid = enc.word_ids()
248
+ first_pos = {}
249
+ for pos, w in enumerate(wid):
250
+ if w is not None and w not in first_pos:
251
+ first_pos[w] = pos + 1 # +1 for [CLS]
252
+ n_words = len(words)
253
+ word_first = torch.tensor([first_pos[i] for i in range(n_words)], dtype=torch.long)
254
+ return sent_wp_ids, word_first, n_words
255
+
256
+ @torch.no_grad()
257
+ def _batch_predict_verbs(model, tokenizer, words, verb_idxs, id2label, device): # NEW
258
+ """One forward pass for all verbs in the sentence."""
259
+ device = _pick_device(device)
260
+ sent_wp_ids, word_first_1, n_words = _encode_sentence_once(words, tokenizer)
261
+ cls_id, sep_id = tokenizer.cls_token_id, tokenizer.sep_token_id
262
+
263
+ ids_list, tt_list, am_list = [], [], []
264
+ pred_idx_list, ind_list, wf_list = [], [], []
265
+
266
+ for p in verb_idxs:
267
+ pred_wp_ids = tokenizer(
268
+ [words[p]], is_split_into_words=True, add_special_tokens=False,
269
+ return_attention_mask=False, return_token_type_ids=False,
270
+ )["input_ids"]
271
+ if isinstance(pred_wp_ids[0], list):
272
+ pred_wp_ids = pred_wp_ids[0]
273
+
274
+ ids = [cls_id] + sent_wp_ids + [sep_id] + pred_wp_ids + [sep_id]
275
+ tt = [0]*(1 + len(sent_wp_ids) + 1) + [1]*(len(pred_wp_ids) + 1)
276
+ am = [1]*len(ids)
277
+
278
+ ids_list.append(torch.tensor(ids, dtype=torch.long))
279
+ tt_list.append(torch.tensor(tt, dtype=torch.long))
280
+ am_list.append(torch.tensor(am, dtype=torch.long))
281
+ pred_idx_list.append(torch.tensor(p, dtype=torch.long))
282
+ ind = torch.zeros(n_words, dtype=torch.long); ind[p] = 1
283
+ ind_list.append(ind)
284
+ wf_list.append(word_first_1.clone())
285
+
286
+ # pad
287
+ def pad_1d(seq, pad_id=0):
288
+ L = max(x.numel() for x in seq)
289
+ out = torch.full((len(seq), L), pad_id, dtype=seq[0].dtype)
290
+ for i, x in enumerate(seq):
291
+ out[i, :x.numel()] = x
292
+ return out
293
+
294
+ pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
295
+ input_ids = pad_1d(ids_list, pad_id).to(device)
296
+ token_type_ids = pad_1d(tt_list, 0).to(device)
297
+ attention_mask = pad_1d(am_list, 0).to(device)
298
+
299
+ B = len(verb_idxs)
300
+ sent_lens = torch.full((B,), n_words, dtype=torch.long, device=device)
301
+ sentence_mask = torch.ones(B, n_words, dtype=torch.bool, device=device)
302
+ pred_word_idx = torch.stack(pred_idx_list).to(device)
303
+ indicator = torch.stack(ind_list).to(device)
304
+ word_first_wp_fullidx = torch.stack(wf_list).to(device)
305
+
306
+ logits, _ = model(
307
+ input_ids=input_ids,
308
+ token_type_ids=token_type_ids,
309
+ attention_mask=attention_mask,
310
+ word_first_wp_fullidx=word_first_wp_fullidx,
311
+ sentence_mask=sentence_mask,
312
+ sent_lens=sent_lens,
313
+ pred_word_idx=pred_word_idx,
314
+ indicator=indicator,
315
+ labels=None,
316
+ ) # [B, n_words, C]
317
+
318
+ results = []
319
+ for row, p in enumerate(verb_idxs):
320
+ row_logits = logits[row]
321
+ tags = [id2label[i] for i in row_logits.argmax(-1).tolist()]
322
+ results.append((p, tags, row_logits))
323
+ return results
324
+
325
+
326
+ @torch.no_grad()
327
+ def predict_srl_allennlp_like_spacy(
328
+ model, tokenizer, nlp, sentence, id2label,
329
+ device="cuda",
330
+ prob_threshold=0.50,
331
+ top_k=None,
332
+ pick_best_if_none=True
333
+ ):
334
+ model.eval()
335
+ words, _ = words_and_spans_spacy(sentence, nlp)
336
+ if not words:
337
+ return [], []
338
+
339
+ verb_idxs = spacy_verb_indices(nlp, sentence)
340
+ if not verb_idxs:
341
+ return words, []
342
+
343
+ # one forward for all verbs (fast path)
344
+ batch_out = _batch_predict_verbs(model, tokenizer, words, verb_idxs, id2label, device)
345
+ b_v_id = next((i for i,t in id2label.items() if t in ("B-V","V")), None)
346
+
347
+ frames = []
348
+ for p, tags, row_logits in batch_out:
349
+ p_bv = float(torch.softmax(row_logits[p], dim=-1)[b_v_id].item()) if b_v_id is not None else 1.0
350
+ frames.append({
351
+ "predicate_index": p,
352
+ "predicate": words[p],
353
+ "p_bv": p_bv,
354
+ "tags": tags,
355
+ "spans": bio_to_spans(tags)
356
+ })
357
+
358
+ # optional thresholding / top-k
359
+ if prob_threshold is not None:
360
+ keep = [f for f in frames if f["p_bv"] >= prob_threshold]
361
+ if not keep and pick_best_if_none and frames:
362
+ keep = [max(frames, key=lambda r: r["p_bv"])]
363
+ frames = keep
364
+ if top_k is not None and len(frames) > top_k:
365
+ frames = sorted(frames, key=lambda r: r["p_bv"], reverse=True)[:top_k]
366
+
367
+ return words, frames
368
+
369
+ def main_predictor(model_path, bert_name, sentence, spacy_model="en_core_web_md"):
370
+ sentence = normalize_whitespace(sentence)
371
+ from model import PredicateAwareSRL # keep your import style
372
+ _ensure_loaded(model_path, bert_name, spacy_model, PredicateAwareSRL) # NEW: cache/warm
373
+ model = _CACHE["model"]
374
+ tokenizer = _CACHE["tokenizer"]
375
+ id2label = _CACHE["id2label"]
376
+ nlp = _CACHE["nlp"]
377
+ device = _CACHE["device"]
378
+
379
+ words, frames = predict_srl_allennlp_like_spacy(
380
+ model, tokenizer, nlp, sentence, id2label,
381
+ device=str(device), prob_threshold=0.40, top_k=None, pick_best_if_none=True
382
+ )
383
+ return words, frames
384
+
385
+
visualizer_dev.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from predictor import main_predictor
2
+ import re
3
+ import itertools
4
+
5
+ def bio_brackets_to_spans(text: str) -> str:
6
+ """
7
+ Collapse BIO bracket chunks into non-BIO spans.
8
+ Example:
9
+ [B-ARG2: of] [I-ARG2: the] [I-ARG2: orchards] → [ARG2: of the orchards]
10
+ [B-V: take] → [V: take]
11
+ Non-bracket text (spaces, punctuation, quotes) is preserved.
12
+ """
13
+
14
+ BIO_RE = re.compile(r"\[(B|I)-([A-Za-z0-9\-]+):\s*([^\]]+?)\]")
15
+
16
+ out = []
17
+ i = 0
18
+ matches = list(BIO_RE.finditer(text))
19
+
20
+ m = 0
21
+ cursor = 0
22
+ while m < len(matches):
23
+ # plain text before next BIO chunk
24
+ out.append(text[cursor:matches[m].start()])
25
+
26
+ # start a run
27
+ prefix, role, tok = matches[m].groups()
28
+ tokens = [tok]
29
+ cursor = matches[m].end()
30
+ m += 1
31
+
32
+ # absorb subsequent I-<same role> chunks if only whitespace between
33
+ while m < len(matches):
34
+ between = text[cursor:matches[m].start()]
35
+ p2, role2, tok2 = matches[m].groups()
36
+ if role2 == role and p2 == "I" and between.strip() == "":
37
+ tokens.append(tok2)
38
+ cursor = matches[m].end()
39
+ m += 1
40
+ else:
41
+ break
42
+
43
+ # output merged span (drop B-/I-), keep V as just "V"
44
+ out.append(f"[{role}: {' '.join(tokens)}]")
45
+
46
+ # trailing text
47
+ out.append(text[cursor:])
48
+ return "".join(out)
49
+
50
+ def create_description(words, tag_list):
51
+ desc_list = []
52
+ for tok, tag in zip(words, tag_list):
53
+ if tag != 'O' :
54
+ desc_list.append("["+tag+": "+tok+"]")
55
+ else:
56
+ desc_list.append(tok)
57
+ desc_str_temp = (' ').join(desc_list)
58
+
59
+ return bio_brackets_to_spans(desc_str_temp)
60
+
61
+ def create_dict(words, frames):
62
+ final_dict = {}
63
+ verb = []
64
+ for f in frames:
65
+ temp_dict = {}
66
+ temp_dict['verb'] = f['predicate']
67
+ temp_dict['description'] = create_description(words, f['tags'])
68
+ temp_dict['tags'] = f['tags']
69
+ verb.append(temp_dict)
70
+ final_dict['verbs'] = verb
71
+ final_dict['words'] = words
72
+
73
+ return final_dict
74
+
75
+ def print_srl_frames_pretty(words, frames, show_grid=True, color=False):
76
+ """
77
+ Pretty-print SRL frames.
78
+ - Description: Token+Labels
79
+ - Frames: Predicate/Roles
80
+ - show_grid: also print a token/label grid aligned by column
81
+ - color: add simple ANSI colors per role (terminal only)
82
+ """
83
+
84
+
85
+ # tiny colorizer (terminal-only); safe no-op if color=False
86
+ ANSI = {
87
+ "ARG0": "\033[38;5;34m", "ARG1": "\033[38;5;33m", "ARG2": "\033[38;5;129m",
88
+ "ARG3": "\033[38;5;172m", "ARG4": "\033[38;5;166m", "ARGM": "\033[38;5;244m",
89
+ "V": "\033[1;37m", "RESET": "\033[0m"
90
+ }
91
+ def paint(txt, role):
92
+ if not color: return txt
93
+ key = "ARGM" if role.startswith("ARGM") else ("V" if role.endswith("V") or role=="V" else role)
94
+ return f"{ANSI.get(key, '')}{txt}{ANSI['RESET']}"
95
+
96
+ def spans_from_bio(tags):
97
+ spans = []
98
+ i = 0
99
+ while i < len(tags):
100
+ t = tags[i]
101
+ if t == "O":
102
+ i += 1; continue
103
+ if t.endswith("-V"): # you can include/exclude the V span as you like
104
+ spans.append(("V", i, i))
105
+ i += 1; continue
106
+ if t.startswith("B-"):
107
+ role = t[2:]
108
+ j = i + 1
109
+ while j < len(tags) and tags[j] == f"I-{role}":
110
+ j += 1
111
+ spans.append((role, i, j-1))
112
+ i = j
113
+ else:
114
+ i += 1
115
+ return spans
116
+
117
+ # words = [word.text for word in words]
118
+ print("Sentence:", " ".join(words))
119
+ if not frames:
120
+ print(" (no predicates detected)")
121
+ return
122
+
123
+ for k, fr in enumerate(frames, 1):
124
+ tags = fr["tags"]
125
+ spans = fr.get("spans") or spans_from_bio(tags)
126
+ pred_idx = fr["predicate_index"]
127
+ pred = fr["predicate"]
128
+ p_bv = fr.get("p_bv", None)
129
+
130
+ print("\n" + "—"*60)
131
+ # head = f"Frame {k} — predicate: {pred} (idx {pred_idx})"
132
+ # if p_bv is not None:
133
+ # head += f" P(B-V)={p_bv:.3f}"
134
+ # print(head)
135
+
136
+ print(create_description(words, tags))
137
+
138
+ # Aggregate phrases per role for a clean summary
139
+ by_role = {}
140
+ for role, s, e in spans:
141
+ phrase = " ".join(words[s:e+1])
142
+ by_role.setdefault(role, []).append(phrase)
143
+
144
+ # Put V first, then core args, then ARGM*
145
+ order = (
146
+ (("V",),),
147
+ tuple((r,) for r in ["ARG0","ARG1","ARG2","ARG3","ARG4"]),
148
+ (tuple(sorted([r for r in by_role if r.startswith("ARGM")])),)
149
+ )
150
+ ordered_roles = []
151
+ for group in order:
152
+ for r in itertools.chain.from_iterable(group):
153
+ if r in by_role: ordered_roles.append(r)
154
+ # add any leftover roles
155
+ # for r in sorted(by_role):
156
+ # if r not in ordered_roles: ordered_roles.append(r)
157
+ # print("Predicate:")
158
+ # print(f" {r:<8}: {pred}")
159
+ # print("Roles:")
160
+ # for r in ordered_roles:
161
+ # joined = "; ".join(by_role[r])
162
+ # print(f" {r:<8}: {paint(joined, r)}")
163
+
164
+ if show_grid:
165
+ # token/tag grid aligned by column width
166
+ colw = [max(len(w), len(t)) for w, t in zip(words, tags)]
167
+ tok_row = " ".join(w.ljust(colw[i]) for i, w in enumerate(words))
168
+ tag_row = " ".join((t if t != "O" else ".").ljust(colw[i]) for i, t in enumerate(tags))
169
+ print("\nTOKEN:", tok_row)
170
+ print("LABEL:", tag_row)
171
+
172
+ def prediction(model_path, bert_name, sentence):
173
+ words, frames = main_predictor(model_path, bert_name, sentence)
174
+ # assumes you have print_srl_frames_pretty already defined somewhere
175
+ try:
176
+ print_srl_frames_pretty(words, frames, show_grid=True, color=False)
177
+ except NameError:
178
+ # fallback
179
+ print("Sentence:", " ".join(words))
180
+ for fr in frames:
181
+ print(f"\nPredicate: {fr['predicate']} P(B-V)={fr['p_bv']:.3f}")
182
+ print("Tags:", list(zip(words, fr['tags'])))
183
+ print("Spans:", fr['spans'])
184
+
185
+ def prediction_formatted(model_path, bert_name, sentence):
186
+ words, frames = main_predictor(model_path, bert_name, sentence)
187
+ # assumes you have create_dict, otherwise return raw
188
+ try:
189
+ return create_dict(words, frames)
190
+ except NameError:
191
+ return {"words": words, "frames": frames}