| """Infilling (text emendation) case study — Bamman & Burns (2020) Table 3. |
| |
| Reproduces the masked word prediction experiment. For each example in the |
| emendation dataset, the target word is replaced with [MASK] and the model |
| predicts the missing word. Measures Precision@1, @10, @50. |
| |
| Reference results (from original logs): |
| Precision@1: 0.331 |
| Precision@10: 0.622 |
| Precision@50: 0.740 |
| n: 1161 |
| """ |
|
|
| import copy |
| import re |
| from typing import List |
|
|
| import pytest |
| import torch |
| from transformers import AutoTokenizer, BertForMaskedLM |
|
|
| from case_study_utils import INFILLING_DATA_PATH |
|
|
|
|
| def _tokenize_text(tokenizer, text: str) -> List[int]: |
| """Tokenize text word-by-word, matching the original LatinTokenizer behavior.""" |
| ids = [] |
| for word in text.split(): |
| word_ids = tokenizer.encode(word.lower(), add_special_tokens=False) |
| ids.extend(word_ids) |
| return ids |
|
|
|
|
| REF_P1 = 0.331 |
| REF_P10 = 0.622 |
| REF_P50 = 0.740 |
| TOLERANCE = 0.01 |
|
|
|
|
| def _proc(model, tokenizer, token_ids, device): |
| """Predict the subtoken at the [MASK] position for multi-subtoken words.""" |
| mask_id = tokenizer.convert_tokens_to_ids("[MASK]") |
| mask_pos = token_ids.index(mask_id) |
| t = torch.LongTensor(token_ids).unsqueeze(0).to(device) |
| with torch.no_grad(): |
| preds = model(t)[0] |
| sorted_vals = torch.argsort(preds[0][mask_pos], descending=True) |
| predicted_index = sorted_vals[0].item() |
| return tokenizer.convert_ids_to_tokens(predicted_index) |
|
|
|
|
| def _evaluate_one(model, tokenizer, text_before, text_after, truth, device): |
| """Evaluate a single infilling example. Returns (p1, p10, p50).""" |
| before_ids = _tokenize_text(tokenizer, text_before) |
| after_ids = _tokenize_text(tokenizer, text_after) |
| mask_id = tokenizer.convert_tokens_to_ids("[MASK]") |
| cls_id = tokenizer.convert_tokens_to_ids("[CLS]") |
| sep_id = tokenizer.convert_tokens_to_ids("[SEP]") |
|
|
| token_ids = [cls_id] + before_ids + [mask_id] + after_ids + [sep_id] |
| mask_pos = token_ids.index(mask_id) |
|
|
| t = torch.LongTensor(token_ids).unsqueeze(0).to(device) |
| p1 = p10 = p50 = 0 |
|
|
| with torch.no_grad(): |
| preds = model(t)[0] |
| sorted_vals = torch.argsort(preds[0][mask_pos], descending=True) |
|
|
| for k, p in enumerate(sorted_vals[:50]): |
| predicted_index = p.item() |
| predicted_token = tokenizer.convert_ids_to_tokens(predicted_index) |
|
|
| suffix = "" |
| if not predicted_token.endswith("_"): |
| uptokens = copy.deepcopy(token_ids) |
| uptokens.insert(mask_pos, predicted_index) |
| suffix = _proc(model, tokenizer, uptokens, device) |
|
|
| predicted_word = f"{predicted_token}{suffix}" |
| predicted_word = re.sub(r"_$", "", predicted_word).lower() |
| truth_lower = truth.lower() |
|
|
| if predicted_word == truth_lower: |
| if k == 0: |
| p1 = 1 |
| if k < 10: |
| p10 = 1 |
| if k < 50: |
| p50 = 1 |
|
|
| return p1, p10, p50 |
|
|
|
|
| @pytest.mark.slow |
| def test_infilling_precision(model_path): |
| """Reproduce infilling case study from Bamman & Burns (2020).""" |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
| model = BertForMaskedLM.from_pretrained(model_path) |
| model.to(device) |
| model.eval() |
|
|
| min_tokens = 10 |
| max_tokens = 100 |
| all_p1 = all_p10 = all_p50 = n = 0 |
|
|
| with open(INFILLING_DATA_PATH) as f: |
| for line in f: |
| cols = line.split("\t") |
| if len(cols) < 5: |
| continue |
| if cols[0] != "disjoint": |
| continue |
|
|
| text_before = cols[2] |
| truth = cols[3] |
| if len(truth) < 2: |
| continue |
| text_after = cols[4].rstrip() |
|
|
| tot_toks = len(text_before.split()) + 1 + len(text_after.split()) |
| if not (min_tokens < tot_toks < max_tokens): |
| continue |
|
|
| p1, p10, p50 = _evaluate_one( |
| model, tokenizer, text_before, text_after, truth, device |
| ) |
| all_p1 += p1 |
| all_p10 += p10 |
| all_p50 += p50 |
| n += 1 |
|
|
| assert n == 1161, f"Expected 1161 examples, got {n}" |
| precision_1 = all_p1 / n |
| precision_10 = all_p10 / n |
| precision_50 = all_p50 / n |
|
|
| print(f"Precision@1: {precision_1:.3f} (ref: {REF_P1})") |
| print(f"Precision@10: {precision_10:.3f} (ref: {REF_P10})") |
| print(f"Precision@50: {precision_50:.3f} (ref: {REF_P50})") |
|
|
| assert abs(precision_1 - REF_P1) < TOLERANCE, ( |
| f"P@1 {precision_1:.3f} outside tolerance of {REF_P1} +/- {TOLERANCE}" |
| ) |
| assert abs(precision_10 - REF_P10) < TOLERANCE, ( |
| f"P@10 {precision_10:.3f} outside tolerance of {REF_P10} +/- {TOLERANCE}" |
| ) |
| assert abs(precision_50 - REF_P50) < TOLERANCE, ( |
| f"P@50 {precision_50:.3f} outside tolerance of {REF_P50} +/- {TOLERANCE}" |
| ) |
|
|