latin-bert / tests /test_pos_tagging.py
diyclassics's picture
refactor: extract shared case study utils and move data to tracked paths
f04d50f
"""POS tagging case study -- Bamman & Burns (2020) Table 1.
Reproduces the POS tagging experiment: fine-tune a linear classifier on top
of Latin BERT embeddings, evaluate on UD Latin treebanks.
Reference results (from original logs):
Perseus: 0.943
PROIEL: 0.982
ITTB: 0.988
"""
import subprocess
from pathlib import Path
import numpy as np
import pytest
import torch
import torch.optim as optim
from transformers import AutoTokenizer, BertModel
from case_study_utils import (
BATCH_SIZE,
BERT_DIM,
BertForSequenceLabeling,
)
torch.manual_seed(0)
np.random.seed(0)
TOLERANCE = 0.01
UD_REPOS = {
"perseus": "https://github.com/UniversalDependencies/UD_Latin-Perseus.git",
"proiel": "https://github.com/UniversalDependencies/UD_Latin-PROIEL.git",
"ittb": "https://github.com/UniversalDependencies/UD_Latin-ITTB.git",
}
REFERENCE_ACCURACY = {
"perseus": 0.943,
"proiel": 0.982,
"ittb": 0.988,
}
def _read_conllu_annotations(filename, tagset, labeled=True):
"""Read CoNLL-U file, return list of sentences."""
sentences = []
sentence = [["[CLS]", -100, -1, filename]]
sentence_id = 0
with open(filename, encoding="utf-8") as f:
for line in f:
if line.startswith("#"):
continue
if line == "\n":
sentence_id += 1
sentence.append(["[SEP]", -100, -1, filename])
sentences.append(sentence)
sentence = [["[CLS]", -100, -1, filename]]
else:
cols = line.rstrip().split("\t")
if "-" in cols[0] or "." in cols[0]:
continue
word = cols[1].lower()
label = tagset[cols[3]] if labeled else 0
sentence.append([word, label, sentence_id, filename])
sentence.append(["[SEP]", -100, -1, filename])
if len(sentence) > 2:
sentences.append(sentence)
return sentences
def _generate_tagset(filenames):
"""Generate POS tagset from CoNLL-U files."""
tags = {}
for filename in filenames:
with open(filename) as f:
for line in f:
if line.startswith("#") or len(line.rstrip()) == 0:
continue
cols = line.rstrip().split("\t")
if "-" in cols[0] or "." in cols[0]:
continue
tags[cols[3]] = 1
return {tag: idx for idx, tag in enumerate(tags)}
def _train_and_evaluate(treebank_name, treebank_dir, device, model_path):
"""Train POS tagger on a UD treebank and return test accuracy."""
tokenizer = AutoTokenizer.from_pretrained(
model_path, trust_remote_code=True
)
bert_model = BertModel.from_pretrained(model_path)
conllu_files = sorted(Path(treebank_dir).glob("*.conllu"))
train_file = [f for f in conllu_files if "train" in f.name][0]
test_file = [f for f in conllu_files if "test" in f.name][0]
dev_files = [f for f in conllu_files if "dev" in f.name]
tagset = _generate_tagset([str(f) for f in conllu_files])
num_labels = len(tagset)
model = BertForSequenceLabeling(
tokenizer, bert_model, freeze_bert=False, num_labels=num_labels,
hidden_size=BERT_DIM
)
model.to(device)
train_sents = _read_conllu_annotations(str(train_file), tagset)
train_data, train_mask, train_labels, train_transforms, _ = \
model.get_batches(train_sents, BATCH_SIZE)
test_sents = _read_conllu_annotations(str(test_file), tagset)
test_data, test_mask, test_labels, test_transforms, _ = \
model.get_batches(test_sents, BATCH_SIZE)
if dev_files:
dev_sents = _read_conllu_annotations(str(dev_files[0]), tagset)
dev_data, dev_mask, dev_labels, dev_transforms, _ = \
model.get_batches(dev_sents, BATCH_SIZE)
else:
dev_data = None
optimizer = optim.Adam(model.parameters(), lr=5e-5)
best_score = 0
best_state = None
best_epoch = 0
for epoch in range(5):
model.train()
big_loss = 0
for b in range(len(train_data)):
loss = model(
train_data[b].to(device),
attention_mask=train_mask[b],
transforms=train_transforms[b],
labels=train_labels[b],
)
big_loss += loss.item()
loss.backward()
optimizer.step()
model.zero_grad()
print(f" epoch {epoch}: loss={big_loss:.2f}")
if dev_data is not None:
model.eval()
cor = tot = 0
with torch.no_grad():
for b in range(len(dev_data)):
logits = model(
dev_data[b].to(device),
attention_mask=dev_mask[b],
transforms=dev_transforms[b],
)
size = dev_labels[b].shape
logits = logits.view(-1, size[1], num_labels)
preds = np.argmax(logits.cpu().numpy(), axis=2)
for row in range(size[0]):
for col in range(size[1]):
if dev_labels[b][row][col] != -100:
if preds[row][col] == dev_labels[b][row][col]:
cor += 1
tot += 1
score = cor / tot if tot > 0 else 0
print(f" epoch {epoch}: dev accuracy={score:.4f}")
if score > best_score:
best_score = score
best_state = {
k: v.cpu().clone()
for k, v in model.state_dict().items()
}
best_epoch = epoch
else:
best_state = {
k: v.cpu().clone()
for k, v in model.state_dict().items()
}
best_epoch = epoch
print(f" best epoch: {best_epoch}")
if best_state is not None:
model.load_state_dict(best_state)
model.eval()
cor = tot = 0
with torch.no_grad():
for b in range(len(test_data)):
logits = model(
test_data[b].to(device),
attention_mask=test_mask[b],
transforms=test_transforms[b],
)
size = test_labels[b].shape
logits = logits.view(-1, size[1], num_labels)
preds = np.argmax(logits.cpu().numpy(), axis=2)
for row in range(size[0]):
for col in range(size[1]):
if test_labels[b][row][col] != -100:
if preds[row][col] == test_labels[b][row][col]:
cor += 1
tot += 1
accuracy = cor / tot if tot > 0 else 0
return accuracy
@pytest.fixture(scope="module")
def ud_treebanks(tmp_path_factory):
"""Download UD Latin treebanks to a temp directory."""
base = tmp_path_factory.mktemp("ud_latin")
paths = {}
for name, url in UD_REPOS.items():
if name == "ittb":
dest = base / "UD_Latin-ITTB"
elif name == "proiel":
dest = base / "UD_Latin-PROIEL"
else:
dest = base / f"UD_Latin-{name.capitalize()}"
subprocess.run(
["git", "clone", "--depth=1", url, str(dest)], check=True
)
paths[name] = dest
return paths
@pytest.mark.slow
@pytest.mark.parametrize("treebank", ["perseus", "proiel", "ittb"])
def test_pos_tagging(ud_treebanks, treebank, model_path):
"""Reproduce POS tagging case study from Bamman & Burns (2020)."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
accuracy = _train_and_evaluate(treebank, ud_treebanks[treebank], device, model_path)
ref = REFERENCE_ACCURACY[treebank]
print(f"\n{treebank}: accuracy={accuracy:.3f} (ref={ref:.3f})")
assert abs(accuracy - ref) < TOLERANCE, (
f"{treebank} accuracy {accuracy:.3f} outside tolerance "
f"of {ref} +/- {TOLERANCE}"
)