| from predictor import main_predictor, _predict_cached, srl_init |
| import re |
| import itertools |
|
|
| def bio_brackets_to_spans(text: str) -> str: |
| """ |
| Collapse BIO bracket chunks into non-BIO spans. |
| Example: |
| [B-ARG2: of] [I-ARG2: the] [I-ARG2: orchards] → [ARG2: of the orchards] |
| [B-V: take] → [V: take] |
| Non-bracket text (spaces, punctuation, quotes) is preserved. |
| """ |
|
|
| BIO_RE = re.compile(r"\[(B|I)-([A-Za-z0-9\-]+):\s*([^\]]+?)\]") |
|
|
| out = [] |
| i = 0 |
| matches = list(BIO_RE.finditer(text)) |
|
|
| m = 0 |
| cursor = 0 |
| while m < len(matches): |
| |
| out.append(text[cursor:matches[m].start()]) |
|
|
| |
| prefix, role, tok = matches[m].groups() |
| tokens = [tok] |
| cursor = matches[m].end() |
| m += 1 |
|
|
| |
| while m < len(matches): |
| between = text[cursor:matches[m].start()] |
| p2, role2, tok2 = matches[m].groups() |
| if role2 == role and p2 == "I" and between.strip() == "": |
| tokens.append(tok2) |
| cursor = matches[m].end() |
| m += 1 |
| else: |
| break |
|
|
| |
| out.append(f"[{role}: {' '.join(tokens)}]") |
|
|
| |
| out.append(text[cursor:]) |
| return "".join(out) |
|
|
| def create_description(words, tag_list): |
| desc_list = [] |
| for tok, tag in zip(words, tag_list): |
| if tag != 'O' : |
| desc_list.append("["+tag+": "+tok+"]") |
| else: |
| desc_list.append(tok) |
| desc_str_temp = (' ').join(desc_list) |
|
|
| return bio_brackets_to_spans(desc_str_temp) |
|
|
| def create_dict(words, frames): |
| final_dict = {} |
| verb = [] |
| for f in frames: |
| temp_dict = {} |
| temp_dict['verb'] = f['predicate'] |
| temp_dict['description'] = create_description(words, f['tags']) |
| temp_dict['tags'] = f['tags'] |
| verb.append(temp_dict) |
| final_dict['verbs'] = verb |
| final_dict['words'] = words |
|
|
| return final_dict |
|
|
| def print_srl_frames_pretty(words, frames, show_grid=True, color=False): |
| """ |
| Pretty-print SRL frames. |
| - Description: Token+Labels |
| - Frames: Predicate/Roles |
| - show_grid: also print a token/label grid aligned by column |
| - color: add simple ANSI colors per role (terminal only) |
| """ |
|
|
|
|
| |
| ANSI = { |
| "ARG0": "\033[38;5;34m", "ARG1": "\033[38;5;33m", "ARG2": "\033[38;5;129m", |
| "ARG3": "\033[38;5;172m", "ARG4": "\033[38;5;166m", "ARGM": "\033[38;5;244m", |
| "V": "\033[1;37m", "RESET": "\033[0m" |
| } |
| def paint(txt, role): |
| if not color: return txt |
| key = "ARGM" if role.startswith("ARGM") else ("V" if role.endswith("V") or role=="V" else role) |
| return f"{ANSI.get(key, '')}{txt}{ANSI['RESET']}" |
|
|
| def spans_from_bio(tags): |
| spans = [] |
| i = 0 |
| while i < len(tags): |
| t = tags[i] |
| if t == "O": |
| i += 1; continue |
| if t.endswith("-V"): |
| spans.append(("V", i, i)) |
| i += 1; continue |
| if t.startswith("B-"): |
| role = t[2:] |
| j = i + 1 |
| while j < len(tags) and tags[j] == f"I-{role}": |
| j += 1 |
| spans.append((role, i, j-1)) |
| i = j |
| else: |
| i += 1 |
| return spans |
|
|
| |
| print("Sentence:", " ".join(words)) |
| if not frames: |
| print(" (no predicates detected)") |
| return |
|
|
| for k, fr in enumerate(frames, 1): |
| tags = fr["tags"] |
| spans = fr.get("spans") or spans_from_bio(tags) |
| pred_idx = fr["predicate_index"] |
| pred = fr["predicate"] |
| p_bv = fr.get("p_bv", None) |
|
|
| print("\n" + "—"*60) |
| |
| |
| |
| |
| |
| print(create_description(words, tags)) |
| |
| |
| by_role = {} |
| for role, s, e in spans: |
| phrase = " ".join(words[s:e+1]) |
| by_role.setdefault(role, []).append(phrase) |
|
|
| |
| order = ( |
| (("V",),), |
| tuple((r,) for r in ["ARG0","ARG1","ARG2","ARG3","ARG4"]), |
| (tuple(sorted([r for r in by_role if r.startswith("ARGM")])),) |
| ) |
| ordered_roles = [] |
| for group in order: |
| for r in itertools.chain.from_iterable(group): |
| if r in by_role: ordered_roles.append(r) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| if show_grid: |
| |
| colw = [max(len(w), len(t)) for w, t in zip(words, tags)] |
| tok_row = " ".join(w.ljust(colw[i]) for i, w in enumerate(words)) |
| tag_row = " ".join((t if t != "O" else ".").ljust(colw[i]) for i, t in enumerate(tags)) |
| print("\nTOKEN:", tok_row) |
| print("LABEL:", tag_row) |
|
|
| def prediction(*args): |
| """ |
| Two modes: |
| - prediction(sentence) # fast path (uses cache) |
| - prediction(model_path, bert_name, sentence) # backward-compatible one-shot |
| """ |
| if len(args) == 1: |
| sentence = args[0] |
| words, frames = _predict_cached(sentence) |
| elif len(args) == 3: |
| model_path, bert_name, sentence = args |
| |
| srl_init(model_path, bert_name) |
| words, frames = _predict_cached(sentence) |
| else: |
| raise TypeError("prediction(...) expects (sentence) OR (model_path, bert_name, sentence)") |
|
|
| |
| try: |
| print_srl_frames_pretty(words, frames, show_grid=True, color=False) |
| except NameError: |
| print("Sentence:", " ".join(words)) |
| for fr in frames: |
| print(f"\nPredicate: {fr['predicate']} P(B-V)={fr['p_bv']:.3f}") |
| print("Tags:", list(zip(words, fr['tags']))) |
| print("Spans:", fr['spans']) |
|
|
| def prediction_formatted(*args): |
| """Same overload behavior, but returns the dict instead of printing.""" |
| if len(args) == 1: |
| sentence = args[0] |
| words, frames = _predict_cached(sentence) |
| elif len(args) == 3: |
| model_path, bert_name, sentence = args |
| srl_init(model_path, bert_name) |
| words, frames = _predict_cached(sentence) |
| else: |
| raise TypeError("prediction_formatted(...) expects (sentence) OR (model_path, bert_name, sentence)") |
| try: |
| return create_dict(words, frames) |
| except NameError: |
| return {"words": words, "frames": frames} |