Delete testing.py
Browse files- testing.py +0 -80
testing.py
DELETED
|
@@ -1,80 +0,0 @@
|
|
| 1 |
-
from SRL_model import SRL_BERT_model
|
| 2 |
-
from collections import Counter
|
| 3 |
-
import torch
|
| 4 |
-
|
| 5 |
-
def bio_to_spans(tags):
|
| 6 |
-
spans = []
|
| 7 |
-
i = 0
|
| 8 |
-
while i < len(tags):
|
| 9 |
-
t = tags[i]
|
| 10 |
-
if t == "O" or t.endswith("-V"):
|
| 11 |
-
i += 1; continue
|
| 12 |
-
if t.startswith("B-"):
|
| 13 |
-
role = t[2:]; j = i + 1
|
| 14 |
-
while j < len(tags) and tags[j] == f"I-{role}":
|
| 15 |
-
j += 1
|
| 16 |
-
spans.append((role, i, j-1))
|
| 17 |
-
i = j
|
| 18 |
-
else:
|
| 19 |
-
i += 1
|
| 20 |
-
return spans
|
| 21 |
-
|
| 22 |
-
@torch.no_grad()
|
| 23 |
-
def eval_span_f1(model, dataloader, id2label, device="cuda"):
|
| 24 |
-
model.eval()
|
| 25 |
-
tp = fp = fn = 0
|
| 26 |
-
for batch in dataloader:
|
| 27 |
-
gold = batch["labels"] # [B, Lw]
|
| 28 |
-
mask = (gold != -100)
|
| 29 |
-
|
| 30 |
-
batch = {k:(v.to(device) if torch.is_tensor(v) else v) for k,v in batch.items()}
|
| 31 |
-
logits, _ = model(**batch)
|
| 32 |
-
pred = logits.argmax(-1).cpu() # [B, Lw]
|
| 33 |
-
print(pred)
|
| 34 |
-
for g_seq, p_seq, m in zip(gold, pred, mask):
|
| 35 |
-
gl = [id2label[int(i)] for i in g_seq[m].tolist()]
|
| 36 |
-
pl = [id2label[int(i)] for i in p_seq[m].tolist()]
|
| 37 |
-
G = Counter(bio_to_spans(gl))
|
| 38 |
-
P = Counter(bio_to_spans(pl))
|
| 39 |
-
# micro counts
|
| 40 |
-
common = G & P
|
| 41 |
-
tp += sum(common.values())
|
| 42 |
-
fp += sum(P.values()) - sum(common.values())
|
| 43 |
-
fn += sum(G.values()) - sum(common.values())
|
| 44 |
-
|
| 45 |
-
prec = tp / (tp + fp + 1e-12)
|
| 46 |
-
rec = tp / (tp + fn + 1e-12)
|
| 47 |
-
f1 = 2 * prec * rec / (prec + rec + 1e-12)
|
| 48 |
-
return prec, rec, f1
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
if __name__ =="__main__":
|
| 52 |
-
|
| 53 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 54 |
-
ckpt_path = "/blue/bonniejdorr/youms/SRL-Aware_Model/model/best_srl_Sep_29.ckpt" # <-- change if needed
|
| 55 |
-
ckpt = torch.load(ckpt_path, map_location=device)
|
| 56 |
-
hp = ckpt["hparams"]
|
| 57 |
-
|
| 58 |
-
model = SRL_BERT_model.PredicateAwareSRL(**hp).to(device)
|
| 59 |
-
model.load_state_dict(ckpt["state_dict"])
|
| 60 |
-
model.eval()
|
| 61 |
-
|
| 62 |
-
label2id = ckpt["label2id"]
|
| 63 |
-
id2label = {v: k for k, v in label2id.items()}
|
| 64 |
-
|
| 65 |
-
h = ckpt.get("hparams", {
|
| 66 |
-
"bert_name": "bert-base-cased",
|
| 67 |
-
"num_labels": len(label2id),
|
| 68 |
-
"use_indicator": True,
|
| 69 |
-
"use_distance": True,
|
| 70 |
-
"indicator_dim": 10,
|
| 71 |
-
"lstm_hidden": 768,
|
| 72 |
-
"mlp_hidden": 300,
|
| 73 |
-
"pos_dim": 50,
|
| 74 |
-
"max_distance": 128,
|
| 75 |
-
"dropout": 0.1,
|
| 76 |
-
})
|
| 77 |
-
|
| 78 |
-
#test_loader from SRL_BERT_model
|
| 79 |
-
prec, rec, span_f1 = eval_span_f1(model, test_loader, id2label, device=device)
|
| 80 |
-
print(f"[TEST-SPAN] P={prec:.3f} R={rec:.3f} F1={span_f1:.3f}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|