latin-bert / tests /test_wsd.py
diyclassics's picture
refactor: extract shared case study utils and move data to tracked paths
f04d50f
"""Word sense disambiguation case study -- Bamman & Burns (2020) Table 2.
Reproduces the WSD experiment: for each polysemous lemma (201 total),
train a binary classifier (sense I vs sense II) on BERT embeddings with
10-fold cross-validation. Pick the best epoch on dev, evaluate on test.
Reference results (from original logs):
OVERALL: epoch 9, accuracy 0.754, n=1070
"""
import random
import numpy as np
import pytest
import torch
import torch.optim as optim
from transformers import AutoTokenizer, BertModel
from case_study_utils import (
BATCH_SIZE,
BertForSequenceLabeling,
WSD_DATA_PATH,
)
random.seed(1)
torch.manual_seed(0)
np.random.seed(0)
REF_ACCURACY = 0.754
TOLERANCE = 0.02 # WSD has more variance due to per-lemma training
MAX_EPOCHS = 100
def _get_labs(before, target, after, label):
"""Build a labeled sentence for WSD.
Only the target word gets a real label (0 or 1); all context words
get -100 (ignored by CrossEntropyLoss).
"""
sent = []
for word in before.split(" "):
if word:
sent.append((word, -100))
sent.append((target, label))
for word in after.split(" "):
if word:
sent.append((word, -100))
return sent
def _read_wsd_data(filename):
"""Read WSD data file, return dict of lemma -> {0: [...], 1: [...]}."""
lemmas = {}
with open(filename) as f:
for line in f:
cols = line.split("\t")
lemma = cols[0]
label = cols[1]
before = cols[2]
target = cols[3]
after = cols[4].rstrip()
if lemma not in lemmas:
lemmas[lemma] = {0: [], 1: []}
if label == "I":
lemmas[lemma][0].append(
_get_labs(before, target, after, 0)
)
elif label == "II":
lemmas[lemma][1].append(
_get_labs(before, target, after, 1)
)
return lemmas
def _get_splits(data):
"""10-fold cross-validation splits."""
trains, tests, devs = [], [], []
for _i in range(10):
trains.append([])
tests.append([])
devs.append([])
for sense_idx in [0, 1]:
for idx, sent in enumerate(data[sense_idx]):
test_fold = idx % 10
dev_fold = test_fold - 1 if test_fold > 0 else 9
for i in range(10):
if i == test_fold:
tests[i].append(sent)
elif i == dev_fold:
devs[i].append(sent)
else:
trains[i].append(sent)
for i in range(10):
random.shuffle(trains[i])
random.shuffle(tests[i])
random.shuffle(devs[i])
return trains, devs, tests
def _evaluate(model, batched_data, batched_mask, batched_labels,
batched_transforms, device):
"""Evaluate model on batched data, return (correct, total)."""
model.eval()
cor = 0
tot = 0
with torch.no_grad():
for b in range(len(batched_data)):
logits = model(
batched_data[b].to(device),
attention_mask=batched_mask[b],
transforms=batched_transforms[b],
)
logits = logits.cpu()
size = batched_labels[b].shape
logits = logits.view(-1, size[1], 2)
preds = np.argmax(logits.numpy(), axis=2)
for row in range(size[0]):
for col in range(size[1]):
if batched_labels[b][row][col] != -100:
if preds[row][col] == batched_labels[b][row][col]:
cor += 1
tot += 1
return cor, tot
@pytest.mark.slow
def test_wsd_accuracy(model_path):
"""Reproduce WSD case study from Bamman & Burns (2020)."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(
model_path, trust_remote_code=True
)
data = _read_wsd_data(str(WSD_DATA_PATH))
dev_cors = [0.0] * MAX_EPOCHS
test_cors = [0.0] * MAX_EPOCHS
dev_n = [0.0] * MAX_EPOCHS
test_n = [0.0] * MAX_EPOCHS
for lemma_idx, lemma in enumerate(data):
print(f"\n[{lemma_idx + 1}/{len(data)}] {lemma}")
bert_model = BertModel.from_pretrained(model_path)
model = BertForSequenceLabeling(
tokenizer, bert_model, freeze_bert=False, num_labels=2
)
model.to(device)
trains, devs, tests = _get_splits(data[lemma])
train_b, train_m, train_l, train_t, _ = model.get_batches(
trains[0], BATCH_SIZE
)
dev_b, dev_m, dev_l, dev_t, _ = model.get_batches(
devs[0], BATCH_SIZE
)
test_b, test_m, test_l, test_t, _ = model.get_batches(
tests[0], BATCH_SIZE
)
optimizer = optim.Adam(model.parameters(), lr=5e-5)
for epoch in range(MAX_EPOCHS):
model.train()
for b in range(len(train_b)):
loss = model(
train_b[b].to(device),
attention_mask=train_m[b],
transforms=train_t[b],
labels=train_l[b],
)
loss.backward()
optimizer.step()
model.zero_grad()
c, t = _evaluate(model, dev_b, dev_m, dev_l, dev_t, device)
dev_cors[epoch] += c
dev_n[epoch] += t
c, t = _evaluate(model, test_b, test_m, test_l, test_t, device)
test_cors[epoch] += c
test_n[epoch] += t
for epoch in range(MAX_EPOCHS):
if dev_n[epoch] > 0:
dev_acc = dev_cors[epoch] / dev_n[epoch]
print(
f" DEV: epoch={epoch} acc={dev_acc:.3f} "
f"lemma={lemma} n={dev_n[epoch]}"
)
best_epoch = max(
range(MAX_EPOCHS),
key=lambda i: dev_cors[i] / dev_n[i] if dev_n[i] > 0 else 0,
)
test_accuracy = test_cors[best_epoch] / test_n[best_epoch]
print(
f"\nOVERALL: epoch={best_epoch}, "
f"accuracy={test_accuracy:.3f}, "
f"n={test_n[best_epoch]}"
)
print(f"Reference: epoch=9, accuracy={REF_ACCURACY}, n=1070")
assert abs(test_accuracy - REF_ACCURACY) < TOLERANCE, (
f"WSD accuracy {test_accuracy:.3f} outside tolerance "
f"of {REF_ACCURACY} +/- {TOLERANCE}"
)