latin-bert / tests /test_infilling.py
diyclassics's picture
refactor: extract shared case study utils and move data to tracked paths
f04d50f
raw
history blame
5.05 kB
"""Infilling (text emendation) case study — Bamman & Burns (2020) Table 3.
Reproduces the masked word prediction experiment. For each example in the
emendation dataset, the target word is replaced with [MASK] and the model
predicts the missing word. Measures Precision@1, @10, @50.
Reference results (from original logs):
Precision@1: 0.331
Precision@10: 0.622
Precision@50: 0.740
n: 1161
"""
import copy
import re
from typing import List
import pytest
import torch
from transformers import AutoTokenizer, BertForMaskedLM
from case_study_utils import INFILLING_DATA_PATH
def _tokenize_text(tokenizer, text: str) -> List[int]:
"""Tokenize text word-by-word, matching the original LatinTokenizer behavior."""
ids = []
for word in text.split():
word_ids = tokenizer.encode(word.lower(), add_special_tokens=False)
ids.extend(word_ids)
return ids
REF_P1 = 0.331
REF_P10 = 0.622
REF_P50 = 0.740
TOLERANCE = 0.01
def _proc(model, tokenizer, token_ids, device):
"""Predict the subtoken at the [MASK] position for multi-subtoken words."""
mask_id = tokenizer.convert_tokens_to_ids("[MASK]")
mask_pos = token_ids.index(mask_id)
t = torch.LongTensor(token_ids).unsqueeze(0).to(device)
with torch.no_grad():
preds = model(t)[0]
sorted_vals = torch.argsort(preds[0][mask_pos], descending=True)
predicted_index = sorted_vals[0].item()
return tokenizer.convert_ids_to_tokens(predicted_index)
def _evaluate_one(model, tokenizer, text_before, text_after, truth, device):
"""Evaluate a single infilling example. Returns (p1, p10, p50)."""
before_ids = _tokenize_text(tokenizer, text_before)
after_ids = _tokenize_text(tokenizer, text_after)
mask_id = tokenizer.convert_tokens_to_ids("[MASK]")
cls_id = tokenizer.convert_tokens_to_ids("[CLS]")
sep_id = tokenizer.convert_tokens_to_ids("[SEP]")
token_ids = [cls_id] + before_ids + [mask_id] + after_ids + [sep_id]
mask_pos = token_ids.index(mask_id)
t = torch.LongTensor(token_ids).unsqueeze(0).to(device)
p1 = p10 = p50 = 0
with torch.no_grad():
preds = model(t)[0]
sorted_vals = torch.argsort(preds[0][mask_pos], descending=True)
for k, p in enumerate(sorted_vals[:50]):
predicted_index = p.item()
predicted_token = tokenizer.convert_ids_to_tokens(predicted_index)
suffix = ""
if not predicted_token.endswith("_"):
uptokens = copy.deepcopy(token_ids)
uptokens.insert(mask_pos, predicted_index)
suffix = _proc(model, tokenizer, uptokens, device)
predicted_word = f"{predicted_token}{suffix}"
predicted_word = re.sub(r"_$", "", predicted_word).lower()
truth_lower = truth.lower()
if predicted_word == truth_lower:
if k == 0:
p1 = 1
if k < 10:
p10 = 1
if k < 50:
p50 = 1
return p1, p10, p50
@pytest.mark.slow
def test_infilling_precision(model_path):
"""Reproduce infilling case study from Bamman & Burns (2020)."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = BertForMaskedLM.from_pretrained(model_path)
model.to(device)
model.eval()
min_tokens = 10
max_tokens = 100
all_p1 = all_p10 = all_p50 = n = 0
with open(INFILLING_DATA_PATH) as f:
for line in f:
cols = line.split("\t")
if len(cols) < 5:
continue
if cols[0] != "disjoint":
continue
text_before = cols[2]
truth = cols[3]
if len(truth) < 2:
continue
text_after = cols[4].rstrip()
tot_toks = len(text_before.split()) + 1 + len(text_after.split())
if not (min_tokens < tot_toks < max_tokens):
continue
p1, p10, p50 = _evaluate_one(
model, tokenizer, text_before, text_after, truth, device
)
all_p1 += p1
all_p10 += p10
all_p50 += p50
n += 1
assert n == 1161, f"Expected 1161 examples, got {n}"
precision_1 = all_p1 / n
precision_10 = all_p10 / n
precision_50 = all_p50 / n
print(f"Precision@1: {precision_1:.3f} (ref: {REF_P1})")
print(f"Precision@10: {precision_10:.3f} (ref: {REF_P10})")
print(f"Precision@50: {precision_50:.3f} (ref: {REF_P50})")
assert abs(precision_1 - REF_P1) < TOLERANCE, (
f"P@1 {precision_1:.3f} outside tolerance of {REF_P1} +/- {TOLERANCE}"
)
assert abs(precision_10 - REF_P10) < TOLERANCE, (
f"P@10 {precision_10:.3f} outside tolerance of {REF_P10} +/- {TOLERANCE}"
)
assert abs(precision_50 - REF_P50) < TOLERANCE, (
f"P@50 {precision_50:.3f} outside tolerance of {REF_P50} +/- {TOLERANCE}"
)