| """Word sense disambiguation case study -- Bamman & Burns (2020) Table 2. |
| |
| Reproduces the WSD experiment: for each polysemous lemma (201 total), |
| train a binary classifier (sense I vs sense II) on BERT embeddings with |
| 10-fold cross-validation. Pick the best epoch on dev, evaluate on test. |
| |
| Reference results (from original logs): |
| OVERALL: epoch 9, accuracy 0.754, n=1070 |
| """ |
|
|
| import random |
|
|
| import numpy as np |
| import pytest |
| import torch |
| import torch.optim as optim |
| from transformers import AutoTokenizer, BertModel |
|
|
| from case_study_utils import ( |
| BATCH_SIZE, |
| BertForSequenceLabeling, |
| WSD_DATA_PATH, |
| ) |
|
|
| random.seed(1) |
| torch.manual_seed(0) |
| np.random.seed(0) |
|
|
| REF_ACCURACY = 0.754 |
| TOLERANCE = 0.02 |
| MAX_EPOCHS = 100 |
|
|
|
|
| def _get_labs(before, target, after, label): |
| """Build a labeled sentence for WSD. |
| |
| Only the target word gets a real label (0 or 1); all context words |
| get -100 (ignored by CrossEntropyLoss). |
| """ |
| sent = [] |
| for word in before.split(" "): |
| if word: |
| sent.append((word, -100)) |
| sent.append((target, label)) |
| for word in after.split(" "): |
| if word: |
| sent.append((word, -100)) |
| return sent |
|
|
|
|
| def _read_wsd_data(filename): |
| """Read WSD data file, return dict of lemma -> {0: [...], 1: [...]}.""" |
| lemmas = {} |
| with open(filename) as f: |
| for line in f: |
| cols = line.split("\t") |
| lemma = cols[0] |
| label = cols[1] |
| before = cols[2] |
| target = cols[3] |
| after = cols[4].rstrip() |
| if lemma not in lemmas: |
| lemmas[lemma] = {0: [], 1: []} |
| if label == "I": |
| lemmas[lemma][0].append( |
| _get_labs(before, target, after, 0) |
| ) |
| elif label == "II": |
| lemmas[lemma][1].append( |
| _get_labs(before, target, after, 1) |
| ) |
| return lemmas |
|
|
|
|
| def _get_splits(data): |
| """10-fold cross-validation splits.""" |
| trains, tests, devs = [], [], [] |
| for _i in range(10): |
| trains.append([]) |
| tests.append([]) |
| devs.append([]) |
|
|
| for sense_idx in [0, 1]: |
| for idx, sent in enumerate(data[sense_idx]): |
| test_fold = idx % 10 |
| dev_fold = test_fold - 1 if test_fold > 0 else 9 |
| for i in range(10): |
| if i == test_fold: |
| tests[i].append(sent) |
| elif i == dev_fold: |
| devs[i].append(sent) |
| else: |
| trains[i].append(sent) |
|
|
| for i in range(10): |
| random.shuffle(trains[i]) |
| random.shuffle(tests[i]) |
| random.shuffle(devs[i]) |
|
|
| return trains, devs, tests |
|
|
|
|
| def _evaluate(model, batched_data, batched_mask, batched_labels, |
| batched_transforms, device): |
| """Evaluate model on batched data, return (correct, total).""" |
| model.eval() |
| cor = 0 |
| tot = 0 |
| with torch.no_grad(): |
| for b in range(len(batched_data)): |
| logits = model( |
| batched_data[b].to(device), |
| attention_mask=batched_mask[b], |
| transforms=batched_transforms[b], |
| ) |
| logits = logits.cpu() |
| size = batched_labels[b].shape |
| logits = logits.view(-1, size[1], 2) |
| preds = np.argmax(logits.numpy(), axis=2) |
| for row in range(size[0]): |
| for col in range(size[1]): |
| if batched_labels[b][row][col] != -100: |
| if preds[row][col] == batched_labels[b][row][col]: |
| cor += 1 |
| tot += 1 |
| return cor, tot |
|
|
|
|
| @pytest.mark.slow |
| def test_wsd_accuracy(model_path): |
| """Reproduce WSD case study from Bamman & Burns (2020).""" |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| tokenizer = AutoTokenizer.from_pretrained( |
| model_path, trust_remote_code=True |
| ) |
| data = _read_wsd_data(str(WSD_DATA_PATH)) |
|
|
| dev_cors = [0.0] * MAX_EPOCHS |
| test_cors = [0.0] * MAX_EPOCHS |
| dev_n = [0.0] * MAX_EPOCHS |
| test_n = [0.0] * MAX_EPOCHS |
|
|
| for lemma_idx, lemma in enumerate(data): |
| print(f"\n[{lemma_idx + 1}/{len(data)}] {lemma}") |
|
|
| bert_model = BertModel.from_pretrained(model_path) |
| model = BertForSequenceLabeling( |
| tokenizer, bert_model, freeze_bert=False, num_labels=2 |
| ) |
| model.to(device) |
|
|
| trains, devs, tests = _get_splits(data[lemma]) |
|
|
| train_b, train_m, train_l, train_t, _ = model.get_batches( |
| trains[0], BATCH_SIZE |
| ) |
| dev_b, dev_m, dev_l, dev_t, _ = model.get_batches( |
| devs[0], BATCH_SIZE |
| ) |
| test_b, test_m, test_l, test_t, _ = model.get_batches( |
| tests[0], BATCH_SIZE |
| ) |
|
|
| optimizer = optim.Adam(model.parameters(), lr=5e-5) |
|
|
| for epoch in range(MAX_EPOCHS): |
| model.train() |
| for b in range(len(train_b)): |
| loss = model( |
| train_b[b].to(device), |
| attention_mask=train_m[b], |
| transforms=train_t[b], |
| labels=train_l[b], |
| ) |
| loss.backward() |
| optimizer.step() |
| model.zero_grad() |
|
|
| c, t = _evaluate(model, dev_b, dev_m, dev_l, dev_t, device) |
| dev_cors[epoch] += c |
| dev_n[epoch] += t |
|
|
| c, t = _evaluate(model, test_b, test_m, test_l, test_t, device) |
| test_cors[epoch] += c |
| test_n[epoch] += t |
|
|
| for epoch in range(MAX_EPOCHS): |
| if dev_n[epoch] > 0: |
| dev_acc = dev_cors[epoch] / dev_n[epoch] |
| print( |
| f" DEV: epoch={epoch} acc={dev_acc:.3f} " |
| f"lemma={lemma} n={dev_n[epoch]}" |
| ) |
|
|
| best_epoch = max( |
| range(MAX_EPOCHS), |
| key=lambda i: dev_cors[i] / dev_n[i] if dev_n[i] > 0 else 0, |
| ) |
| test_accuracy = test_cors[best_epoch] / test_n[best_epoch] |
|
|
| print( |
| f"\nOVERALL: epoch={best_epoch}, " |
| f"accuracy={test_accuracy:.3f}, " |
| f"n={test_n[best_epoch]}" |
| ) |
| print(f"Reference: epoch=9, accuracy={REF_ACCURACY}, n=1070") |
|
|
| assert abs(test_accuracy - REF_ACCURACY) < TOLERANCE, ( |
| f"WSD accuracy {test_accuracy:.3f} outside tolerance " |
| f"of {REF_ACCURACY} +/- {TOLERANCE}" |
| ) |
|
|