yeomtong commited on
Commit
607c8e1
·
verified ·
1 Parent(s): ab9f44d

Upload 2 files

Browse files
Files changed (2) hide show
  1. predictor.py +224 -0
  2. visualizer.py +182 -0
predictor.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from model import PredicateAwareSRL
3
+ from transformers import AutoTokenizer
4
+ import spacy
5
+ from spacy import cli as spacy_cli
6
+ import torch
7
+
8
+ @torch.no_grad()
9
+ def predict_srl_single(
10
+ model, tokenizer, words, predicate_word_idx, id2label, device="cuda"
11
+ ):
12
+ # words must come from spaCy (one token per element)
13
+ # e.g., words = [t.text for t in nlp(sentence)]
14
+ model.eval()
15
+
16
+ # --- sentence subwords ---
17
+ sent_enc = tokenizer(
18
+ words,
19
+ is_split_into_words=True,
20
+ add_special_tokens=False,
21
+ return_attention_mask=False,
22
+ return_token_type_ids=False,
23
+ )
24
+
25
+ # Require a *fast* tokenizer to get word_ids
26
+ try:
27
+ sent_word_ids = sent_enc.word_ids()
28
+ except Exception:
29
+ raise ValueError(
30
+ "Tokenizer must be a *fast* tokenizer to use .word_ids(). "
31
+ "Initialize with use_fast=True."
32
+ )
33
+
34
+ sent_wp_ids = sent_enc["input_ids"]
35
+ # HF may return [[...]] vs [...] depending on version—normalize to flat list
36
+ if isinstance(sent_wp_ids[0], list):
37
+ sent_wp_ids = sent_wp_ids[0]
38
+
39
+ # first-subword index per word (in full sequence after we add [CLS])
40
+ first_pos_by_wid = {}
41
+ for pos, wid in enumerate(sent_word_ids):
42
+ if wid is not None and wid not in first_pos_by_wid:
43
+ first_pos_by_wid[wid] = pos + 1 # +1 to account for [CLS] we add below
44
+
45
+ n_words = len(words)
46
+ word_first_wp_fullidx = torch.tensor(
47
+ [first_pos_by_wid[i] for i in range(n_words)], dtype=torch.long
48
+ ).unsqueeze(0)
49
+
50
+ # --- predicate subwords (surface form only) ---
51
+ pred_enc = tokenizer(
52
+ [words[predicate_word_idx]],
53
+ is_split_into_words=True,
54
+ add_special_tokens=False,
55
+ return_attention_mask=False,
56
+ return_token_type_ids=False,
57
+ )
58
+ pred_wp_ids = pred_enc["input_ids"]
59
+ if isinstance(pred_wp_ids[0], list):
60
+ pred_wp_ids = pred_wp_ids[0]
61
+
62
+ # --- assemble full input: [CLS] sent [SEP] pred [SEP] ---
63
+ cls_id, sep_id = tokenizer.cls_token_id, tokenizer.sep_token_id
64
+ input_ids = [cls_id] + sent_wp_ids + [sep_id] + pred_wp_ids + [sep_id]
65
+ token_type_ids = [0] * (1 + len(sent_wp_ids) + 1) + [1] * (len(pred_wp_ids) + 1)
66
+ attention_mask = [1] * len(input_ids)
67
+
68
+ # --- tensors ---
69
+ device = torch.device(device if torch.cuda.is_available() and "cuda" in device else "cpu")
70
+ input_ids = torch.tensor(input_ids).unsqueeze(0).to(device)
71
+ token_type_ids = torch.tensor(token_type_ids).unsqueeze(0).to(device)
72
+ attention_mask = torch.tensor(attention_mask).unsqueeze(0).to(device)
73
+
74
+ sent_len = torch.tensor([n_words], dtype=torch.long).to(device)
75
+ sentence_mask = torch.ones(1, n_words, dtype=torch.bool).to(device)
76
+ pred_word_idx = torch.tensor([predicate_word_idx], dtype=torch.long).to(device)
77
+ indicator = torch.zeros(1, n_words, dtype=torch.long).to(device)
78
+ indicator[0, predicate_word_idx] = 1
79
+ word_first_wp_fullidx = word_first_wp_fullidx.to(device)
80
+
81
+ # --- forward ---
82
+ logits, _ = model(
83
+ input_ids=input_ids,
84
+ token_type_ids=token_type_ids,
85
+ attention_mask=attention_mask,
86
+ word_first_wp_fullidx=word_first_wp_fullidx,
87
+ sentence_mask=sentence_mask,
88
+ sent_lens=sent_len,
89
+ pred_word_idx=pred_word_idx,
90
+ indicator=indicator,
91
+ labels=None,
92
+ )
93
+
94
+ pred_ids = logits.argmax(-1).squeeze(0).tolist()
95
+ tags = [id2label[i] for i in pred_ids]
96
+ return tags, logits.squeeze(0).cpu() # [L_word, num_labels]
97
+
98
+
99
+ def spacy_verb_indices(nlp, sentence: str):
100
+ """
101
+ Returns the indices (0..n-1) of tokens that are verbs/auxiliaries by spaCy POS.
102
+ """
103
+ doc = nlp(sentence)
104
+ return [i for i, t in enumerate(doc) if t.pos_ in ("VERB", "AUX") or t.tag_.startswith("VB")]
105
+
106
+
107
+ def words_and_spans_spacy(sentence: str, nlp):
108
+ """
109
+ Returns:
110
+ words : list[str] (spaCy tokens)
111
+ spans : list[(start,end)] (char offsets per word)
112
+ """
113
+ doc = nlp(sentence)
114
+ words = [t.text for t in doc]
115
+ spans = [(t.idx, t.idx + len(t.text)) for t in doc]
116
+ return words, spans
117
+
118
+ def bio_to_spans(tags):
119
+ spans = []
120
+ i = 0
121
+ while i < len(tags):
122
+ t = tags[i]
123
+ if t == "O" or t.endswith("-V"):
124
+ i += 1
125
+ continue
126
+ if t.startswith("B-"):
127
+ role = t[2:]
128
+ j = i + 1
129
+ while j < len(tags) and tags[j] == f"I-{role}":
130
+ j += 1
131
+ spans.append((role, i, j-1))
132
+ i = j
133
+ else:
134
+ i += 1
135
+ return spans
136
+
137
+
138
+
139
+ @torch.no_grad()
140
+ def predict_srl_allennlp_like_spacy(
141
+ model, tokenizer, nlp, sentence, id2label,
142
+ device="cuda",
143
+ prob_threshold=0.50,
144
+ top_k=None,
145
+ pick_best_if_none=True
146
+ ):
147
+ model.eval()
148
+
149
+ # -- spaCy-only tokenization --
150
+ words, spans = words_and_spans_spacy(sentence, nlp)
151
+ n = len(words)
152
+ if n == 0:
153
+ return [], []
154
+
155
+ # verb candidates from spaCy
156
+ verb_idxs = spacy_verb_indices(nlp, sentence)
157
+ if not verb_idxs:
158
+ return words, [] # no predicates found
159
+
160
+ # find predicate label id
161
+ pred_ids = [i for i, t in id2label.items() if t in ("B-V", "V")]
162
+ if not pred_ids:
163
+ raise ValueError("Label set has no predicate tag ('B-V' or 'V').")
164
+ b_v_id = pred_ids[0]
165
+
166
+ keep = verb_idxs
167
+ if top_k is not None and len(keep) > top_k:
168
+ keep = keep[:top_k]
169
+
170
+ results = []
171
+ for p in keep:
172
+ # IMPORTANT: predict_srl_single should encode using
173
+ # tokenizer(..., is_split_into_words=True) on `words`
174
+ tags, logits = predict_srl_single(
175
+ model, tokenizer, words, p, id2label, device=device
176
+ )
177
+ p_bv = torch.softmax(logits[p], dim=-1)[b_v_id].item()
178
+ spans_out = bio_to_spans(tags)
179
+ results.append({
180
+ "predicate_index": p,
181
+ "predicate": words[p],
182
+ "p_bv": p_bv,
183
+ "tags": tags,
184
+ "spans": spans_out
185
+ })
186
+
187
+ # optional thresholding
188
+ if prob_threshold is not None:
189
+ passed = [r for r in results if r["p_bv"] >= prob_threshold]
190
+ if not passed and pick_best_if_none and results:
191
+ passed = [max(results, key=lambda r: r["p_bv"])]
192
+ results = passed
193
+
194
+ return words, results
195
+
196
+
197
+ def main_predictor(model_path, bert_name, sentence, spacy_model="en_core_web_md"):
198
+ device = "cuda" if torch.cuda.is_available() else "cpu"
199
+ ckpt = torch.load(model_path, map_location=device)
200
+ hp = ckpt.get("hparams", ckpt.get("hyper_parameters", {}))
201
+
202
+ model = PredicateAwareSRL(**hp).to(device)
203
+ state = ckpt.get("state_dict", ckpt.get("model_state_dict", ckpt))
204
+ model.load_state_dict(state)
205
+ model.eval()
206
+
207
+ label2id = ckpt["label2id"] if "label2id" in ckpt else {v:k for k,v in ckpt["id2label"].items()}
208
+ id2label = {v:k for k,v in label2id.items()}
209
+
210
+ tokenizer = AutoTokenizer.from_pretrained(bert_name, use_fast=True)
211
+
212
+ try:
213
+ nlp = spacy.load(spacy_model)
214
+ except OSError:
215
+ spacy_cli.download(spacy_model) # <— no local `spacy` binding
216
+ nlp = spacy.load(spacy_model)
217
+
218
+ words, frames = predict_srl_allennlp_like_spacy(
219
+ model, tokenizer, nlp, sentence, id2label,
220
+ device=device, prob_threshold=0.40, top_k=None, pick_best_if_none=True
221
+ )
222
+ return words, frames
223
+
224
+
visualizer.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from predictor import main_predictor
3
+ import re
4
+ import itertools
5
+
6
+ def bio_brackets_to_spans(text: str) -> str:
7
+ """
8
+ Collapse BIO bracket chunks into non-BIO spans.
9
+ Example:
10
+ [B-ARG2: of] [I-ARG2: the] [I-ARG2: orchards] → [ARG2: of the orchards]
11
+ [B-V: take] → [V: take]
12
+ Non-bracket text (spaces, punctuation, quotes) is preserved.
13
+ """
14
+
15
+ BIO_RE = re.compile(r"\[(B|I)-([A-Za-z0-9\-]+):\s*([^\]]+?)\]")
16
+
17
+ out = []
18
+ i = 0
19
+ matches = list(BIO_RE.finditer(text))
20
+
21
+ m = 0
22
+ cursor = 0
23
+ while m < len(matches):
24
+ # plain text before next BIO chunk
25
+ out.append(text[cursor:matches[m].start()])
26
+
27
+ # start a run
28
+ prefix, role, tok = matches[m].groups()
29
+ tokens = [tok]
30
+ cursor = matches[m].end()
31
+ m += 1
32
+
33
+ # absorb subsequent I-<same role> chunks if only whitespace between
34
+ while m < len(matches):
35
+ between = text[cursor:matches[m].start()]
36
+ p2, role2, tok2 = matches[m].groups()
37
+ if role2 == role and p2 == "I" and between.strip() == "":
38
+ tokens.append(tok2)
39
+ cursor = matches[m].end()
40
+ m += 1
41
+ else:
42
+ break
43
+
44
+ # output merged span (drop B-/I-), keep V as just "V"
45
+ out.append(f"[{role}: {' '.join(tokens)}]")
46
+
47
+ # trailing text
48
+ out.append(text[cursor:])
49
+ return "".join(out)
50
+
51
+ def create_description(words, tag_list):
52
+ desc_list = []
53
+ for tok, tag in zip(words, tag_list):
54
+ if tag != 'O' :
55
+ desc_list.append("["+tag+": "+tok+"]")
56
+ else:
57
+ desc_list.append(tok)
58
+ desc_str_temp = (' ').join(desc_list)
59
+
60
+ return bio_brackets_to_spans(desc_str_temp)
61
+
62
+ def create_dict(words, frames):
63
+ final_dict = {}
64
+ verb = []
65
+ for f in frames:
66
+ temp_dict = {}
67
+ temp_dict['verb'] = f['predicate']
68
+ temp_dict['description'] = create_description(words, f['tags'])
69
+ temp_dict['tags'] = f['tags']
70
+ verb.append(temp_dict)
71
+ final_dict['verbs'] = verb
72
+ final_dict['words'] = words
73
+
74
+ return final_dict
75
+
76
+ def print_srl_frames_pretty(words, frames, show_grid=True, color=False):
77
+ """
78
+ Pretty-print SRL frames.
79
+ - Description: Token+Labels
80
+ - Frames: Predicate/Roles
81
+ - show_grid: also print a token/label grid aligned by column
82
+ - color: add simple ANSI colors per role (terminal only)
83
+ """
84
+
85
+
86
+ # tiny colorizer (terminal-only); safe no-op if color=False
87
+ ANSI = {
88
+ "ARG0": "\033[38;5;34m", "ARG1": "\033[38;5;33m", "ARG2": "\033[38;5;129m",
89
+ "ARG3": "\033[38;5;172m", "ARG4": "\033[38;5;166m", "ARGM": "\033[38;5;244m",
90
+ "V": "\033[1;37m", "RESET": "\033[0m"
91
+ }
92
+ def paint(txt, role):
93
+ if not color: return txt
94
+ key = "ARGM" if role.startswith("ARGM") else ("V" if role.endswith("V") or role=="V" else role)
95
+ return f"{ANSI.get(key, '')}{txt}{ANSI['RESET']}"
96
+
97
+ def spans_from_bio(tags):
98
+ spans = []
99
+ i = 0
100
+ while i < len(tags):
101
+ t = tags[i]
102
+ if t == "O":
103
+ i += 1; continue
104
+ if t.endswith("-V"): # you can include/exclude the V span as you like
105
+ spans.append(("V", i, i))
106
+ i += 1; continue
107
+ if t.startswith("B-"):
108
+ role = t[2:]
109
+ j = i + 1
110
+ while j < len(tags) and tags[j] == f"I-{role}":
111
+ j += 1
112
+ spans.append((role, i, j-1))
113
+ i = j
114
+ else:
115
+ i += 1
116
+ return spans
117
+
118
+ # words = [word.text for word in words]
119
+ print("Sentence:", " ".join(words))
120
+ if not frames:
121
+ print(" (no predicates detected)")
122
+ return
123
+
124
+ for k, fr in enumerate(frames, 1):
125
+ tags = fr["tags"]
126
+ spans = fr.get("spans") or spans_from_bio(tags)
127
+ pred_idx = fr["predicate_index"]
128
+ pred = fr["predicate"]
129
+ p_bv = fr.get("p_bv", None)
130
+
131
+ print("\n" + "—"*60)
132
+ # head = f"Frame {k} — predicate: {pred} (idx {pred_idx})"
133
+ # if p_bv is not None:
134
+ # head += f" P(B-V)={p_bv:.3f}"
135
+ # print(head)
136
+
137
+ print(create_description(words, tags))
138
+
139
+ # Aggregate phrases per role for a clean summary
140
+ by_role = {}
141
+ for role, s, e in spans:
142
+ phrase = " ".join(words[s:e+1])
143
+ by_role.setdefault(role, []).append(phrase)
144
+
145
+ # Put V first, then core args, then ARGM*
146
+ order = (
147
+ (("V",),),
148
+ tuple((r,) for r in ["ARG0","ARG1","ARG2","ARG3","ARG4"]),
149
+ (tuple(sorted([r for r in by_role if r.startswith("ARGM")])),)
150
+ )
151
+ ordered_roles = []
152
+ for group in order:
153
+ for r in itertools.chain.from_iterable(group):
154
+ if r in by_role: ordered_roles.append(r)
155
+ # add any leftover roles
156
+ # for r in sorted(by_role):
157
+ # if r not in ordered_roles: ordered_roles.append(r)
158
+ # print("Predicate:")
159
+ # print(f" {r:<8}: {pred}")
160
+ # print("Roles:")
161
+ # for r in ordered_roles:
162
+ # joined = "; ".join(by_role[r])
163
+ # print(f" {r:<8}: {paint(joined, r)}")
164
+
165
+ if show_grid:
166
+ # token/tag grid aligned by column width
167
+ colw = [max(len(w), len(t)) for w, t in zip(words, tags)]
168
+ tok_row = " ".join(w.ljust(colw[i]) for i, w in enumerate(words))
169
+ tag_row = " ".join((t if t != "O" else ".").ljust(colw[i]) for i, t in enumerate(tags))
170
+ print("\nTOKEN:", tok_row)
171
+ print("LABEL:", tag_row)
172
+
173
+
174
+ def prediction(model_path, bert_name, sentence):
175
+ words, frames = main_predictor(model_path, bert_name, sentence)
176
+ print_srl_frames_pretty(words, frames, show_grid=True, color=False)
177
+
178
+ def prediction_formatted(model_path, bert_name, sentence):
179
+ words, frames = main_predictor(model_path, bert_name, sentence)
180
+ temp_result = create_dict(words, frames)
181
+
182
+ return temp_result