yeomtong commited on
Commit
2a53d7b
·
verified ·
1 Parent(s): 237a0b9

Upload visualizer_up.py

Browse files
Files changed (1) hide show
  1. visualizer_up.py +213 -0
visualizer_up.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from predictor_up import main_predictor, _predict_cached, srl_init
2
+ import re
3
+ import itertools
4
+
5
+ def bio_brackets_to_spans(text: str) -> str:
6
+ """
7
+ Collapse BIO bracket chunks into non-BIO spans.
8
+ Example:
9
+ [B-ARG2: of] [I-ARG2: the] [I-ARG2: orchards] → [ARG2: of the orchards]
10
+ [B-V: take] → [V: take]
11
+ Non-bracket text (spaces, punctuation, quotes) is preserved.
12
+ """
13
+
14
+ BIO_RE = re.compile(r"\[(B|I)-([A-Za-z0-9\-]+):\s*([^\]]+?)\]")
15
+
16
+ out = []
17
+ i = 0
18
+ matches = list(BIO_RE.finditer(text))
19
+
20
+ m = 0
21
+ cursor = 0
22
+ while m < len(matches):
23
+ # plain text before next BIO chunk
24
+ out.append(text[cursor:matches[m].start()])
25
+
26
+ # start a run
27
+ prefix, role, tok = matches[m].groups()
28
+ tokens = [tok]
29
+ cursor = matches[m].end()
30
+ m += 1
31
+
32
+ # absorb subsequent I-<same role> chunks if only whitespace between
33
+ while m < len(matches):
34
+ between = text[cursor:matches[m].start()]
35
+ p2, role2, tok2 = matches[m].groups()
36
+ if role2 == role and p2 == "I" and between.strip() == "":
37
+ tokens.append(tok2)
38
+ cursor = matches[m].end()
39
+ m += 1
40
+ else:
41
+ break
42
+
43
+ # output merged span (drop B-/I-), keep V as just "V"
44
+ out.append(f"[{role}: {' '.join(tokens)}]")
45
+
46
+ # trailing text
47
+ out.append(text[cursor:])
48
+ return "".join(out)
49
+
50
+ def create_description(words, tag_list):
51
+ desc_list = []
52
+ for tok, tag in zip(words, tag_list):
53
+ if tag != 'O' :
54
+ desc_list.append("["+tag+": "+tok+"]")
55
+ else:
56
+ desc_list.append(tok)
57
+ desc_str_temp = (' ').join(desc_list)
58
+
59
+ return bio_brackets_to_spans(desc_str_temp)
60
+
61
+ def create_dict(words, frames):
62
+ final_dict = {}
63
+ verb = []
64
+ for f in frames:
65
+ temp_dict = {}
66
+ temp_dict['verb'] = f['predicate']
67
+ temp_dict['description'] = create_description(words, f['tags'])
68
+ temp_dict['tags'] = f['tags']
69
+ verb.append(temp_dict)
70
+ final_dict['verbs'] = verb
71
+ final_dict['words'] = words
72
+
73
+ return final_dict
74
+
75
+ def print_srl_frames_pretty(words, frames, show_grid=True, color=False):
76
+ """
77
+ Pretty-print SRL frames.
78
+ - Description: Token+Labels
79
+ - Frames: Predicate/Roles
80
+ - show_grid: also print a token/label grid aligned by column
81
+ - color: add simple ANSI colors per role (terminal only)
82
+ """
83
+
84
+
85
+ # tiny colorizer (terminal-only); safe no-op if color=False
86
+ ANSI = {
87
+ "ARG0": "\033[38;5;34m", "ARG1": "\033[38;5;33m", "ARG2": "\033[38;5;129m",
88
+ "ARG3": "\033[38;5;172m", "ARG4": "\033[38;5;166m", "ARGM": "\033[38;5;244m",
89
+ "V": "\033[1;37m", "RESET": "\033[0m"
90
+ }
91
+ def paint(txt, role):
92
+ if not color: return txt
93
+ key = "ARGM" if role.startswith("ARGM") else ("V" if role.endswith("V") or role=="V" else role)
94
+ return f"{ANSI.get(key, '')}{txt}{ANSI['RESET']}"
95
+
96
+ def spans_from_bio(tags):
97
+ spans = []
98
+ i = 0
99
+ while i < len(tags):
100
+ t = tags[i]
101
+ if t == "O":
102
+ i += 1; continue
103
+ if t.endswith("-V"): # may include/exclude the V span as you like
104
+ spans.append(("V", i, i))
105
+ i += 1; continue
106
+ if t.startswith("B-"):
107
+ role = t[2:]
108
+ j = i + 1
109
+ while j < len(tags) and tags[j] == f"I-{role}":
110
+ j += 1
111
+ spans.append((role, i, j-1))
112
+ i = j
113
+ else:
114
+ i += 1
115
+ return spans
116
+
117
+ # words = [word.text for word in words]
118
+ print("Sentence:", " ".join(words))
119
+ if not frames:
120
+ print(" (no predicates detected)")
121
+ return
122
+
123
+ for k, fr in enumerate(frames, 1):
124
+ tags = fr["tags"]
125
+ spans = fr.get("spans") or spans_from_bio(tags)
126
+ pred_idx = fr["predicate_index"]
127
+ pred = fr["predicate"]
128
+ p_bv = fr.get("p_bv", None)
129
+
130
+ print("\n" + "—"*60)
131
+ # head = f"Frame {k} — predicate: {pred} (idx {pred_idx})"
132
+ # if p_bv is not None:
133
+ # head += f" P(B-V)={p_bv:.3f}"
134
+ # print(head)
135
+
136
+ print(create_description(words, tags))
137
+
138
+ # Aggregate phrases per role for a clean summary
139
+ by_role = {}
140
+ for role, s, e in spans:
141
+ phrase = " ".join(words[s:e+1])
142
+ by_role.setdefault(role, []).append(phrase)
143
+
144
+ # Put V first, then core args, then ARGM*
145
+ order = (
146
+ (("V",),),
147
+ tuple((r,) for r in ["ARG0","ARG1","ARG2","ARG3","ARG4"]),
148
+ (tuple(sorted([r for r in by_role if r.startswith("ARGM")])),)
149
+ )
150
+ ordered_roles = []
151
+ for group in order:
152
+ for r in itertools.chain.from_iterable(group):
153
+ if r in by_role: ordered_roles.append(r)
154
+ # add any leftover roles
155
+ # for r in sorted(by_role):
156
+ # if r not in ordered_roles: ordered_roles.append(r)
157
+ # print("Predicate:")
158
+ # print(f" {r:<8}: {pred}")
159
+ # print("Roles:")
160
+ # for r in ordered_roles:
161
+ # joined = "; ".join(by_role[r])
162
+ # print(f" {r:<8}: {paint(joined, r)}")
163
+
164
+ if show_grid:
165
+ # token/tag grid aligned by column width
166
+ colw = [max(len(w), len(t)) for w, t in zip(words, tags)]
167
+ tok_row = " ".join(w.ljust(colw[i]) for i, w in enumerate(words))
168
+ tag_row = " ".join((t if t != "O" else ".").ljust(colw[i]) for i, t in enumerate(tags))
169
+ print("\nTOKEN:", tok_row)
170
+ print("LABEL:", tag_row)
171
+
172
+ def prediction(*args):
173
+ """
174
+ Two modes:
175
+ - prediction(sentence) # fast path (uses cache)
176
+ - prediction(model_path, bert_name, sentence) # backward-compatible one-shot
177
+ """
178
+ if len(args) == 1:
179
+ sentence = args[0]
180
+ words, frames = _predict_cached(sentence)
181
+ elif len(args) == 3:
182
+ model_path, bert_name, sentence = args
183
+ # one-shot: load then predict
184
+ srl_init(model_path, bert_name)
185
+ words, frames = _predict_cached(sentence)
186
+ else:
187
+ raise TypeError("prediction(...) expects (sentence) OR (model_path, bert_name, sentence)")
188
+
189
+ # your existing pretty-printer, if available
190
+ try:
191
+ print_srl_frames_pretty(words, frames, show_grid=True, color=False)
192
+ except NameError:
193
+ print("Sentence:", " ".join(words))
194
+ for fr in frames:
195
+ print(f"\nPredicate: {fr['predicate']} P(B-V)={fr['p_bv']:.3f}")
196
+ print("Tags:", list(zip(words, fr['tags'])))
197
+ print("Spans:", fr['spans'])
198
+
199
+ def prediction_formatted(*args):
200
+ """Same overload behavior, but returns the dict instead of printing."""
201
+ if len(args) == 1:
202
+ sentence = args[0]
203
+ words, frames = _predict_cached(sentence)
204
+ elif len(args) == 3:
205
+ model_path, bert_name, sentence = args
206
+ srl_init(model_path, bert_name)
207
+ words, frames = _predict_cached(sentence)
208
+ else:
209
+ raise TypeError("prediction_formatted(...) expects (sentence) OR (model_path, bert_name, sentence)")
210
+ try:
211
+ return create_dict(words, frames)
212
+ except NameError:
213
+ return {"words": words, "frames": frames}