File size: 7,732 Bytes
f04d50f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 | """Shared utilities for Bamman & Burns (2020) case study tests.
Provides the subword-to-word transform matrix approach used by all four
case studies: POS tagging, WSD, infilling, and contextual nearest neighbors.
"""
from pathlib import Path
import numpy as np
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
BERT_DIM = 768
BATCH_SIZE = 32
DROPOUT_RATE = 0.25
# Special tokens that should not go through subword encoding
SPECIAL_TOKENS = {"[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"}
# Data paths (relative to repo root)
REPO_ROOT = Path(__file__).resolve().parent.parent
DATA_DIR = REPO_ROOT / "data"
CASE_STUDY_DIR = DATA_DIR / "case_studies"
WSD_DATA_PATH = CASE_STUDY_DIR / "wsd" / "latin.sense.data"
INFILLING_DATA_PATH = CASE_STUDY_DIR / "infilling" / "emendation_filtered.txt"
# ---------------------------------------------------------------------------
# Tokenization helpers
# ---------------------------------------------------------------------------
def word_to_subtokens(tokenizer, word):
"""Get subtoken strings for a single word.
Special tokens ([CLS], [SEP], etc.) are returned as-is.
Regular words are tokenized through the subword pipeline,
matching the original LatinTokenizer.tokenize() behavior.
"""
if word in SPECIAL_TOKENS:
return [word]
return tokenizer.tokenize(word)
# ---------------------------------------------------------------------------
# Batching with transform matrices
# ---------------------------------------------------------------------------
def get_batches(tokenizer, sentences, max_batch, has_labels=True):
"""Tokenize and batch sentences with subword-to-word transform matrices.
Each word is tokenized individually (matching original behavior).
The transform matrix averages subword representations back to
word-level representations.
sentences: list of sentences, where each sentence is a list of items.
If has_labels=True, each item is [word, label, ...] (list/tuple).
If has_labels=False, each item is a word string.
Returns:
If has_labels: (data, masks, labels, transforms, ordering)
If not: (data, masks, transforms, ordering)
"""
all_data = []
all_masks = []
all_labels = [] if has_labels else None
all_transforms = []
for sentence in sentences:
tok_ids = []
input_mask = []
labels = [] if has_labels else None
transform = []
# First pass: get subtokens for each word
all_toks = []
n = 0
for item in sentence:
word = item[0] if has_labels else item
toks = word_to_subtokens(tokenizer, word)
all_toks.append(toks)
n += len(toks)
# Second pass: build transform matrix and collect IDs
cur = 0
for idx, item in enumerate(sentence):
toks = all_toks[idx]
ind = list(np.zeros(n))
for j in range(cur, cur + len(toks)):
ind[j] = 1.0 / len(toks)
cur += len(toks)
transform.append(ind)
tok_ids.extend(tokenizer.convert_tokens_to_ids(toks))
input_mask.extend(np.ones(len(toks)))
if has_labels:
labels.append(int(item[1]))
all_data.append(tok_ids)
all_masks.append(input_mask)
if has_labels:
all_labels.append(labels)
all_transforms.append(transform)
lengths = np.array([len(l) for l in all_data])
ordering = np.argsort(lengths)
ordered_data = [None] * len(all_data)
ordered_masks = [None] * len(all_data)
ordered_labels = [None] * len(all_data) if has_labels else None
ordered_transforms = [None] * len(all_data)
for i, ind in enumerate(ordering):
ordered_data[i] = all_data[ind]
ordered_masks[i] = all_masks[ind]
if has_labels:
ordered_labels[i] = all_labels[ind]
ordered_transforms[i] = all_transforms[ind]
batched_data = []
batched_mask = []
batched_labels = [] if has_labels else None
batched_transforms = []
i = 0
current_batch = max_batch
while i < len(ordered_data):
bd = ordered_data[i:i + current_batch]
bm = ordered_masks[i:i + current_batch]
bl = ordered_labels[i:i + current_batch] if has_labels else None
bt = ordered_transforms[i:i + current_batch]
ml = max(len(s) for s in bd)
max_words = max(len(t) for t in bt)
for j in range(len(bd)):
blen = len(bd[j])
for _k in range(blen, ml):
bd[j].append(0)
bm[j].append(0)
for z in range(len(bt[j])):
bt[j][z].append(0)
if has_labels:
blab = len(bl[j])
for _k in range(blab, max_words):
bl[j].append(-100)
for _k in range(len(bt[j]), max_words):
bt[j].append(np.zeros(ml))
batched_data.append(torch.LongTensor(bd))
batched_mask.append(torch.FloatTensor(bm))
if has_labels:
batched_labels.append(torch.LongTensor(bl))
batched_transforms.append(torch.FloatTensor(bt))
i += current_batch
if ml > 100:
current_batch = 12
if ml > 200:
current_batch = 6
if has_labels:
return batched_data, batched_mask, batched_labels, batched_transforms, ordering
return batched_data, batched_mask, batched_transforms, ordering
# ---------------------------------------------------------------------------
# Sequence labeling model (used by POS and WSD)
# ---------------------------------------------------------------------------
class BertForSequenceLabeling(nn.Module):
"""BERT + linear classifier for sequence labeling.
Used by POS tagging and WSD case studies. The encoder is frozen
and a linear head is trained on top.
"""
def __init__(self, tokenizer, bert_model, freeze_bert=False,
num_labels=2, hidden_size=BERT_DIM):
super().__init__()
self.tokenizer = tokenizer
self.num_labels = num_labels
self.bert = bert_model
self.bert.eval()
if freeze_bert:
for param in self.bert.parameters():
param.requires_grad = False
self.dropout = nn.Dropout(DROPOUT_RATE)
self.classifier = nn.Linear(hidden_size, num_labels)
def forward(self, input_ids, attention_mask=None, transforms=None,
labels=None):
device = input_ids.device
if attention_mask is not None:
attention_mask = attention_mask.to(device)
if transforms is not None:
transforms = transforms.to(device)
if labels is not None:
labels = labels.to(device)
outputs = self.bert(input_ids, attention_mask=attention_mask)
sequence_output = outputs[0]
out = torch.matmul(transforms, sequence_output)
logits = self.classifier(out)
if labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-100)
return loss_fct(
logits.view(-1, self.num_labels), labels.view(-1)
)
return logits
def get_batches(self, sentences, max_batch):
"""Tokenize and batch with subword-to-word transform matrices.
Delegates to the module-level get_batches() function.
"""
return get_batches(self.tokenizer, sentences, max_batch,
has_labels=True)
|