static-embedding-chess / scripts /mine_hard_negs_v2.py
oneryalcin's picture
Add files using upload-large-folder tool
f8392aa verified
#!/usr/bin/env python3
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "sentence-transformers[train]>=5.5.0",
# "datasets>=2.19.0",
# "numpy",
# "tqdm",
# ]
# ///
"""Memory-bounded hard-negative miner. Custom impl (not sentence-transformers
util) because the SE function tries to hold the full anchor × corpus similarity
matrix, which OOMs at 327k anchors × 327k positives on M4.
Algorithm:
1. Encode all unique positives once -> N x dim float32 (~670MB at 327k x 512).
2. Encode all unique anchors once -> M x dim float32.
3. For each anchor batch (size B):
- scores = batch_emb @ positives_emb.T -> B x N
- per anchor: argpartition for top RANGE_MAX, exclude actual positive,
sample NUM_NEGATIVES from rank [RANGE_MIN, RANGE_MAX).
4. Stream triplets to parquet.
Peak memory: B * N * 4 bytes for scores. With B=500, N=327k: 650MB.
Run:
SMOKE_TEST=1 uv run --exclude-newer=2026-05-12 mine_hard_negs_v2.py
uv run --exclude-newer=2026-05-12 mine_hard_negs_v2.py
"""
from __future__ import annotations
import os
import random
import re
import sys
from collections import defaultdict
# Force unbuffered stdout so progress is visible when piped
sys.stdout.reconfigure(line_buffering=True)
import numpy as np
import torch
from datasets import Dataset, load_dataset
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
V3_MODEL_PATH = "models/static-embedding-chess/final"
OUTPUT_PATH = "models/hard_negatives.parquet"
SMOKE_TEST = os.environ.get("SMOKE_TEST") == "1"
HELDOUT_FREQ_MIN = 3
HELDOUT_FREQ_MAX = 30
EVAL_QUERIES = 200
NUM_NEGATIVES = 5
RANGE_MIN = 10
RANGE_MAX = 50
ANCHOR_BATCH_SIZE = 500 # 500 * 327k * 4 = ~650MB scratch per batch
def _join_tags(tags):
return " ".join(t.replace("_", " ") for t in tags) if tags else ""
def _bigram_token_str(moves):
toks = moves.split()
if len(toks) < 2:
return moves
bigrams = " ".join(f"{a}+{b}" for a, b in zip(toks, toks[1:]))
return f"{moves} {bigrams}"
def build_puzzle_pairs(batch):
anchors, positives = [], []
for themes, op, moves in zip(batch["Themes"], batch["OpeningTags"], batch["Moves"]):
themes_txt = _join_tags(themes)
op_txt = _join_tags(op)
if not themes_txt:
continue
anchor = themes_txt + (f" {op_txt}" if op_txt else "")
positive = f"themes {themes_txt}"
if op_txt:
positive += f" opening {op_txt}"
positive += f" moves {_bigram_token_str(moves)}"
anchors.append(anchor)
positives.append(positive)
return {"anchor": anchors, "positive": positives}
def main():
print(f"Loading v3 model from {V3_MODEL_PATH}")
model = SentenceTransformer(V3_MODEL_PATH)
print("Loading puzzles...")
puzzles = load_dataset("Lichess/chess-puzzles", split="train")
if SMOKE_TEST:
puzzles = puzzles.select(range(100_000))
pair_puzzles = puzzles.map(
build_puzzle_pairs,
batched=True,
batch_size=20_000,
remove_columns=puzzles.column_names,
num_proc=4,
)
# Materialize columns ONCE as Python lists (HF Dataset random access is
# O(N) per call due to Arrow buffer slicing -- 5.8M iterations would take
# forever otherwise).
print("Materializing columns...")
anchors_list = pair_puzzles["anchor"]
positives_list = pair_puzzles["positive"]
print(f" done ({len(anchors_list):,} rows)")
# Remove held-out anchors
freq = defaultdict(int)
for a in anchors_list:
freq[a] += 1
rare_pool = sorted(
((a, c) for a, c in freq.items() if HELDOUT_FREQ_MIN <= c <= HELDOUT_FREQ_MAX),
key=lambda kv: kv[1],
)
heldout = {a for a, _ in rare_pool[:EVAL_QUERIES]}
# Build one-per-anchor (use as both the anchor source AND the corpus source)
by_anchor = defaultdict(list)
for a, p in zip(anchors_list, positives_list):
if a not in heldout:
by_anchor[a].append(p)
print(f" unique anchors (post-heldout-strip): {len(by_anchor):,}")
rng = random.Random(12)
unique_anchors = list(by_anchor.keys())
if SMOKE_TEST:
unique_anchors = unique_anchors[:200]
print(f" SMOKE_TEST=1: trimmed to {len(unique_anchors)}")
# For each anchor, pick ONE random positive (skip the O(n^2) filter -- just
# iterate unique_anchors directly).
print(f" Sampling one positive per anchor...")
positives = [rng.choice(by_anchor[a]) for a in unique_anchors]
print(f" done")
# Encode anchors and positives
print(f"\nEncoding {len(unique_anchors):,} anchors...")
anchor_emb = model.encode(
unique_anchors, batch_size=512, show_progress_bar=True, convert_to_numpy=True
)
anchor_emb = anchor_emb / np.linalg.norm(anchor_emb, axis=1, keepdims=True)
print(f" anchor shape: {anchor_emb.shape}, mem: {anchor_emb.nbytes / 1e6:.1f}MB")
print(f"\nEncoding {len(positives):,} positives...")
positive_emb = model.encode(
positives, batch_size=512, show_progress_bar=True, convert_to_numpy=True
)
positive_emb = positive_emb / np.linalg.norm(positive_emb, axis=1, keepdims=True)
print(f" positive shape: {positive_emb.shape}, mem: {positive_emb.nbytes / 1e6:.1f}MB")
# Mine hard negs in chunks
print(f"\nMining hard negs (range={RANGE_MIN}..{RANGE_MAX}, num={NUM_NEGATIVES}, batch={ANCHOR_BATCH_SIZE})...")
out_anchors, out_positives, out_negatives = [], [], []
pos_scores_acc, neg_scores_acc = [], []
n_anchors = len(unique_anchors)
for start in tqdm(range(0, n_anchors, ANCHOR_BATCH_SIZE)):
end = min(start + ANCHOR_BATCH_SIZE, n_anchors)
ab = anchor_emb[start:end] # B x D
# scores: B x N. Each row i is anchor[start+i] vs all positives.
scores = ab @ positive_emb.T # B x N (float32)
# For each anchor i in batch, sort scores desc, get top RANGE_MAX
# excluding the actual positive (which is at column start+i).
# We use argpartition for efficiency.
for i in range(end - start):
anchor_idx = start + i
row = scores[i].copy()
# Mask out the actual positive (anchor's own positive is at anchor_idx)
row[anchor_idx] = -np.inf
# Take top RANGE_MAX indices
top_idx = np.argpartition(-row, RANGE_MAX)[:RANGE_MAX]
# Sort them by score
top_idx = top_idx[np.argsort(-row[top_idx])]
# Sample NUM_NEGATIVES from rank [RANGE_MIN, RANGE_MAX)
mid_range = top_idx[RANGE_MIN:RANGE_MAX]
sampled = rng.sample(list(mid_range), min(NUM_NEGATIVES, len(mid_range)))
for neg_idx in sampled:
out_anchors.append(unique_anchors[anchor_idx])
out_positives.append(positives[anchor_idx])
out_negatives.append(positives[neg_idx])
pos_scores_acc.append(float(scores[i, anchor_idx]))
neg_scores_acc.append(float(scores[i, neg_idx]))
print(f"\n output triplets: {len(out_anchors):,}")
print(f" positive scores: mean={np.mean(pos_scores_acc):.3f} std={np.std(pos_scores_acc):.3f}")
print(f" hard-neg scores: mean={np.mean(neg_scores_acc):.3f} std={np.std(neg_scores_acc):.3f}")
print(f" margin (pos - neg): mean={np.mean(np.array(pos_scores_acc) - np.array(neg_scores_acc)):.3f}")
# Save
os.makedirs(os.path.dirname(OUTPUT_PATH) or ".", exist_ok=True)
Dataset.from_dict({
"anchor": out_anchors,
"positive": out_positives,
"negative": out_negatives,
}).to_parquet(OUTPUT_PATH)
print(f" saved to {OUTPUT_PATH} ({os.path.getsize(OUTPUT_PATH) / 1e6:.1f} MB)")
# Sample
print("\n=== Sample triplets ===")
for i in [0, len(out_anchors)//2, len(out_anchors)-1]:
print(f" ANCHOR: {out_anchors[i]!r}")
print(f" POSITIVE:{out_positives[i][:100]!r}")
print(f" NEGATIVE:{out_negatives[i][:100]!r}")
print()
if __name__ == "__main__":
main()