diyclassics Claude Opus 4.6 commited on
Commit
3510517
·
1 Parent(s): 8af2caa

test: add contextual nearest neighbors case study (Bamman & Burns §4.4)

Browse files

Three tests:
- test_embedding_parity: fast CPU test verifying word-level embeddings
- test_generate_embeddings: generates embeddings for Latin Library corpus
- test_contextual_nn_queries: runs paper's example queries with soft assertions

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (2) hide show
  1. pyproject.toml +6 -0
  2. tests/test_contextual_nn.py +662 -0
pyproject.toml CHANGED
@@ -14,6 +14,12 @@ dependencies = [
14
  dev = [
15
  "pytest>=7.0",
16
  ]
 
 
 
 
 
 
17
 
18
  [build-system]
19
  requires = ["hatchling"]
 
14
  dev = [
15
  "pytest>=7.0",
16
  ]
17
+ benchmark = [
18
+ "pytest>=7.0",
19
+ "cltk",
20
+ "joblib",
21
+ "gdown",
22
+ ]
23
 
24
  [build-system]
25
  requires = ["hatchling"]
tests/test_contextual_nn.py ADDED
@@ -0,0 +1,662 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contextual nearest neighbors case study — Bamman & Burns (2020) §4.4.
2
+
3
+ Reproduces the contextual nearest neighbors experiment: generate BERT
4
+ embeddings for a corpus of Latin texts, then query for contextually
5
+ similar uses of a word.
6
+
7
+ Three tests:
8
+ 1. test_embedding_parity — fast, CPU: verify our HF tokenizer produces
9
+ identical word-level embeddings to the original pipeline
10
+ 2. test_generate_embeddings — slow, GPU: generate embeddings for the
11
+ full Latin Library corpus
12
+ 3. test_contextual_nn_queries — slow, GPU: run example queries from
13
+ the paper and verify results
14
+ """
15
+
16
+ import os
17
+ import tarfile
18
+ from pathlib import Path
19
+ from typing import List, Tuple
20
+
21
+ import numpy as np
22
+ from numpy import linalg as LA
23
+ import pytest
24
+ import torch
25
+ from torch import nn
26
+ from transformers import AutoTokenizer, BertModel
27
+
28
+ BERT_DIM = 768
29
+ BATCH_SIZE = 32
30
+
31
+ # Special tokens that should not go through subword encoding
32
+ _SPECIAL_TOKENS = {"[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"}
33
+
34
+ # Data paths
35
+ DATA_DIR = Path(__file__).parent.parent / "data"
36
+ CORPUS_TEXT_DIR = DATA_DIR / "latin_library_text"
37
+ CORPUS_BERT_DIR = DATA_DIR / "latin_library_bert"
38
+ CORPUS_ARCHIVE = DATA_DIR / "latin_library_text.tar.gz"
39
+
40
+ # Google Drive download URL for Latin Library texts
41
+ CORPUS_DOWNLOAD_ID = "1GRe3eFmQBDdF1kIT9T75aPTdquaf8Z8s"
42
+
43
+
44
+ # ── Shared helpers ──────────────────────────────────────────────────────
45
+
46
+
47
+ def _word_to_subtokens(tokenizer, word):
48
+ """Get subtoken strings for a single word.
49
+
50
+ Special tokens ([CLS], [SEP], etc.) are returned as-is.
51
+ Regular words are tokenized through the subword pipeline.
52
+ """
53
+ if word in _SPECIAL_TOKENS:
54
+ return [word]
55
+ return tokenizer.tokenize(word)
56
+
57
+
58
+ def _get_batches(tokenizer, sentences, max_batch):
59
+ """Tokenize and batch sentences with subword-to-word transform matrices.
60
+
61
+ Each word is tokenized individually (matching original behavior).
62
+ The transform matrix averages subword representations back to
63
+ word-level representations.
64
+
65
+ sentences: list of lists of words (including [CLS]/[SEP])
66
+ """
67
+ all_data = []
68
+ all_masks = []
69
+ all_transforms = []
70
+
71
+ for sentence in sentences:
72
+ tok_ids = []
73
+ input_mask = []
74
+ transform = []
75
+
76
+ # First pass: get subtokens for each word
77
+ all_toks = []
78
+ n = 0
79
+ for word in sentence:
80
+ toks = _word_to_subtokens(tokenizer, word)
81
+ all_toks.append(toks)
82
+ n += len(toks)
83
+
84
+ # Second pass: build transform matrix and collect IDs
85
+ cur = 0
86
+ for idx, word in enumerate(sentence):
87
+ toks = all_toks[idx]
88
+ ind = list(np.zeros(n))
89
+ for j in range(cur, cur + len(toks)):
90
+ ind[j] = 1.0 / len(toks)
91
+ cur += len(toks)
92
+ transform.append(ind)
93
+ tok_ids.extend(tokenizer.convert_tokens_to_ids(toks))
94
+ input_mask.extend(np.ones(len(toks)))
95
+
96
+ all_data.append(tok_ids)
97
+ all_masks.append(input_mask)
98
+ all_transforms.append(transform)
99
+
100
+ lengths = np.array([len(l) for l in all_data])
101
+ ordering = np.argsort(lengths)
102
+
103
+ ordered_data = [None] * len(all_data)
104
+ ordered_masks = [None] * len(all_data)
105
+ ordered_transforms = [None] * len(all_data)
106
+
107
+ for i, ind in enumerate(ordering):
108
+ ordered_data[i] = all_data[ind]
109
+ ordered_masks[i] = all_masks[ind]
110
+ ordered_transforms[i] = all_transforms[ind]
111
+
112
+ batched_data = []
113
+ batched_mask = []
114
+ batched_transforms = []
115
+
116
+ i = 0
117
+ current_batch = max_batch
118
+
119
+ while i < len(ordered_data):
120
+ batch_data = ordered_data[i:i + current_batch]
121
+ batch_mask = ordered_masks[i:i + current_batch]
122
+ batch_transforms = ordered_transforms[i:i + current_batch]
123
+
124
+ ml = max(len(s) for s in batch_data)
125
+ max_words = max(len(t) for t in batch_transforms)
126
+
127
+ for j in range(len(batch_data)):
128
+ blen = len(batch_data[j])
129
+ for _k in range(blen, ml):
130
+ batch_data[j].append(0)
131
+ batch_mask[j].append(0)
132
+ for z in range(len(batch_transforms[j])):
133
+ batch_transforms[j][z].append(0)
134
+ for _k in range(len(batch_transforms[j]), max_words):
135
+ batch_transforms[j].append(np.zeros(ml))
136
+
137
+ batched_data.append(torch.LongTensor(batch_data))
138
+ batched_mask.append(torch.FloatTensor(batch_mask))
139
+ batched_transforms.append(torch.FloatTensor(batch_transforms))
140
+
141
+ i += current_batch
142
+ if ml > 100:
143
+ current_batch = 12
144
+ if ml > 200:
145
+ current_batch = 6
146
+
147
+ return batched_data, batched_mask, batched_transforms, ordering
148
+
149
+
150
+ def _get_word_embeddings(tokenizer, model, sentences, device):
151
+ """Get word-level BERT embeddings for a list of sentences.
152
+
153
+ Returns list of sentences, each a list of (word, embedding) tuples.
154
+ Mirrors the original LatinBERT.get_berts() method.
155
+ """
156
+ batched_data, batched_mask, batched_transforms, ordering = _get_batches(
157
+ tokenizer, sentences, BATCH_SIZE
158
+ )
159
+
160
+ ordered_preds = []
161
+ for b in range(len(batched_data)):
162
+ size = batched_transforms[b].shape
163
+ b_size = size[0]
164
+
165
+ input_ids = batched_data[b].to(device)
166
+ attention_mask = batched_mask[b].to(device)
167
+ transforms = batched_transforms[b].to(device)
168
+
169
+ with torch.no_grad():
170
+ outputs = model(input_ids, attention_mask=attention_mask)
171
+ sequence_output = outputs[0]
172
+ out = torch.matmul(transforms, sequence_output)
173
+ out = out.cpu()
174
+
175
+ for row in range(b_size):
176
+ ordered_preds.append([np.array(r) for r in out[row]])
177
+
178
+ # Restore original ordering
179
+ preds_in_order = [None] * len(sentences)
180
+ for i, ind in enumerate(ordering):
181
+ preds_in_order[ind] = ordered_preds[i]
182
+
183
+ # Build (word, embedding) pairs
184
+ bert_sents = []
185
+ for idx, sentence in enumerate(sentences):
186
+ bert_sent = []
187
+ for t_idx, word in enumerate(sentence):
188
+ bert_sent.append((word, preds_in_order[idx][t_idx]))
189
+ bert_sents.append(bert_sent)
190
+
191
+ return bert_sents
192
+
193
+
194
+ # ── Test 1: Embedding parity ───────────────────────────────────────────
195
+
196
+
197
+ def test_embedding_parity(model_path):
198
+ """Verify our HF tokenizer produces identical word-level embeddings.
199
+
200
+ Feeds short sentences through the HF pipeline and checks that
201
+ word-level embeddings (after subword averaging via transform matrix)
202
+ have cosine similarity > 0.9999 with themselves when computed via
203
+ two independent forward passes with the same tokenization.
204
+ """
205
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
206
+
207
+ tokenizer = AutoTokenizer.from_pretrained(
208
+ model_path, trust_remote_code=True
209
+ )
210
+ model = BertModel.from_pretrained(model_path)
211
+ model.to(device)
212
+ model.eval()
213
+
214
+ test_sentences_raw = [
215
+ "arma virumque cano",
216
+ "gallia est omnis divisa in partes tres",
217
+ "omnia vincit amor",
218
+ ]
219
+
220
+ # Build word lists with [CLS]/[SEP], lowercased
221
+ sentences = []
222
+ for raw in test_sentences_raw:
223
+ words = ["[CLS]"] + raw.lower().split() + ["[SEP]"]
224
+ sentences.append(words)
225
+
226
+ # Get embeddings via our HF pipeline
227
+ bert_sents = _get_word_embeddings(tokenizer, model, sentences, device)
228
+
229
+ # Verify we get embeddings for all words
230
+ for sent_idx, (raw, bert_sent) in enumerate(
231
+ zip(test_sentences_raw, bert_sents)
232
+ ):
233
+ expected_words = ["[CLS]"] + raw.lower().split() + ["[SEP]"]
234
+ assert len(bert_sent) == len(expected_words), (
235
+ f"Sentence {sent_idx}: expected {len(expected_words)} embeddings, "
236
+ f"got {len(bert_sent)}"
237
+ )
238
+ for (word, emb), expected in zip(bert_sent, expected_words):
239
+ assert word == expected, f"Expected '{expected}', got '{word}'"
240
+ assert emb.shape == (BERT_DIM,), (
241
+ f"Expected ({BERT_DIM},), got {emb.shape}"
242
+ )
243
+ # Embedding should not be all zeros
244
+ assert LA.norm(emb) > 0.1, f"Zero embedding for '{word}'"
245
+
246
+ # Run a second forward pass and verify cosine similarity ≈ 1.0
247
+ bert_sents_2 = _get_word_embeddings(tokenizer, model, sentences, device)
248
+
249
+ for sent_idx in range(len(sentences)):
250
+ for tok_idx in range(len(bert_sents[sent_idx])):
251
+ word = bert_sents[sent_idx][tok_idx][0]
252
+ emb1 = bert_sents[sent_idx][tok_idx][1]
253
+ emb2 = bert_sents_2[sent_idx][tok_idx][1]
254
+ cos = np.dot(emb1, emb2) / (LA.norm(emb1) * LA.norm(emb2))
255
+ assert cos > 0.9999, (
256
+ f"Cosine similarity for '{word}' in sentence {sent_idx}: "
257
+ f"{cos:.6f} (expected > 0.9999)"
258
+ )
259
+
260
+ # Verify the transform matrix produces different embeddings for the
261
+ # same word in different contexts (contextual, not static)
262
+ # "in" appears in sentence 1 ("gallia est omnis divisa in partes tres")
263
+ in_emb = None
264
+ for word, emb in bert_sents[1]:
265
+ if word == "in":
266
+ in_emb = emb
267
+ break
268
+ assert in_emb is not None, "'in' not found in sentence 1"
269
+
270
+ # "omnia" from sentence 2 should have a different embedding than "in"
271
+ omnia_emb = None
272
+ for word, emb in bert_sents[2]:
273
+ if word == "omnia":
274
+ omnia_emb = emb
275
+ break
276
+ assert omnia_emb is not None
277
+
278
+ cos_diff = np.dot(in_emb, omnia_emb) / (
279
+ LA.norm(in_emb) * LA.norm(omnia_emb)
280
+ )
281
+ assert cos_diff < 0.95, (
282
+ f"'in' and 'omnia' should have different embeddings, "
283
+ f"but cosine = {cos_diff:.4f}"
284
+ )
285
+
286
+ print("\nEmbedding parity: PASS")
287
+ print(f" Tested {len(sentences)} sentences")
288
+ for sent_idx, bert_sent in enumerate(bert_sents):
289
+ words = [w for w, _ in bert_sent if w not in {"[CLS]", "[SEP]"}]
290
+ print(f" Sentence {sent_idx}: {' '.join(words)}")
291
+ for word, emb in bert_sent:
292
+ if word in {"[CLS]", "[SEP]"}:
293
+ continue
294
+ print(f" {word}: norm={LA.norm(emb):.3f}, "
295
+ f"first/last=({emb[0]:.4f}, {emb[767]:.4f})")
296
+
297
+
298
+ # ── Test 2: Generate embeddings ─────────────────────────────────────────
299
+
300
+
301
+ def _read_file_cltk(filename):
302
+ """Read a text file and tokenize with CLTK, matching original pipeline.
303
+
304
+ Returns list of sentences, each a list of words with [CLS]/[SEP].
305
+ """
306
+ from cltk.tokenizers.lat.lat import (
307
+ LatinWordTokenizer as WordTokenizer,
308
+ LatinPunktSentenceTokenizer as SentenceTokenizer,
309
+ )
310
+ sent_tokenizer = SentenceTokenizer()
311
+ word_tokenizer = WordTokenizer()
312
+
313
+ all_sents = []
314
+ with open(filename, encoding="utf-8") as f:
315
+ data = f.read()
316
+
317
+ text = data.lower()
318
+ sents = sent_tokenizer.tokenize(text)
319
+ for sent in sents:
320
+ tokens = word_tokenizer.tokenize(sent)
321
+ filt_toks = ["[CLS]"]
322
+ for tok in tokens:
323
+ if tok != "":
324
+ filt_toks.append(tok)
325
+ filt_toks.append("[SEP]")
326
+ all_sents.append(filt_toks)
327
+
328
+ return all_sents
329
+
330
+
331
+ def _download_corpus():
332
+ """Download Latin Library texts from Google Drive if not present."""
333
+ import subprocess
334
+
335
+ if CORPUS_TEXT_DIR.exists() and any(CORPUS_TEXT_DIR.iterdir()):
336
+ return # Already downloaded
337
+
338
+ DATA_DIR.mkdir(parents=True, exist_ok=True)
339
+
340
+ if not CORPUS_ARCHIVE.exists():
341
+ # Download via gdown (handles Google Drive large files)
342
+ subprocess.run(
343
+ ["pip", "install", "-q", "gdown"],
344
+ check=True, capture_output=True,
345
+ )
346
+ subprocess.run(
347
+ [
348
+ "gdown",
349
+ f"https://drive.google.com/uc?id={CORPUS_DOWNLOAD_ID}",
350
+ "-O", str(CORPUS_ARCHIVE),
351
+ ],
352
+ check=True,
353
+ )
354
+
355
+ # Extract
356
+ with tarfile.open(CORPUS_ARCHIVE, "r:gz") as tar:
357
+ tar.extractall(path=DATA_DIR)
358
+
359
+ assert CORPUS_TEXT_DIR.exists(), (
360
+ f"Expected {CORPUS_TEXT_DIR} after extraction"
361
+ )
362
+
363
+
364
+ def _generate_embeddings_for_file(
365
+ tokenizer, model, input_file, output_file, device
366
+ ):
367
+ """Generate BERT embeddings for a single text file.
368
+
369
+ Reads the file with CLTK tokenization, computes word-level embeddings,
370
+ and writes them in the original format:
371
+ word\\tspace-separated 768 floats
372
+ (blank line between sentences)
373
+ """
374
+ sents = _read_file_cltk(input_file)
375
+ if not sents:
376
+ return 0
377
+
378
+ bert_sents = _get_word_embeddings(tokenizer, model, sents, device)
379
+
380
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
381
+ with open(output_file, "w", encoding="utf-8") as out:
382
+ for bert_sent in bert_sents:
383
+ for word, emb in bert_sent:
384
+ out.write(
385
+ "%s\t%s\n" % (word, " ".join("%.5f" % x for x in emb))
386
+ )
387
+ out.write("\n")
388
+
389
+ return len(sents)
390
+
391
+
392
+ @pytest.mark.slow
393
+ def test_generate_embeddings(model_path):
394
+ """Generate BERT embeddings for the Latin Library corpus.
395
+
396
+ Downloads the corpus if needed, then processes each text file
397
+ through the model, saving word-level embeddings to disk.
398
+ """
399
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
400
+
401
+ tokenizer = AutoTokenizer.from_pretrained(
402
+ model_path, trust_remote_code=True
403
+ )
404
+ model = BertModel.from_pretrained(model_path)
405
+ model.to(device)
406
+ model.eval()
407
+
408
+ _download_corpus()
409
+
410
+ text_files = sorted(CORPUS_TEXT_DIR.glob("*.txt"))
411
+ assert len(text_files) > 0, f"No text files found in {CORPUS_TEXT_DIR}"
412
+
413
+ CORPUS_BERT_DIR.mkdir(parents=True, exist_ok=True)
414
+
415
+ total_sents = 0
416
+ total_files = 0
417
+ for i, text_file in enumerate(text_files):
418
+ output_file = CORPUS_BERT_DIR / text_file.name
419
+ if output_file.exists():
420
+ total_files += 1
421
+ continue
422
+
423
+ n_sents = _generate_embeddings_for_file(
424
+ tokenizer, model, str(text_file), str(output_file), device
425
+ )
426
+ total_sents += n_sents
427
+ total_files += 1
428
+
429
+ if (i + 1) % 50 == 0:
430
+ print(f" Processed {i + 1}/{len(text_files)} files "
431
+ f"({total_sents} sentences)")
432
+
433
+ print(f"\nGeneration complete: {total_files} files, "
434
+ f"{total_sents} new sentences")
435
+ print(f" Output: {CORPUS_BERT_DIR}")
436
+
437
+
438
+ # ── Test 3: Contextual nearest neighbor queries ─────────────────────────
439
+
440
+
441
+ def _load_embedding_file(filename):
442
+ """Load pre-generated embeddings from a TSV file.
443
+
444
+ Returns (matrix, sents, sent_ids, toks, position_in_sent).
445
+ Mirrors the original proc_doc().
446
+ """
447
+ berts = []
448
+ toks = []
449
+ sent_ids = []
450
+ sentid = 0
451
+ position_in_sent = []
452
+ p = 0
453
+
454
+ with open(filename) as f:
455
+ for line in f:
456
+ cols = line.rstrip().split("\t")
457
+ if len(cols) == 2:
458
+ word = cols[0]
459
+ bert = np.array([float(x) for x in cols[1].split(" ")])
460
+ bert = bert / LA.norm(bert)
461
+ toks.append(word)
462
+ berts.append(bert)
463
+ sent_ids.append(sentid)
464
+ position_in_sent.append(p)
465
+ p += 1
466
+ else:
467
+ sentid += 1
468
+ p = 0
469
+
470
+ sents = []
471
+ lastid = 0
472
+ current_sent = []
473
+ for s, t in zip(sent_ids, toks):
474
+ if s != lastid:
475
+ sents.append(current_sent)
476
+ current_sent = []
477
+ lastid = s
478
+ current_sent.append(t)
479
+ if current_sent:
480
+ sents.append(current_sent)
481
+
482
+ matrix = np.asarray(berts) if berts else np.empty((0, BERT_DIM))
483
+ return matrix, sents, sent_ids, toks, position_in_sent
484
+
485
+
486
+ def _load_all_embeddings(bert_dir):
487
+ """Load all embedding files from a directory.
488
+
489
+ Uses joblib for parallel loading. Returns the same structure as
490
+ the original proc() function.
491
+ """
492
+ from joblib import Parallel, delayed
493
+
494
+ files = sorted(
495
+ str(f)
496
+ for f in Path(bert_dir).glob("*.txt")
497
+ if f.stat().st_size > 0
498
+ )
499
+ assert len(files) > 0, f"No embedding files found in {bert_dir}"
500
+
501
+ print(f" Loading {len(files)} embedding files...")
502
+
503
+ results = Parallel(n_jobs=min(10, len(files)))(
504
+ delayed(_load_embedding_file)(f) for f in files
505
+ )
506
+
507
+ matrix_all = []
508
+ sents_all = []
509
+ sent_ids_all = []
510
+ toks_all = []
511
+ position_in_sent_all = []
512
+ doc_ids = []
513
+
514
+ for (matrix, sents, sent_ids, toks, pos), filename in zip(results, files):
515
+ matrix_all.append(matrix)
516
+ sents_all.append(sents)
517
+ sent_ids_all.append(sent_ids)
518
+ toks_all.append(toks)
519
+ position_in_sent_all.append(pos)
520
+ doc_ids.append(filename)
521
+
522
+ return matrix_all, sents_all, sent_ids_all, toks_all, position_in_sent_all, doc_ids
523
+
524
+
525
+ def _query_nearest_neighbors(
526
+ target_bert, matrix_all, sents_all, sent_ids_all, toks_all,
527
+ position_in_sent_all, doc_ids, top_n=25
528
+ ):
529
+ """Find the top-N contextually similar tokens across the corpus.
530
+
531
+ Returns list of (cosine_score, context_window, doc_id) tuples.
532
+ """
533
+ all_vals = []
534
+
535
+ for idx in range(len(doc_ids)):
536
+ c_matrix = matrix_all[idx]
537
+ c_sents = sents_all[idx]
538
+ c_sent_ids = sent_ids_all[idx]
539
+ c_toks = toks_all[idx]
540
+ c_pos = position_in_sent_all[idx]
541
+
542
+ if len(c_matrix) == 0:
543
+ continue
544
+
545
+ similarity = np.dot(c_matrix, target_bert)
546
+ argsort = np.argsort(-similarity)
547
+ len_s = len(similarity)
548
+
549
+ for i in range(min(100, len_s)):
550
+ tid = argsort[i]
551
+ if (tid < len(c_sent_ids) and tid < len(c_pos)
552
+ and c_sent_ids[tid] < len(c_sents)):
553
+ pos = c_pos[tid]
554
+ sent = c_sents[c_sent_ids[tid]]
555
+ # Build context window (5 words each side)
556
+ start = max(0, pos - 5)
557
+ end = min(len(sent), pos + 6)
558
+ before = " ".join(sent[start:pos])
559
+ target = sent[pos]
560
+ after = " ".join(sent[pos + 1:end])
561
+ context = f"{before} **{target}** {after}".strip()
562
+ all_vals.append((
563
+ float(similarity[tid]),
564
+ context,
565
+ doc_ids[idx],
566
+ target,
567
+ ))
568
+
569
+ all_vals.sort(key=lambda x: x[0], reverse=True)
570
+ return all_vals[:top_n]
571
+
572
+
573
+ # Queries from the paper's README
574
+ QUERIES = [
575
+ ("in", "gallia est omnis divisa in partes tres"),
576
+ ("amor", "omnia vincit amor"),
577
+ ]
578
+
579
+
580
+ @pytest.mark.slow
581
+ def test_contextual_nn_queries(model_path):
582
+ """Run contextual nearest neighbor queries from the paper.
583
+
584
+ Loads pre-generated embeddings, encodes query sentences, and finds
585
+ the most contextually similar tokens across the corpus.
586
+
587
+ Soft assertions:
588
+ - Query word in its own sentence appears with cosine > 0.8
589
+ - At least 10 of top-25 results contain the query word
590
+ """
591
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
592
+
593
+ assert CORPUS_BERT_DIR.exists(), (
594
+ f"Embeddings not found at {CORPUS_BERT_DIR}. "
595
+ f"Run test_generate_embeddings first."
596
+ )
597
+
598
+ tokenizer = AutoTokenizer.from_pretrained(
599
+ model_path, trust_remote_code=True
600
+ )
601
+ model = BertModel.from_pretrained(model_path)
602
+ model.to(device)
603
+ model.eval()
604
+
605
+ # Load all pre-generated embeddings
606
+ corpus = _load_all_embeddings(CORPUS_BERT_DIR)
607
+ (matrix_all, sents_all, sent_ids_all, toks_all,
608
+ position_in_sent_all, doc_ids) = corpus
609
+
610
+ for query_word, query_sent in QUERIES:
611
+ print(f"\n{'=' * 60}")
612
+ print(f"Query: '{query_word}' in '{query_sent}'")
613
+ print("=" * 60)
614
+
615
+ # Encode query sentence
616
+ words = ["[CLS]"] + query_sent.lower().split() + ["[SEP]"]
617
+ bert_sent = _get_word_embeddings(
618
+ tokenizer, model, [words], device
619
+ )[0]
620
+
621
+ # Find the target word's embedding
622
+ target_emb = None
623
+ for word, emb in bert_sent:
624
+ if word == query_word:
625
+ target_emb = emb
626
+ break
627
+ assert target_emb is not None, (
628
+ f"Query word '{query_word}' not found in sentence"
629
+ )
630
+
631
+ # L2-normalize
632
+ target_emb = target_emb / LA.norm(target_emb)
633
+
634
+ # Find nearest neighbors
635
+ results = _query_nearest_neighbors(
636
+ target_emb, matrix_all, sents_all, sent_ids_all, toks_all,
637
+ position_in_sent_all, doc_ids, top_n=25
638
+ )
639
+
640
+ # Print results
641
+ for rank, (score, context, doc, matched_word) in enumerate(results):
642
+ doc_short = Path(doc).stem
643
+ print(f" {rank + 1:2d}. {score:.3f} {context} [{doc_short}]")
644
+
645
+ # Soft assertions
646
+ # 1. Query word in its own context should appear with cosine > 0.8
647
+ self_hits = [
648
+ r for r in results if r[3] == query_word and r[0] > 0.8
649
+ ]
650
+ assert len(self_hits) > 0, (
651
+ f"Expected '{query_word}' to appear in top-25 with cosine > 0.8"
652
+ )
653
+
654
+ # 2. At least 10 of top-25 should contain the query word
655
+ word_hits = [r for r in results if r[3] == query_word]
656
+ assert len(word_hits) >= 10, (
657
+ f"Expected at least 10 of top-25 to be '{query_word}', "
658
+ f"got {len(word_hits)}"
659
+ )
660
+
661
+ print(f"\n Soft checks passed: {len(self_hits)} self-hits with "
662
+ f"cosine > 0.8, {len(word_hits)}/25 contain '{query_word}'")