from predictor import main_predictor, _predict_cached, srl_init import re import itertools def bio_brackets_to_spans(text: str) -> str: """ Collapse BIO bracket chunks into non-BIO spans. Example: [B-ARG2: of] [I-ARG2: the] [I-ARG2: orchards] → [ARG2: of the orchards] [B-V: take] → [V: take] Non-bracket text (spaces, punctuation, quotes) is preserved. """ BIO_RE = re.compile(r"\[(B|I)-([A-Za-z0-9\-]+):\s*([^\]]+?)\]") out = [] i = 0 matches = list(BIO_RE.finditer(text)) m = 0 cursor = 0 while m < len(matches): # plain text before next BIO chunk out.append(text[cursor:matches[m].start()]) # start a run prefix, role, tok = matches[m].groups() tokens = [tok] cursor = matches[m].end() m += 1 # absorb subsequent I- chunks if only whitespace between while m < len(matches): between = text[cursor:matches[m].start()] p2, role2, tok2 = matches[m].groups() if role2 == role and p2 == "I" and between.strip() == "": tokens.append(tok2) cursor = matches[m].end() m += 1 else: break # output merged span (drop B-/I-), keep V as just "V" out.append(f"[{role}: {' '.join(tokens)}]") # trailing text out.append(text[cursor:]) return "".join(out) def create_description(words, tag_list): desc_list = [] for tok, tag in zip(words, tag_list): if tag != 'O' : desc_list.append("["+tag+": "+tok+"]") else: desc_list.append(tok) desc_str_temp = (' ').join(desc_list) return bio_brackets_to_spans(desc_str_temp) def create_dict(words, frames): final_dict = {} verb = [] for f in frames: temp_dict = {} temp_dict['verb'] = f['predicate'] temp_dict['description'] = create_description(words, f['tags']) temp_dict['tags'] = f['tags'] verb.append(temp_dict) final_dict['verbs'] = verb final_dict['words'] = words return final_dict def print_srl_frames_pretty(words, frames, show_grid=True, color=False): """ Pretty-print SRL frames. - Description: Token+Labels - Frames: Predicate/Roles - show_grid: also print a token/label grid aligned by column - color: add simple ANSI colors per role (terminal only) """ # tiny colorizer (terminal-only); safe no-op if color=False ANSI = { "ARG0": "\033[38;5;34m", "ARG1": "\033[38;5;33m", "ARG2": "\033[38;5;129m", "ARG3": "\033[38;5;172m", "ARG4": "\033[38;5;166m", "ARGM": "\033[38;5;244m", "V": "\033[1;37m", "RESET": "\033[0m" } def paint(txt, role): if not color: return txt key = "ARGM" if role.startswith("ARGM") else ("V" if role.endswith("V") or role=="V" else role) return f"{ANSI.get(key, '')}{txt}{ANSI['RESET']}" def spans_from_bio(tags): spans = [] i = 0 while i < len(tags): t = tags[i] if t == "O": i += 1; continue if t.endswith("-V"): # may include/exclude the V span as you like spans.append(("V", i, i)) i += 1; continue if t.startswith("B-"): role = t[2:] j = i + 1 while j < len(tags) and tags[j] == f"I-{role}": j += 1 spans.append((role, i, j-1)) i = j else: i += 1 return spans # words = [word.text for word in words] print("Sentence:", " ".join(words)) if not frames: print(" (no predicates detected)") return for k, fr in enumerate(frames, 1): tags = fr["tags"] spans = fr.get("spans") or spans_from_bio(tags) pred_idx = fr["predicate_index"] pred = fr["predicate"] p_bv = fr.get("p_bv", None) print("\n" + "—"*60) # head = f"Frame {k} — predicate: {pred} (idx {pred_idx})" # if p_bv is not None: # head += f" P(B-V)={p_bv:.3f}" # print(head) print(create_description(words, tags)) # Aggregate phrases per role for a clean summary by_role = {} for role, s, e in spans: phrase = " ".join(words[s:e+1]) by_role.setdefault(role, []).append(phrase) # Put V first, then core args, then ARGM* order = ( (("V",),), tuple((r,) for r in ["ARG0","ARG1","ARG2","ARG3","ARG4"]), (tuple(sorted([r for r in by_role if r.startswith("ARGM")])),) ) ordered_roles = [] for group in order: for r in itertools.chain.from_iterable(group): if r in by_role: ordered_roles.append(r) # add any leftover roles # for r in sorted(by_role): # if r not in ordered_roles: ordered_roles.append(r) # print("Predicate:") # print(f" {r:<8}: {pred}") # print("Roles:") # for r in ordered_roles: # joined = "; ".join(by_role[r]) # print(f" {r:<8}: {paint(joined, r)}") if show_grid: # token/tag grid aligned by column width colw = [max(len(w), len(t)) for w, t in zip(words, tags)] tok_row = " ".join(w.ljust(colw[i]) for i, w in enumerate(words)) tag_row = " ".join((t if t != "O" else ".").ljust(colw[i]) for i, t in enumerate(tags)) print("\nTOKEN:", tok_row) print("LABEL:", tag_row) def prediction(*args): """ Two modes: - prediction(sentence) # fast path (uses cache) - prediction(model_path, bert_name, sentence) # backward-compatible one-shot """ if len(args) == 1: sentence = args[0] words, frames = _predict_cached(sentence) elif len(args) == 3: model_path, bert_name, sentence = args # one-shot: load then predict srl_init(model_path, bert_name) words, frames = _predict_cached(sentence) else: raise TypeError("prediction(...) expects (sentence) OR (model_path, bert_name, sentence)") # your existing pretty-printer, if available try: print_srl_frames_pretty(words, frames, show_grid=True, color=False) except NameError: print("Sentence:", " ".join(words)) for fr in frames: print(f"\nPredicate: {fr['predicate']} P(B-V)={fr['p_bv']:.3f}") print("Tags:", list(zip(words, fr['tags']))) print("Spans:", fr['spans']) def prediction_formatted(*args): """Same overload behavior, but returns the dict instead of printing.""" if len(args) == 1: sentence = args[0] words, frames = _predict_cached(sentence) elif len(args) == 3: model_path, bert_name, sentence = args srl_init(model_path, bert_name) words, frames = _predict_cached(sentence) else: raise TypeError("prediction_formatted(...) expects (sentence) OR (model_path, bert_name, sentence)") try: return create_dict(words, frames) except NameError: return {"words": words, "frames": frames}