latin-bert / tests /case_study_utils.py
diyclassics's picture
refactor: extract shared case study utils and move data to tracked paths
f04d50f
"""Shared utilities for Bamman & Burns (2020) case study tests.
Provides the subword-to-word transform matrix approach used by all four
case studies: POS tagging, WSD, infilling, and contextual nearest neighbors.
"""
from pathlib import Path
import numpy as np
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
BERT_DIM = 768
BATCH_SIZE = 32
DROPOUT_RATE = 0.25
# Special tokens that should not go through subword encoding
SPECIAL_TOKENS = {"[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"}
# Data paths (relative to repo root)
REPO_ROOT = Path(__file__).resolve().parent.parent
DATA_DIR = REPO_ROOT / "data"
CASE_STUDY_DIR = DATA_DIR / "case_studies"
WSD_DATA_PATH = CASE_STUDY_DIR / "wsd" / "latin.sense.data"
INFILLING_DATA_PATH = CASE_STUDY_DIR / "infilling" / "emendation_filtered.txt"
# ---------------------------------------------------------------------------
# Tokenization helpers
# ---------------------------------------------------------------------------
def word_to_subtokens(tokenizer, word):
"""Get subtoken strings for a single word.
Special tokens ([CLS], [SEP], etc.) are returned as-is.
Regular words are tokenized through the subword pipeline,
matching the original LatinTokenizer.tokenize() behavior.
"""
if word in SPECIAL_TOKENS:
return [word]
return tokenizer.tokenize(word)
# ---------------------------------------------------------------------------
# Batching with transform matrices
# ---------------------------------------------------------------------------
def get_batches(tokenizer, sentences, max_batch, has_labels=True):
"""Tokenize and batch sentences with subword-to-word transform matrices.
Each word is tokenized individually (matching original behavior).
The transform matrix averages subword representations back to
word-level representations.
sentences: list of sentences, where each sentence is a list of items.
If has_labels=True, each item is [word, label, ...] (list/tuple).
If has_labels=False, each item is a word string.
Returns:
If has_labels: (data, masks, labels, transforms, ordering)
If not: (data, masks, transforms, ordering)
"""
all_data = []
all_masks = []
all_labels = [] if has_labels else None
all_transforms = []
for sentence in sentences:
tok_ids = []
input_mask = []
labels = [] if has_labels else None
transform = []
# First pass: get subtokens for each word
all_toks = []
n = 0
for item in sentence:
word = item[0] if has_labels else item
toks = word_to_subtokens(tokenizer, word)
all_toks.append(toks)
n += len(toks)
# Second pass: build transform matrix and collect IDs
cur = 0
for idx, item in enumerate(sentence):
toks = all_toks[idx]
ind = list(np.zeros(n))
for j in range(cur, cur + len(toks)):
ind[j] = 1.0 / len(toks)
cur += len(toks)
transform.append(ind)
tok_ids.extend(tokenizer.convert_tokens_to_ids(toks))
input_mask.extend(np.ones(len(toks)))
if has_labels:
labels.append(int(item[1]))
all_data.append(tok_ids)
all_masks.append(input_mask)
if has_labels:
all_labels.append(labels)
all_transforms.append(transform)
lengths = np.array([len(l) for l in all_data])
ordering = np.argsort(lengths)
ordered_data = [None] * len(all_data)
ordered_masks = [None] * len(all_data)
ordered_labels = [None] * len(all_data) if has_labels else None
ordered_transforms = [None] * len(all_data)
for i, ind in enumerate(ordering):
ordered_data[i] = all_data[ind]
ordered_masks[i] = all_masks[ind]
if has_labels:
ordered_labels[i] = all_labels[ind]
ordered_transforms[i] = all_transforms[ind]
batched_data = []
batched_mask = []
batched_labels = [] if has_labels else None
batched_transforms = []
i = 0
current_batch = max_batch
while i < len(ordered_data):
bd = ordered_data[i:i + current_batch]
bm = ordered_masks[i:i + current_batch]
bl = ordered_labels[i:i + current_batch] if has_labels else None
bt = ordered_transforms[i:i + current_batch]
ml = max(len(s) for s in bd)
max_words = max(len(t) for t in bt)
for j in range(len(bd)):
blen = len(bd[j])
for _k in range(blen, ml):
bd[j].append(0)
bm[j].append(0)
for z in range(len(bt[j])):
bt[j][z].append(0)
if has_labels:
blab = len(bl[j])
for _k in range(blab, max_words):
bl[j].append(-100)
for _k in range(len(bt[j]), max_words):
bt[j].append(np.zeros(ml))
batched_data.append(torch.LongTensor(bd))
batched_mask.append(torch.FloatTensor(bm))
if has_labels:
batched_labels.append(torch.LongTensor(bl))
batched_transforms.append(torch.FloatTensor(bt))
i += current_batch
if ml > 100:
current_batch = 12
if ml > 200:
current_batch = 6
if has_labels:
return batched_data, batched_mask, batched_labels, batched_transforms, ordering
return batched_data, batched_mask, batched_transforms, ordering
# ---------------------------------------------------------------------------
# Sequence labeling model (used by POS and WSD)
# ---------------------------------------------------------------------------
class BertForSequenceLabeling(nn.Module):
"""BERT + linear classifier for sequence labeling.
Used by POS tagging and WSD case studies. The encoder is frozen
and a linear head is trained on top.
"""
def __init__(self, tokenizer, bert_model, freeze_bert=False,
num_labels=2, hidden_size=BERT_DIM):
super().__init__()
self.tokenizer = tokenizer
self.num_labels = num_labels
self.bert = bert_model
self.bert.eval()
if freeze_bert:
for param in self.bert.parameters():
param.requires_grad = False
self.dropout = nn.Dropout(DROPOUT_RATE)
self.classifier = nn.Linear(hidden_size, num_labels)
def forward(self, input_ids, attention_mask=None, transforms=None,
labels=None):
device = input_ids.device
if attention_mask is not None:
attention_mask = attention_mask.to(device)
if transforms is not None:
transforms = transforms.to(device)
if labels is not None:
labels = labels.to(device)
outputs = self.bert(input_ids, attention_mask=attention_mask)
sequence_output = outputs[0]
out = torch.matmul(transforms, sequence_output)
logits = self.classifier(out)
if labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-100)
return loss_fct(
logits.view(-1, self.num_labels), labels.view(-1)
)
return logits
def get_batches(self, sentences, max_batch):
"""Tokenize and batch with subword-to-word transform matrices.
Delegates to the module-level get_batches() function.
"""
return get_batches(self.tokenizer, sentences, max_batch,
has_labels=True)