yeomtong commited on
Commit
db3c303
·
verified ·
1 Parent(s): 47a2c8c

Delete testing.py

Browse files
Files changed (1) hide show
  1. testing.py +0 -80
testing.py DELETED
@@ -1,80 +0,0 @@
1
- from SRL_model import SRL_BERT_model
2
- from collections import Counter
3
- import torch
4
-
5
- def bio_to_spans(tags):
6
- spans = []
7
- i = 0
8
- while i < len(tags):
9
- t = tags[i]
10
- if t == "O" or t.endswith("-V"):
11
- i += 1; continue
12
- if t.startswith("B-"):
13
- role = t[2:]; j = i + 1
14
- while j < len(tags) and tags[j] == f"I-{role}":
15
- j += 1
16
- spans.append((role, i, j-1))
17
- i = j
18
- else:
19
- i += 1
20
- return spans
21
-
22
- @torch.no_grad()
23
- def eval_span_f1(model, dataloader, id2label, device="cuda"):
24
- model.eval()
25
- tp = fp = fn = 0
26
- for batch in dataloader:
27
- gold = batch["labels"] # [B, Lw]
28
- mask = (gold != -100)
29
-
30
- batch = {k:(v.to(device) if torch.is_tensor(v) else v) for k,v in batch.items()}
31
- logits, _ = model(**batch)
32
- pred = logits.argmax(-1).cpu() # [B, Lw]
33
- print(pred)
34
- for g_seq, p_seq, m in zip(gold, pred, mask):
35
- gl = [id2label[int(i)] for i in g_seq[m].tolist()]
36
- pl = [id2label[int(i)] for i in p_seq[m].tolist()]
37
- G = Counter(bio_to_spans(gl))
38
- P = Counter(bio_to_spans(pl))
39
- # micro counts
40
- common = G & P
41
- tp += sum(common.values())
42
- fp += sum(P.values()) - sum(common.values())
43
- fn += sum(G.values()) - sum(common.values())
44
-
45
- prec = tp / (tp + fp + 1e-12)
46
- rec = tp / (tp + fn + 1e-12)
47
- f1 = 2 * prec * rec / (prec + rec + 1e-12)
48
- return prec, rec, f1
49
-
50
-
51
- if __name__ =="__main__":
52
-
53
- device = "cuda" if torch.cuda.is_available() else "cpu"
54
- ckpt_path = "/blue/bonniejdorr/youms/SRL-Aware_Model/model/best_srl_Sep_29.ckpt" # <-- change if needed
55
- ckpt = torch.load(ckpt_path, map_location=device)
56
- hp = ckpt["hparams"]
57
-
58
- model = SRL_BERT_model.PredicateAwareSRL(**hp).to(device)
59
- model.load_state_dict(ckpt["state_dict"])
60
- model.eval()
61
-
62
- label2id = ckpt["label2id"]
63
- id2label = {v: k for k, v in label2id.items()}
64
-
65
- h = ckpt.get("hparams", {
66
- "bert_name": "bert-base-cased",
67
- "num_labels": len(label2id),
68
- "use_indicator": True,
69
- "use_distance": True,
70
- "indicator_dim": 10,
71
- "lstm_hidden": 768,
72
- "mlp_hidden": 300,
73
- "pos_dim": 50,
74
- "max_distance": 128,
75
- "dropout": 0.1,
76
- })
77
-
78
- #test_loader from SRL_BERT_model
79
- prec, rec, span_f1 = eval_span_f1(model, test_loader, id2label, device=device)
80
- print(f"[TEST-SPAN] P={prec:.3f} R={rec:.3f} F1={span_f1:.3f}")