| """POS tagging case study -- Bamman & Burns (2020) Table 1. |
| |
| Reproduces the POS tagging experiment: fine-tune a linear classifier on top |
| of Latin BERT embeddings, evaluate on UD Latin treebanks. |
| |
| Reference results (from original logs): |
| Perseus: 0.943 |
| PROIEL: 0.982 |
| ITTB: 0.988 |
| """ |
|
|
| import subprocess |
| from pathlib import Path |
|
|
| import numpy as np |
| import pytest |
| import torch |
| import torch.optim as optim |
| from transformers import AutoTokenizer, BertModel |
|
|
| from case_study_utils import ( |
| BATCH_SIZE, |
| BERT_DIM, |
| BertForSequenceLabeling, |
| ) |
|
|
| torch.manual_seed(0) |
| np.random.seed(0) |
|
|
| TOLERANCE = 0.01 |
|
|
| UD_REPOS = { |
| "perseus": "https://github.com/UniversalDependencies/UD_Latin-Perseus.git", |
| "proiel": "https://github.com/UniversalDependencies/UD_Latin-PROIEL.git", |
| "ittb": "https://github.com/UniversalDependencies/UD_Latin-ITTB.git", |
| } |
|
|
| REFERENCE_ACCURACY = { |
| "perseus": 0.943, |
| "proiel": 0.982, |
| "ittb": 0.988, |
| } |
|
|
|
|
| def _read_conllu_annotations(filename, tagset, labeled=True): |
| """Read CoNLL-U file, return list of sentences.""" |
| sentences = [] |
| sentence = [["[CLS]", -100, -1, filename]] |
| sentence_id = 0 |
|
|
| with open(filename, encoding="utf-8") as f: |
| for line in f: |
| if line.startswith("#"): |
| continue |
| if line == "\n": |
| sentence_id += 1 |
| sentence.append(["[SEP]", -100, -1, filename]) |
| sentences.append(sentence) |
| sentence = [["[CLS]", -100, -1, filename]] |
| else: |
| cols = line.rstrip().split("\t") |
| if "-" in cols[0] or "." in cols[0]: |
| continue |
| word = cols[1].lower() |
| label = tagset[cols[3]] if labeled else 0 |
| sentence.append([word, label, sentence_id, filename]) |
|
|
| sentence.append(["[SEP]", -100, -1, filename]) |
| if len(sentence) > 2: |
| sentences.append(sentence) |
| return sentences |
|
|
|
|
| def _generate_tagset(filenames): |
| """Generate POS tagset from CoNLL-U files.""" |
| tags = {} |
| for filename in filenames: |
| with open(filename) as f: |
| for line in f: |
| if line.startswith("#") or len(line.rstrip()) == 0: |
| continue |
| cols = line.rstrip().split("\t") |
| if "-" in cols[0] or "." in cols[0]: |
| continue |
| tags[cols[3]] = 1 |
| return {tag: idx for idx, tag in enumerate(tags)} |
|
|
|
|
| def _train_and_evaluate(treebank_name, treebank_dir, device, model_path): |
| """Train POS tagger on a UD treebank and return test accuracy.""" |
| tokenizer = AutoTokenizer.from_pretrained( |
| model_path, trust_remote_code=True |
| ) |
| bert_model = BertModel.from_pretrained(model_path) |
|
|
| conllu_files = sorted(Path(treebank_dir).glob("*.conllu")) |
| train_file = [f for f in conllu_files if "train" in f.name][0] |
| test_file = [f for f in conllu_files if "test" in f.name][0] |
| dev_files = [f for f in conllu_files if "dev" in f.name] |
|
|
| tagset = _generate_tagset([str(f) for f in conllu_files]) |
| num_labels = len(tagset) |
|
|
| model = BertForSequenceLabeling( |
| tokenizer, bert_model, freeze_bert=False, num_labels=num_labels, |
| hidden_size=BERT_DIM |
| ) |
| model.to(device) |
|
|
| train_sents = _read_conllu_annotations(str(train_file), tagset) |
| train_data, train_mask, train_labels, train_transforms, _ = \ |
| model.get_batches(train_sents, BATCH_SIZE) |
|
|
| test_sents = _read_conllu_annotations(str(test_file), tagset) |
| test_data, test_mask, test_labels, test_transforms, _ = \ |
| model.get_batches(test_sents, BATCH_SIZE) |
|
|
| if dev_files: |
| dev_sents = _read_conllu_annotations(str(dev_files[0]), tagset) |
| dev_data, dev_mask, dev_labels, dev_transforms, _ = \ |
| model.get_batches(dev_sents, BATCH_SIZE) |
| else: |
| dev_data = None |
|
|
| optimizer = optim.Adam(model.parameters(), lr=5e-5) |
| best_score = 0 |
| best_state = None |
| best_epoch = 0 |
|
|
| for epoch in range(5): |
| model.train() |
| big_loss = 0 |
| for b in range(len(train_data)): |
| loss = model( |
| train_data[b].to(device), |
| attention_mask=train_mask[b], |
| transforms=train_transforms[b], |
| labels=train_labels[b], |
| ) |
| big_loss += loss.item() |
| loss.backward() |
| optimizer.step() |
| model.zero_grad() |
|
|
| print(f" epoch {epoch}: loss={big_loss:.2f}") |
|
|
| if dev_data is not None: |
| model.eval() |
| cor = tot = 0 |
| with torch.no_grad(): |
| for b in range(len(dev_data)): |
| logits = model( |
| dev_data[b].to(device), |
| attention_mask=dev_mask[b], |
| transforms=dev_transforms[b], |
| ) |
| size = dev_labels[b].shape |
| logits = logits.view(-1, size[1], num_labels) |
| preds = np.argmax(logits.cpu().numpy(), axis=2) |
| for row in range(size[0]): |
| for col in range(size[1]): |
| if dev_labels[b][row][col] != -100: |
| if preds[row][col] == dev_labels[b][row][col]: |
| cor += 1 |
| tot += 1 |
| score = cor / tot if tot > 0 else 0 |
| print(f" epoch {epoch}: dev accuracy={score:.4f}") |
| if score > best_score: |
| best_score = score |
| best_state = { |
| k: v.cpu().clone() |
| for k, v in model.state_dict().items() |
| } |
| best_epoch = epoch |
| else: |
| best_state = { |
| k: v.cpu().clone() |
| for k, v in model.state_dict().items() |
| } |
| best_epoch = epoch |
|
|
| print(f" best epoch: {best_epoch}") |
|
|
| if best_state is not None: |
| model.load_state_dict(best_state) |
|
|
| model.eval() |
| cor = tot = 0 |
| with torch.no_grad(): |
| for b in range(len(test_data)): |
| logits = model( |
| test_data[b].to(device), |
| attention_mask=test_mask[b], |
| transforms=test_transforms[b], |
| ) |
| size = test_labels[b].shape |
| logits = logits.view(-1, size[1], num_labels) |
| preds = np.argmax(logits.cpu().numpy(), axis=2) |
| for row in range(size[0]): |
| for col in range(size[1]): |
| if test_labels[b][row][col] != -100: |
| if preds[row][col] == test_labels[b][row][col]: |
| cor += 1 |
| tot += 1 |
|
|
| accuracy = cor / tot if tot > 0 else 0 |
| return accuracy |
|
|
|
|
| @pytest.fixture(scope="module") |
| def ud_treebanks(tmp_path_factory): |
| """Download UD Latin treebanks to a temp directory.""" |
| base = tmp_path_factory.mktemp("ud_latin") |
| paths = {} |
| for name, url in UD_REPOS.items(): |
| if name == "ittb": |
| dest = base / "UD_Latin-ITTB" |
| elif name == "proiel": |
| dest = base / "UD_Latin-PROIEL" |
| else: |
| dest = base / f"UD_Latin-{name.capitalize()}" |
| subprocess.run( |
| ["git", "clone", "--depth=1", url, str(dest)], check=True |
| ) |
| paths[name] = dest |
| return paths |
|
|
|
|
| @pytest.mark.slow |
| @pytest.mark.parametrize("treebank", ["perseus", "proiel", "ittb"]) |
| def test_pos_tagging(ud_treebanks, treebank, model_path): |
| """Reproduce POS tagging case study from Bamman & Burns (2020).""" |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| accuracy = _train_and_evaluate(treebank, ud_treebanks[treebank], device, model_path) |
| ref = REFERENCE_ACCURACY[treebank] |
| print(f"\n{treebank}: accuracy={accuracy:.3f} (ref={ref:.3f})") |
| assert abs(accuracy - ref) < TOLERANCE, ( |
| f"{treebank} accuracy {accuracy:.3f} outside tolerance " |
| f"of {ref} +/- {TOLERANCE}" |
| ) |
|
|