#!/usr/bin/env python3 # /// script # requires-python = ">=3.10" # dependencies = [ # "sentence-transformers[train]>=5.5.0", # "datasets>=2.19.0", # "numpy", # "tqdm", # ] # /// """Memory-bounded hard-negative miner. Custom impl (not sentence-transformers util) because the SE function tries to hold the full anchor × corpus similarity matrix, which OOMs at 327k anchors × 327k positives on M4. Algorithm: 1. Encode all unique positives once -> N x dim float32 (~670MB at 327k x 512). 2. Encode all unique anchors once -> M x dim float32. 3. For each anchor batch (size B): - scores = batch_emb @ positives_emb.T -> B x N - per anchor: argpartition for top RANGE_MAX, exclude actual positive, sample NUM_NEGATIVES from rank [RANGE_MIN, RANGE_MAX). 4. Stream triplets to parquet. Peak memory: B * N * 4 bytes for scores. With B=500, N=327k: 650MB. Run: SMOKE_TEST=1 uv run --exclude-newer=2026-05-12 mine_hard_negs_v2.py uv run --exclude-newer=2026-05-12 mine_hard_negs_v2.py """ from __future__ import annotations import os import random import re import sys from collections import defaultdict # Force unbuffered stdout so progress is visible when piped sys.stdout.reconfigure(line_buffering=True) import numpy as np import torch from datasets import Dataset, load_dataset from sentence_transformers import SentenceTransformer from tqdm import tqdm V3_MODEL_PATH = "models/static-embedding-chess/final" OUTPUT_PATH = "models/hard_negatives.parquet" SMOKE_TEST = os.environ.get("SMOKE_TEST") == "1" HELDOUT_FREQ_MIN = 3 HELDOUT_FREQ_MAX = 30 EVAL_QUERIES = 200 NUM_NEGATIVES = 5 RANGE_MIN = 10 RANGE_MAX = 50 ANCHOR_BATCH_SIZE = 500 # 500 * 327k * 4 = ~650MB scratch per batch def _join_tags(tags): return " ".join(t.replace("_", " ") for t in tags) if tags else "" def _bigram_token_str(moves): toks = moves.split() if len(toks) < 2: return moves bigrams = " ".join(f"{a}+{b}" for a, b in zip(toks, toks[1:])) return f"{moves} {bigrams}" def build_puzzle_pairs(batch): anchors, positives = [], [] for themes, op, moves in zip(batch["Themes"], batch["OpeningTags"], batch["Moves"]): themes_txt = _join_tags(themes) op_txt = _join_tags(op) if not themes_txt: continue anchor = themes_txt + (f" {op_txt}" if op_txt else "") positive = f"themes {themes_txt}" if op_txt: positive += f" opening {op_txt}" positive += f" moves {_bigram_token_str(moves)}" anchors.append(anchor) positives.append(positive) return {"anchor": anchors, "positive": positives} def main(): print(f"Loading v3 model from {V3_MODEL_PATH}") model = SentenceTransformer(V3_MODEL_PATH) print("Loading puzzles...") puzzles = load_dataset("Lichess/chess-puzzles", split="train") if SMOKE_TEST: puzzles = puzzles.select(range(100_000)) pair_puzzles = puzzles.map( build_puzzle_pairs, batched=True, batch_size=20_000, remove_columns=puzzles.column_names, num_proc=4, ) # Materialize columns ONCE as Python lists (HF Dataset random access is # O(N) per call due to Arrow buffer slicing -- 5.8M iterations would take # forever otherwise). print("Materializing columns...") anchors_list = pair_puzzles["anchor"] positives_list = pair_puzzles["positive"] print(f" done ({len(anchors_list):,} rows)") # Remove held-out anchors freq = defaultdict(int) for a in anchors_list: freq[a] += 1 rare_pool = sorted( ((a, c) for a, c in freq.items() if HELDOUT_FREQ_MIN <= c <= HELDOUT_FREQ_MAX), key=lambda kv: kv[1], ) heldout = {a for a, _ in rare_pool[:EVAL_QUERIES]} # Build one-per-anchor (use as both the anchor source AND the corpus source) by_anchor = defaultdict(list) for a, p in zip(anchors_list, positives_list): if a not in heldout: by_anchor[a].append(p) print(f" unique anchors (post-heldout-strip): {len(by_anchor):,}") rng = random.Random(12) unique_anchors = list(by_anchor.keys()) if SMOKE_TEST: unique_anchors = unique_anchors[:200] print(f" SMOKE_TEST=1: trimmed to {len(unique_anchors)}") # For each anchor, pick ONE random positive (skip the O(n^2) filter -- just # iterate unique_anchors directly). print(f" Sampling one positive per anchor...") positives = [rng.choice(by_anchor[a]) for a in unique_anchors] print(f" done") # Encode anchors and positives print(f"\nEncoding {len(unique_anchors):,} anchors...") anchor_emb = model.encode( unique_anchors, batch_size=512, show_progress_bar=True, convert_to_numpy=True ) anchor_emb = anchor_emb / np.linalg.norm(anchor_emb, axis=1, keepdims=True) print(f" anchor shape: {anchor_emb.shape}, mem: {anchor_emb.nbytes / 1e6:.1f}MB") print(f"\nEncoding {len(positives):,} positives...") positive_emb = model.encode( positives, batch_size=512, show_progress_bar=True, convert_to_numpy=True ) positive_emb = positive_emb / np.linalg.norm(positive_emb, axis=1, keepdims=True) print(f" positive shape: {positive_emb.shape}, mem: {positive_emb.nbytes / 1e6:.1f}MB") # Mine hard negs in chunks print(f"\nMining hard negs (range={RANGE_MIN}..{RANGE_MAX}, num={NUM_NEGATIVES}, batch={ANCHOR_BATCH_SIZE})...") out_anchors, out_positives, out_negatives = [], [], [] pos_scores_acc, neg_scores_acc = [], [] n_anchors = len(unique_anchors) for start in tqdm(range(0, n_anchors, ANCHOR_BATCH_SIZE)): end = min(start + ANCHOR_BATCH_SIZE, n_anchors) ab = anchor_emb[start:end] # B x D # scores: B x N. Each row i is anchor[start+i] vs all positives. scores = ab @ positive_emb.T # B x N (float32) # For each anchor i in batch, sort scores desc, get top RANGE_MAX # excluding the actual positive (which is at column start+i). # We use argpartition for efficiency. for i in range(end - start): anchor_idx = start + i row = scores[i].copy() # Mask out the actual positive (anchor's own positive is at anchor_idx) row[anchor_idx] = -np.inf # Take top RANGE_MAX indices top_idx = np.argpartition(-row, RANGE_MAX)[:RANGE_MAX] # Sort them by score top_idx = top_idx[np.argsort(-row[top_idx])] # Sample NUM_NEGATIVES from rank [RANGE_MIN, RANGE_MAX) mid_range = top_idx[RANGE_MIN:RANGE_MAX] sampled = rng.sample(list(mid_range), min(NUM_NEGATIVES, len(mid_range))) for neg_idx in sampled: out_anchors.append(unique_anchors[anchor_idx]) out_positives.append(positives[anchor_idx]) out_negatives.append(positives[neg_idx]) pos_scores_acc.append(float(scores[i, anchor_idx])) neg_scores_acc.append(float(scores[i, neg_idx])) print(f"\n output triplets: {len(out_anchors):,}") print(f" positive scores: mean={np.mean(pos_scores_acc):.3f} std={np.std(pos_scores_acc):.3f}") print(f" hard-neg scores: mean={np.mean(neg_scores_acc):.3f} std={np.std(neg_scores_acc):.3f}") print(f" margin (pos - neg): mean={np.mean(np.array(pos_scores_acc) - np.array(neg_scores_acc)):.3f}") # Save os.makedirs(os.path.dirname(OUTPUT_PATH) or ".", exist_ok=True) Dataset.from_dict({ "anchor": out_anchors, "positive": out_positives, "negative": out_negatives, }).to_parquet(OUTPUT_PATH) print(f" saved to {OUTPUT_PATH} ({os.path.getsize(OUTPUT_PATH) / 1e6:.1f} MB)") # Sample print("\n=== Sample triplets ===") for i in [0, len(out_anchors)//2, len(out_anchors)-1]: print(f" ANCHOR: {out_anchors[i]!r}") print(f" POSITIVE:{out_positives[i][:100]!r}") print(f" NEGATIVE:{out_negatives[i][:100]!r}") print() if __name__ == "__main__": main()