srl_bert_model / visualizer_up.py
yeomtong's picture
Upload visualizer_up.py
2a53d7b verified
from predictor_up import main_predictor, _predict_cached, srl_init
import re
import itertools
def bio_brackets_to_spans(text: str) -> str:
"""
Collapse BIO bracket chunks into non-BIO spans.
Example:
[B-ARG2: of] [I-ARG2: the] [I-ARG2: orchards] → [ARG2: of the orchards]
[B-V: take] → [V: take]
Non-bracket text (spaces, punctuation, quotes) is preserved.
"""
BIO_RE = re.compile(r"\[(B|I)-([A-Za-z0-9\-]+):\s*([^\]]+?)\]")
out = []
i = 0
matches = list(BIO_RE.finditer(text))
m = 0
cursor = 0
while m < len(matches):
# plain text before next BIO chunk
out.append(text[cursor:matches[m].start()])
# start a run
prefix, role, tok = matches[m].groups()
tokens = [tok]
cursor = matches[m].end()
m += 1
# absorb subsequent I-<same role> chunks if only whitespace between
while m < len(matches):
between = text[cursor:matches[m].start()]
p2, role2, tok2 = matches[m].groups()
if role2 == role and p2 == "I" and between.strip() == "":
tokens.append(tok2)
cursor = matches[m].end()
m += 1
else:
break
# output merged span (drop B-/I-), keep V as just "V"
out.append(f"[{role}: {' '.join(tokens)}]")
# trailing text
out.append(text[cursor:])
return "".join(out)
def create_description(words, tag_list):
desc_list = []
for tok, tag in zip(words, tag_list):
if tag != 'O' :
desc_list.append("["+tag+": "+tok+"]")
else:
desc_list.append(tok)
desc_str_temp = (' ').join(desc_list)
return bio_brackets_to_spans(desc_str_temp)
def create_dict(words, frames):
final_dict = {}
verb = []
for f in frames:
temp_dict = {}
temp_dict['verb'] = f['predicate']
temp_dict['description'] = create_description(words, f['tags'])
temp_dict['tags'] = f['tags']
verb.append(temp_dict)
final_dict['verbs'] = verb
final_dict['words'] = words
return final_dict
def print_srl_frames_pretty(words, frames, show_grid=True, color=False):
"""
Pretty-print SRL frames.
- Description: Token+Labels
- Frames: Predicate/Roles
- show_grid: also print a token/label grid aligned by column
- color: add simple ANSI colors per role (terminal only)
"""
# tiny colorizer (terminal-only); safe no-op if color=False
ANSI = {
"ARG0": "\033[38;5;34m", "ARG1": "\033[38;5;33m", "ARG2": "\033[38;5;129m",
"ARG3": "\033[38;5;172m", "ARG4": "\033[38;5;166m", "ARGM": "\033[38;5;244m",
"V": "\033[1;37m", "RESET": "\033[0m"
}
def paint(txt, role):
if not color: return txt
key = "ARGM" if role.startswith("ARGM") else ("V" if role.endswith("V") or role=="V" else role)
return f"{ANSI.get(key, '')}{txt}{ANSI['RESET']}"
def spans_from_bio(tags):
spans = []
i = 0
while i < len(tags):
t = tags[i]
if t == "O":
i += 1; continue
if t.endswith("-V"): # may include/exclude the V span as you like
spans.append(("V", i, i))
i += 1; continue
if t.startswith("B-"):
role = t[2:]
j = i + 1
while j < len(tags) and tags[j] == f"I-{role}":
j += 1
spans.append((role, i, j-1))
i = j
else:
i += 1
return spans
# words = [word.text for word in words]
print("Sentence:", " ".join(words))
if not frames:
print(" (no predicates detected)")
return
for k, fr in enumerate(frames, 1):
tags = fr["tags"]
spans = fr.get("spans") or spans_from_bio(tags)
pred_idx = fr["predicate_index"]
pred = fr["predicate"]
p_bv = fr.get("p_bv", None)
print("\n" + "—"*60)
# head = f"Frame {k} — predicate: {pred} (idx {pred_idx})"
# if p_bv is not None:
# head += f" P(B-V)={p_bv:.3f}"
# print(head)
print(create_description(words, tags))
# Aggregate phrases per role for a clean summary
by_role = {}
for role, s, e in spans:
phrase = " ".join(words[s:e+1])
by_role.setdefault(role, []).append(phrase)
# Put V first, then core args, then ARGM*
order = (
(("V",),),
tuple((r,) for r in ["ARG0","ARG1","ARG2","ARG3","ARG4"]),
(tuple(sorted([r for r in by_role if r.startswith("ARGM")])),)
)
ordered_roles = []
for group in order:
for r in itertools.chain.from_iterable(group):
if r in by_role: ordered_roles.append(r)
# add any leftover roles
# for r in sorted(by_role):
# if r not in ordered_roles: ordered_roles.append(r)
# print("Predicate:")
# print(f" {r:<8}: {pred}")
# print("Roles:")
# for r in ordered_roles:
# joined = "; ".join(by_role[r])
# print(f" {r:<8}: {paint(joined, r)}")
if show_grid:
# token/tag grid aligned by column width
colw = [max(len(w), len(t)) for w, t in zip(words, tags)]
tok_row = " ".join(w.ljust(colw[i]) for i, w in enumerate(words))
tag_row = " ".join((t if t != "O" else ".").ljust(colw[i]) for i, t in enumerate(tags))
print("\nTOKEN:", tok_row)
print("LABEL:", tag_row)
def prediction(*args):
"""
Two modes:
- prediction(sentence) # fast path (uses cache)
- prediction(model_path, bert_name, sentence) # backward-compatible one-shot
"""
if len(args) == 1:
sentence = args[0]
words, frames = _predict_cached(sentence)
elif len(args) == 3:
model_path, bert_name, sentence = args
# one-shot: load then predict
srl_init(model_path, bert_name)
words, frames = _predict_cached(sentence)
else:
raise TypeError("prediction(...) expects (sentence) OR (model_path, bert_name, sentence)")
# your existing pretty-printer, if available
try:
print_srl_frames_pretty(words, frames, show_grid=True, color=False)
except NameError:
print("Sentence:", " ".join(words))
for fr in frames:
print(f"\nPredicate: {fr['predicate']} P(B-V)={fr['p_bv']:.3f}")
print("Tags:", list(zip(words, fr['tags'])))
print("Spans:", fr['spans'])
def prediction_formatted(*args):
"""Same overload behavior, but returns the dict instead of printing."""
if len(args) == 1:
sentence = args[0]
words, frames = _predict_cached(sentence)
elif len(args) == 3:
model_path, bert_name, sentence = args
srl_init(model_path, bert_name)
words, frames = _predict_cached(sentence)
else:
raise TypeError("prediction_formatted(...) expects (sentence) OR (model_path, bert_name, sentence)")
try:
return create_dict(words, frames)
except NameError:
return {"words": words, "frames": frames}