Upload predicator.py
Browse files- predicator.py +226 -42
predicator.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
| 1 |
-
## This is testing
|
| 2 |
|
|
|
|
|
|
|
|
|
|
| 3 |
import torch
|
| 4 |
|
| 5 |
@torch.no_grad()
|
|
@@ -64,6 +66,58 @@ def predict_srl_single(model, tokenizer, words, predicate_word_idx, id2label, de
|
|
| 64 |
tags = [id2label[i] for i in pred_ids]
|
| 65 |
return tags, logits.squeeze(0).cpu() # [L_word, num_labels]
|
| 66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
def bio_to_spans(tags):
|
| 68 |
spans = []
|
| 69 |
i = 0
|
|
@@ -84,58 +138,188 @@ def bio_to_spans(tags):
|
|
| 84 |
return spans
|
| 85 |
|
| 86 |
@torch.no_grad()
|
| 87 |
-
def
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
|
|
|
|
|
|
|
|
|
| 98 |
results = []
|
| 99 |
-
for p in
|
| 100 |
-
tags, logits = predict_srl_single(
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
|
|
|
| 119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
-
|
| 123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
-
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
|
| 130 |
-
|
| 131 |
-
sentence = "Hojeong decide to go to the school"
|
| 132 |
-
words, preds = predicator_srl(sentence)
|
| 133 |
-
print(words)
|
| 134 |
-
for r in preds:
|
| 135 |
-
print(f"Predicate: {r['predicate']} (idx {r['predicate_index']})")
|
| 136 |
-
print("Tags:", list(zip(words, r["tags"])))
|
| 137 |
-
print("Spans:", r["spans"]) # (ROLE, start, end) indices over words
|
| 138 |
-
print("-" * 60)
|
| 139 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
|
| 2 |
+
from SRL_model import SRL_BERT_model
|
| 3 |
+
from transformers import AutoTokenizer
|
| 4 |
+
import spacy
|
| 5 |
import torch
|
| 6 |
|
| 7 |
@torch.no_grad()
|
|
|
|
| 66 |
tags = [id2label[i] for i in pred_ids]
|
| 67 |
return tags, logits.squeeze(0).cpu() # [L_word, num_labels]
|
| 68 |
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def spacy_verb_indices(nlp, sentence: str,word_spans):
|
| 72 |
+
"""
|
| 73 |
+
Map spaCy POS to each tokenizer word span by max-overlap token,
|
| 74 |
+
and return indices that are verbs (VB*, or POS in {VERB, AUX}).
|
| 75 |
+
"""
|
| 76 |
+
doc = nlp(sentence)
|
| 77 |
+
verb_idxs = []
|
| 78 |
+
for i, (wb, we) in enumerate(word_spans):
|
| 79 |
+
# find spaCy token with maximum overlap
|
| 80 |
+
best_tok, best_olap = None, 0
|
| 81 |
+
for tok in doc:
|
| 82 |
+
tb, te = tok.idx, tok.idx + len(tok)
|
| 83 |
+
olap = max(0, min(we, te) - max(wb, tb))
|
| 84 |
+
if olap > best_olap:
|
| 85 |
+
best_olap = olap
|
| 86 |
+
best_tok = tok
|
| 87 |
+
if best_tok is None:
|
| 88 |
+
continue
|
| 89 |
+
is_verb = best_tok.tag_.startswith("VB") or best_tok.pos_ in {"VERB", "AUX"}
|
| 90 |
+
if is_verb:
|
| 91 |
+
verb_idxs.append(i)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
return verb_idxs
|
| 95 |
+
|
| 96 |
+
def words_and_spans_from_tokenizer(sentence: str, nlp, tokenizer):
|
| 97 |
+
"""
|
| 98 |
+
Returns:
|
| 99 |
+
words : list[str] (tokenizer-aligned words)
|
| 100 |
+
spans : list[(start_char, end_char)] for each word
|
| 101 |
+
"""
|
| 102 |
+
enc = tokenizer(sentence, add_special_tokens=False, return_offsets_mapping=True)
|
| 103 |
+
word_ids = enc.word_ids()
|
| 104 |
+
offsets = enc["offset_mapping"]
|
| 105 |
+
# print("Offsets:", offsets)
|
| 106 |
+
|
| 107 |
+
words, spans, seen = [], [], set()
|
| 108 |
+
for wid, (b, e) in zip(word_ids, offsets):
|
| 109 |
+
if wid is None or wid in seen:
|
| 110 |
+
continue
|
| 111 |
+
seen.add(wid)
|
| 112 |
+
words.append(sentence[b:e])
|
| 113 |
+
spans.append((b, e))
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
doc = nlp(sentence)
|
| 117 |
+
words_final = [token for token in doc]
|
| 118 |
+
|
| 119 |
+
return words, spans, words_final
|
| 120 |
+
|
| 121 |
def bio_to_spans(tags):
|
| 122 |
spans = []
|
| 123 |
i = 0
|
|
|
|
| 138 |
return spans
|
| 139 |
|
| 140 |
@torch.no_grad()
|
| 141 |
+
def predict_srl_allennlp_like_spacy(
|
| 142 |
+
model, tokenizer, nlp, sentence, id2label,
|
| 143 |
+
device="cuda"):
|
| 144 |
+
|
| 145 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 146 |
+
model.eval()
|
| 147 |
+
|
| 148 |
+
# 0) tokenizer-aligned words (and spans for spaCy alignment)
|
| 149 |
+
words, spans, words_final = words_and_spans_from_tokenizer(sentence, nlp,tokenizer)
|
| 150 |
+
# words = words_and_spans_from_tokenizer(sentence, nlp)
|
| 151 |
+
# words = sentence.split(' ')
|
| 152 |
+
# print(spans)
|
| 153 |
+
n = len(words)
|
| 154 |
+
if n == 0:
|
| 155 |
+
return [], []
|
| 156 |
+
|
| 157 |
+
# 1) verb candidates via spaCy POS (aligned to tokenizer words by overlap)
|
| 158 |
+
verb_idxs = spacy_verb_indices(nlp, sentence, spans)
|
| 159 |
+
# verb_idxs = spacy_verb_indices(nlp, sentence)
|
| 160 |
+
if not verb_idxs:
|
| 161 |
+
# no verbs → either return empty or consider all tokens as a fallback
|
| 162 |
+
return words, []
|
| 163 |
+
|
| 164 |
+
# 2) find the predicate label id ('B-V' or 'V')
|
| 165 |
+
pred_ids = [i for i, t in id2label.items() if t in ("B-V", "V")]
|
| 166 |
+
if not pred_ids:
|
| 167 |
+
raise ValueError("Label set has no predicate tag ('B-V' or 'V').")
|
| 168 |
+
b_v_id = pred_ids[0]
|
| 169 |
|
| 170 |
+
keep = verb_idxs
|
| 171 |
+
|
| 172 |
+
# 4) PASS 2: final tagging with indicator ON
|
| 173 |
results = []
|
| 174 |
+
for p in keep:
|
| 175 |
+
tags, logits = predict_srl_single(
|
| 176 |
+
model, tokenizer, words, p, id2label, device=device
|
| 177 |
+
# , use_indicator=True
|
| 178 |
+
)
|
| 179 |
+
# if require_core_arg and not has_core_argument(tags):
|
| 180 |
+
# continue
|
| 181 |
+
p_bv = torch.softmax(logits[p], dim=-1)[b_v_id].item()
|
| 182 |
+
spans_out = bio_to_spans(tags)
|
| 183 |
+
results.append({
|
| 184 |
+
"predicate_index": p,
|
| 185 |
+
"predicate": words[p],
|
| 186 |
+
"p_bv": p_bv,
|
| 187 |
+
"tags": tags,
|
| 188 |
+
"spans": spans_out
|
| 189 |
+
})
|
| 190 |
+
|
| 191 |
+
return words_final, results
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def create_description(words, tag_list):
|
| 195 |
+
desc_list = []
|
| 196 |
+
for tok, tag in zip(words, tag_list):
|
| 197 |
+
desc_list.append("["+tag+": "+tok+"]")
|
| 198 |
+
|
| 199 |
+
return (' ').join(desc_list)
|
| 200 |
|
| 201 |
+
def print_srl_frames_pretty(words, frames, show_grid=True, color=False):
|
| 202 |
+
"""
|
| 203 |
+
Pretty-print SRL frames.
|
| 204 |
+
- Description: Token+Labels
|
| 205 |
+
- Frames: Predicate/Roles
|
| 206 |
+
- show_grid: also print a token/label grid aligned by column
|
| 207 |
+
- color: add simple ANSI colors per role (terminal only)
|
| 208 |
+
"""
|
| 209 |
+
import itertools
|
| 210 |
|
| 211 |
+
# tiny colorizer (terminal); safe no-op if color=False
|
| 212 |
+
ANSI = {
|
| 213 |
+
"ARG0": "\033[38;5;34m", "ARG1": "\033[38;5;33m", "ARG2": "\033[38;5;129m",
|
| 214 |
+
"ARG3": "\033[38;5;172m", "ARG4": "\033[38;5;166m", "ARGM": "\033[38;5;244m",
|
| 215 |
+
"V": "\033[1;37m", "RESET": "\033[0m"
|
| 216 |
+
}
|
| 217 |
+
def paint(txt, role):
|
| 218 |
+
if not color: return txt
|
| 219 |
+
key = "ARGM" if role.startswith("ARGM") else ("V" if role.endswith("V") or role=="V" else role)
|
| 220 |
+
return f"{ANSI.get(key, '')}{txt}{ANSI['RESET']}"
|
| 221 |
|
| 222 |
+
def spans_from_bio(tags):
|
| 223 |
+
spans = []
|
| 224 |
+
i = 0
|
| 225 |
+
while i < len(tags):
|
| 226 |
+
t = tags[i]
|
| 227 |
+
if t == "O":
|
| 228 |
+
i += 1; continue
|
| 229 |
+
if t.endswith("-V"): # you can include/exclude the V span as you like
|
| 230 |
+
spans.append(("V", i, i))
|
| 231 |
+
i += 1; continue
|
| 232 |
+
if t.startswith("B-"):
|
| 233 |
+
role = t[2:]
|
| 234 |
+
j = i + 1
|
| 235 |
+
while j < len(tags) and tags[j] == f"I-{role}":
|
| 236 |
+
j += 1
|
| 237 |
+
spans.append((role, i, j-1))
|
| 238 |
+
i = j
|
| 239 |
+
else:
|
| 240 |
+
i += 1
|
| 241 |
+
return spans
|
| 242 |
+
|
| 243 |
+
words = [word.text for word in words]
|
| 244 |
+
print("Sentence:", " ".join(words))
|
| 245 |
+
if not frames:
|
| 246 |
+
print(" (no predicates detected)")
|
| 247 |
+
return
|
| 248 |
|
| 249 |
+
for k, fr in enumerate(frames, 1):
|
| 250 |
+
tags = fr["tags"]
|
| 251 |
+
spans = fr.get("spans") or spans_from_bio(tags)
|
| 252 |
+
pred_idx = fr["predicate_index"]
|
| 253 |
+
pred = fr["predicate"]
|
| 254 |
+
p_bv = fr.get("p_bv", None)
|
| 255 |
|
| 256 |
+
print("\n" + "—"*60)
|
| 257 |
|
| 258 |
+
print(create_description(words, tags))
|
| 259 |
+
|
| 260 |
+
# Aggregate phrases per role for a clean summary
|
| 261 |
+
by_role = {}
|
| 262 |
+
for role, s, e in spans:
|
| 263 |
+
phrase = " ".join(words[s:e+1])
|
| 264 |
+
by_role.setdefault(role, []).append(phrase)
|
| 265 |
|
| 266 |
+
# Put V first, then core args, then ARGM*
|
| 267 |
+
order = (
|
| 268 |
+
(("V",),),
|
| 269 |
+
tuple((r,) for r in ["ARG0","ARG1","ARG2","ARG3","ARG4"]),
|
| 270 |
+
(tuple(sorted([r for r in by_role if r.startswith("ARGM")])),)
|
| 271 |
+
)
|
| 272 |
+
ordered_roles = []
|
| 273 |
+
for group in order:
|
| 274 |
+
for r in itertools.chain.from_iterable(group):
|
| 275 |
+
if r in by_role: ordered_roles.append(r)
|
| 276 |
+
# add any leftover roles
|
| 277 |
+
for r in sorted(by_role):
|
| 278 |
+
if r not in ordered_roles: ordered_roles.append(r)
|
| 279 |
+
print("Predicate:")
|
| 280 |
+
print(f" {r:<8}: {pred}")
|
| 281 |
+
print("Roles:")
|
| 282 |
+
for r in ordered_roles:
|
| 283 |
+
joined = "; ".join(by_role[r])
|
| 284 |
+
print(f" {r:<8}: {paint(joined, r)}")
|
| 285 |
|
| 286 |
+
if show_grid:
|
| 287 |
+
# token/tag grid aligned by column width
|
| 288 |
+
colw = [max(len(w), len(t)) for w, t in zip(words, tags)]
|
| 289 |
+
tok_row = " ".join(w.ljust(colw[i]) for i, w in enumerate(words))
|
| 290 |
+
tag_row = " ".join((t if t != "O" else ".").ljust(colw[i]) for i, t in enumerate(tags))
|
| 291 |
+
print("\nTOKEN:", tok_row)
|
| 292 |
+
print("LABEL:", tag_row)
|
| 293 |
|
| 294 |
+
def main_predictor(model_path, bert_name, sentence):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
|
| 296 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 297 |
+
# model_path = "/blue/bonniejdorr/youms/SRL-Aware_Model/model/best_srl_Sep_29.ckpt"
|
| 298 |
+
ckpt = torch.load(model_path, map_location=device)
|
| 299 |
+
hp = ckpt["hparams"]
|
| 300 |
|
| 301 |
+
model = SRL_BERT_model.PredicateAwareSRL(**hp).to(device)
|
| 302 |
+
model.load_state_dict(ckpt["state_dict"])
|
| 303 |
+
model.eval()
|
| 304 |
+
|
| 305 |
+
label2id = ckpt["label2id"]
|
| 306 |
+
id2label = {v:k for k,v in label2id.items()}
|
| 307 |
+
|
| 308 |
+
# bert_name = "bert-large-cased" or "bert-based-cased"
|
| 309 |
+
bert_name = bert_name
|
| 310 |
+
tokenizer = AutoTokenizer.from_pretrained(bert_name)
|
| 311 |
+
|
| 312 |
+
nlp = spacy.load("en_core_web_md")
|
| 313 |
+
|
| 314 |
+
words, frames = predict_srl_allennlp_like_spacy(
|
| 315 |
+
model, tokenizer, nlp, sentence, id2label,
|
| 316 |
+
device=device,
|
| 317 |
+
prob_threshold=0.40, # tune on dev; try 0.3–0.6
|
| 318 |
+
top_k=None,
|
| 319 |
+
pick_best_if_none=True
|
| 320 |
+
)
|
| 321 |
+
return words, frames
|
| 322 |
|
| 323 |
+
if __name__ =="__main__":
|
| 324 |
+
words, frames = main_predictor(model_path, bert_namem sentence)
|
| 325 |
+
print_srl_frames_pretty(words, frames, show_grid=True, color=False)
|