"""Shared utilities for Bamman & Burns (2020) case study tests. Provides the subword-to-word transform matrix approach used by all four case studies: POS tagging, WSD, infilling, and contextual nearest neighbors. """ from pathlib import Path import numpy as np import torch from torch import nn from torch.nn import CrossEntropyLoss # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- BERT_DIM = 768 BATCH_SIZE = 32 DROPOUT_RATE = 0.25 # Special tokens that should not go through subword encoding SPECIAL_TOKENS = {"[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"} # Data paths (relative to repo root) REPO_ROOT = Path(__file__).resolve().parent.parent DATA_DIR = REPO_ROOT / "data" CASE_STUDY_DIR = DATA_DIR / "case_studies" WSD_DATA_PATH = CASE_STUDY_DIR / "wsd" / "latin.sense.data" INFILLING_DATA_PATH = CASE_STUDY_DIR / "infilling" / "emendation_filtered.txt" # --------------------------------------------------------------------------- # Tokenization helpers # --------------------------------------------------------------------------- def word_to_subtokens(tokenizer, word): """Get subtoken strings for a single word. Special tokens ([CLS], [SEP], etc.) are returned as-is. Regular words are tokenized through the subword pipeline, matching the original LatinTokenizer.tokenize() behavior. """ if word in SPECIAL_TOKENS: return [word] return tokenizer.tokenize(word) # --------------------------------------------------------------------------- # Batching with transform matrices # --------------------------------------------------------------------------- def get_batches(tokenizer, sentences, max_batch, has_labels=True): """Tokenize and batch sentences with subword-to-word transform matrices. Each word is tokenized individually (matching original behavior). The transform matrix averages subword representations back to word-level representations. sentences: list of sentences, where each sentence is a list of items. If has_labels=True, each item is [word, label, ...] (list/tuple). If has_labels=False, each item is a word string. Returns: If has_labels: (data, masks, labels, transforms, ordering) If not: (data, masks, transforms, ordering) """ all_data = [] all_masks = [] all_labels = [] if has_labels else None all_transforms = [] for sentence in sentences: tok_ids = [] input_mask = [] labels = [] if has_labels else None transform = [] # First pass: get subtokens for each word all_toks = [] n = 0 for item in sentence: word = item[0] if has_labels else item toks = word_to_subtokens(tokenizer, word) all_toks.append(toks) n += len(toks) # Second pass: build transform matrix and collect IDs cur = 0 for idx, item in enumerate(sentence): toks = all_toks[idx] ind = list(np.zeros(n)) for j in range(cur, cur + len(toks)): ind[j] = 1.0 / len(toks) cur += len(toks) transform.append(ind) tok_ids.extend(tokenizer.convert_tokens_to_ids(toks)) input_mask.extend(np.ones(len(toks))) if has_labels: labels.append(int(item[1])) all_data.append(tok_ids) all_masks.append(input_mask) if has_labels: all_labels.append(labels) all_transforms.append(transform) lengths = np.array([len(l) for l in all_data]) ordering = np.argsort(lengths) ordered_data = [None] * len(all_data) ordered_masks = [None] * len(all_data) ordered_labels = [None] * len(all_data) if has_labels else None ordered_transforms = [None] * len(all_data) for i, ind in enumerate(ordering): ordered_data[i] = all_data[ind] ordered_masks[i] = all_masks[ind] if has_labels: ordered_labels[i] = all_labels[ind] ordered_transforms[i] = all_transforms[ind] batched_data = [] batched_mask = [] batched_labels = [] if has_labels else None batched_transforms = [] i = 0 current_batch = max_batch while i < len(ordered_data): bd = ordered_data[i:i + current_batch] bm = ordered_masks[i:i + current_batch] bl = ordered_labels[i:i + current_batch] if has_labels else None bt = ordered_transforms[i:i + current_batch] ml = max(len(s) for s in bd) max_words = max(len(t) for t in bt) for j in range(len(bd)): blen = len(bd[j]) for _k in range(blen, ml): bd[j].append(0) bm[j].append(0) for z in range(len(bt[j])): bt[j][z].append(0) if has_labels: blab = len(bl[j]) for _k in range(blab, max_words): bl[j].append(-100) for _k in range(len(bt[j]), max_words): bt[j].append(np.zeros(ml)) batched_data.append(torch.LongTensor(bd)) batched_mask.append(torch.FloatTensor(bm)) if has_labels: batched_labels.append(torch.LongTensor(bl)) batched_transforms.append(torch.FloatTensor(bt)) i += current_batch if ml > 100: current_batch = 12 if ml > 200: current_batch = 6 if has_labels: return batched_data, batched_mask, batched_labels, batched_transforms, ordering return batched_data, batched_mask, batched_transforms, ordering # --------------------------------------------------------------------------- # Sequence labeling model (used by POS and WSD) # --------------------------------------------------------------------------- class BertForSequenceLabeling(nn.Module): """BERT + linear classifier for sequence labeling. Used by POS tagging and WSD case studies. The encoder is frozen and a linear head is trained on top. """ def __init__(self, tokenizer, bert_model, freeze_bert=False, num_labels=2, hidden_size=BERT_DIM): super().__init__() self.tokenizer = tokenizer self.num_labels = num_labels self.bert = bert_model self.bert.eval() if freeze_bert: for param in self.bert.parameters(): param.requires_grad = False self.dropout = nn.Dropout(DROPOUT_RATE) self.classifier = nn.Linear(hidden_size, num_labels) def forward(self, input_ids, attention_mask=None, transforms=None, labels=None): device = input_ids.device if attention_mask is not None: attention_mask = attention_mask.to(device) if transforms is not None: transforms = transforms.to(device) if labels is not None: labels = labels.to(device) outputs = self.bert(input_ids, attention_mask=attention_mask) sequence_output = outputs[0] out = torch.matmul(transforms, sequence_output) logits = self.classifier(out) if labels is not None: loss_fct = CrossEntropyLoss(ignore_index=-100) return loss_fct( logits.view(-1, self.num_labels), labels.view(-1) ) return logits def get_batches(self, sentences, max_batch): """Tokenize and batch with subword-to-word transform matrices. Delegates to the module-level get_batches() function. """ return get_batches(self.tokenizer, sentences, max_batch, has_labels=True)