| """Shared utilities for Bamman & Burns (2020) case study tests. |
| |
| Provides the subword-to-word transform matrix approach used by all four |
| case studies: POS tagging, WSD, infilling, and contextual nearest neighbors. |
| """ |
|
|
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| from torch import nn |
| from torch.nn import CrossEntropyLoss |
|
|
| |
| |
| |
| BERT_DIM = 768 |
| BATCH_SIZE = 32 |
| DROPOUT_RATE = 0.25 |
|
|
| |
| SPECIAL_TOKENS = {"[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"} |
|
|
| |
| REPO_ROOT = Path(__file__).resolve().parent.parent |
| DATA_DIR = REPO_ROOT / "data" |
| CASE_STUDY_DIR = DATA_DIR / "case_studies" |
| WSD_DATA_PATH = CASE_STUDY_DIR / "wsd" / "latin.sense.data" |
| INFILLING_DATA_PATH = CASE_STUDY_DIR / "infilling" / "emendation_filtered.txt" |
|
|
|
|
| |
| |
| |
| def word_to_subtokens(tokenizer, word): |
| """Get subtoken strings for a single word. |
| |
| Special tokens ([CLS], [SEP], etc.) are returned as-is. |
| Regular words are tokenized through the subword pipeline, |
| matching the original LatinTokenizer.tokenize() behavior. |
| """ |
| if word in SPECIAL_TOKENS: |
| return [word] |
| return tokenizer.tokenize(word) |
|
|
|
|
| |
| |
| |
| def get_batches(tokenizer, sentences, max_batch, has_labels=True): |
| """Tokenize and batch sentences with subword-to-word transform matrices. |
| |
| Each word is tokenized individually (matching original behavior). |
| The transform matrix averages subword representations back to |
| word-level representations. |
| |
| sentences: list of sentences, where each sentence is a list of items. |
| If has_labels=True, each item is [word, label, ...] (list/tuple). |
| If has_labels=False, each item is a word string. |
| |
| Returns: |
| If has_labels: (data, masks, labels, transforms, ordering) |
| If not: (data, masks, transforms, ordering) |
| """ |
| all_data = [] |
| all_masks = [] |
| all_labels = [] if has_labels else None |
| all_transforms = [] |
|
|
| for sentence in sentences: |
| tok_ids = [] |
| input_mask = [] |
| labels = [] if has_labels else None |
| transform = [] |
|
|
| |
| all_toks = [] |
| n = 0 |
| for item in sentence: |
| word = item[0] if has_labels else item |
| toks = word_to_subtokens(tokenizer, word) |
| all_toks.append(toks) |
| n += len(toks) |
|
|
| |
| cur = 0 |
| for idx, item in enumerate(sentence): |
| toks = all_toks[idx] |
| ind = list(np.zeros(n)) |
| for j in range(cur, cur + len(toks)): |
| ind[j] = 1.0 / len(toks) |
| cur += len(toks) |
| transform.append(ind) |
| tok_ids.extend(tokenizer.convert_tokens_to_ids(toks)) |
| input_mask.extend(np.ones(len(toks))) |
| if has_labels: |
| labels.append(int(item[1])) |
|
|
| all_data.append(tok_ids) |
| all_masks.append(input_mask) |
| if has_labels: |
| all_labels.append(labels) |
| all_transforms.append(transform) |
|
|
| lengths = np.array([len(l) for l in all_data]) |
| ordering = np.argsort(lengths) |
|
|
| ordered_data = [None] * len(all_data) |
| ordered_masks = [None] * len(all_data) |
| ordered_labels = [None] * len(all_data) if has_labels else None |
| ordered_transforms = [None] * len(all_data) |
|
|
| for i, ind in enumerate(ordering): |
| ordered_data[i] = all_data[ind] |
| ordered_masks[i] = all_masks[ind] |
| if has_labels: |
| ordered_labels[i] = all_labels[ind] |
| ordered_transforms[i] = all_transforms[ind] |
|
|
| batched_data = [] |
| batched_mask = [] |
| batched_labels = [] if has_labels else None |
| batched_transforms = [] |
|
|
| i = 0 |
| current_batch = max_batch |
|
|
| while i < len(ordered_data): |
| bd = ordered_data[i:i + current_batch] |
| bm = ordered_masks[i:i + current_batch] |
| bl = ordered_labels[i:i + current_batch] if has_labels else None |
| bt = ordered_transforms[i:i + current_batch] |
|
|
| ml = max(len(s) for s in bd) |
| max_words = max(len(t) for t in bt) |
|
|
| for j in range(len(bd)): |
| blen = len(bd[j]) |
| for _k in range(blen, ml): |
| bd[j].append(0) |
| bm[j].append(0) |
| for z in range(len(bt[j])): |
| bt[j][z].append(0) |
| if has_labels: |
| blab = len(bl[j]) |
| for _k in range(blab, max_words): |
| bl[j].append(-100) |
| for _k in range(len(bt[j]), max_words): |
| bt[j].append(np.zeros(ml)) |
|
|
| batched_data.append(torch.LongTensor(bd)) |
| batched_mask.append(torch.FloatTensor(bm)) |
| if has_labels: |
| batched_labels.append(torch.LongTensor(bl)) |
| batched_transforms.append(torch.FloatTensor(bt)) |
|
|
| i += current_batch |
| if ml > 100: |
| current_batch = 12 |
| if ml > 200: |
| current_batch = 6 |
|
|
| if has_labels: |
| return batched_data, batched_mask, batched_labels, batched_transforms, ordering |
| return batched_data, batched_mask, batched_transforms, ordering |
|
|
|
|
| |
| |
| |
| class BertForSequenceLabeling(nn.Module): |
| """BERT + linear classifier for sequence labeling. |
| |
| Used by POS tagging and WSD case studies. The encoder is frozen |
| and a linear head is trained on top. |
| """ |
|
|
| def __init__(self, tokenizer, bert_model, freeze_bert=False, |
| num_labels=2, hidden_size=BERT_DIM): |
| super().__init__() |
| self.tokenizer = tokenizer |
| self.num_labels = num_labels |
| self.bert = bert_model |
| self.bert.eval() |
| if freeze_bert: |
| for param in self.bert.parameters(): |
| param.requires_grad = False |
| self.dropout = nn.Dropout(DROPOUT_RATE) |
| self.classifier = nn.Linear(hidden_size, num_labels) |
|
|
| def forward(self, input_ids, attention_mask=None, transforms=None, |
| labels=None): |
| device = input_ids.device |
| if attention_mask is not None: |
| attention_mask = attention_mask.to(device) |
| if transforms is not None: |
| transforms = transforms.to(device) |
| if labels is not None: |
| labels = labels.to(device) |
|
|
| outputs = self.bert(input_ids, attention_mask=attention_mask) |
| sequence_output = outputs[0] |
| out = torch.matmul(transforms, sequence_output) |
| logits = self.classifier(out) |
|
|
| if labels is not None: |
| loss_fct = CrossEntropyLoss(ignore_index=-100) |
| return loss_fct( |
| logits.view(-1, self.num_labels), labels.view(-1) |
| ) |
| return logits |
|
|
| def get_batches(self, sentences, max_batch): |
| """Tokenize and batch with subword-to-word transform matrices. |
| |
| Delegates to the module-level get_batches() function. |
| """ |
| return get_batches(self.tokenizer, sentences, max_batch, |
| has_labels=True) |
|
|