/** * Pair-index builder + per-mention argmax + cluster grouping. * * Pure JS, no ORT, no DOM. Mirrors the Python helpers in * ``infon/scripts/coref_onnx_experiment.py`` (``build_pairs`` / * ``split_pairs_by_mention``) so the JS pipeline produces * bit-identical pair tensors. */ /** * Enumerate ``(i, j)`` candidate pairs for ``M`` mentions. * * For mention ``m`` (1-indexed because index 0 is DUMMY) we emit * ``(m, 0), (m, 1), …, (m, m-1)`` — DUMMY first, then every earlier * mention. This is the same triangular shape the Python * ``build_pairs`` returns; the scorer ONNX expects this exact layout * because the in-graph ``index_select`` over the prepended DUMMY * relies on j=0 meaning "no antecedent." * * @param nMentions number of mentions in the doc * @returns ``[pairI, pairJ]`` BigInt64 typed arrays of equal length. * Lengths: ``M*(M+1)/2``. */ export function buildPairs(nMentions: number): [BigInt64Array, BigInt64Array] { const pi: bigint[] = []; const pj: bigint[] = []; for (let m = 1; m <= nMentions; m++) { pi.push(BigInt(m)); pj.push(0n); for (let j = 1; j < m; j++) { pi.push(BigInt(m)); pj.push(BigInt(j)); } } return [BigInt64Array.from(pi), BigInt64Array.from(pj)]; } /** * Group flat pair scores back into per-mention argmax decisions. * Mirrors ``split_pairs_by_mention`` in the Python harness. * * @returns ``decisions[i]`` = the mention index (1-based) chosen as * mention i's antecedent, or ``0`` for DUMMY (no antecedent). * Translate to 0-based mention indices with ``decisions[i] - 1``. */ export function pickAntecedents( nMentions: number, pairI: BigInt64Array, pairJ: BigInt64Array, scores: Float32Array, ): { antecedent: number; score: number }[] { const out: { antecedent: number; score: number }[] = []; for (let m = 1; m <= nMentions; m++) { let bestIdx = -1; let bestScore = -Infinity; for (let k = 0; k < pairI.length; k++) { if (Number(pairI[k]) !== m) continue; const s = scores[k]; if (s > bestScore) { bestScore = s; bestIdx = k; } } out.push({ antecedent: bestIdx >= 0 ? Number(pairJ[bestIdx]) : 0, score: bestScore, }); } return out; } /** * Group antecedent decisions into clusters using union-find. * * Each mention either points to DUMMY (starts its own cluster) or to * an earlier mention (joins that mention's cluster). Cluster IDs are * dense 0-based; singletons are not assigned a cluster (returned as * ``-1``) so callers can render them differently. * * @param decisions ``decisions[i].antecedent`` is the *1-based* * mention index this mention links to, or ``0`` for * DUMMY. (Same convention as the model output.) * @returns * - ``cluster[i]`` — cluster id for mention i, or -1 if singleton * - ``clusters`` — list of multi-mention clusters, each a list of * mention indices in document order */ export function groupClusters( decisions: { antecedent: number }[], ): { cluster: number[]; clusters: number[][] } { const n = decisions.length; // Union-find. parent[i] points to a smaller-or-equal mention index. const parent = Array.from({ length: n }, (_, i) => i); const find = (x: number): number => { while (parent[x] !== x) { parent[x] = parent[parent[x]]; // path compression x = parent[x]; } return x; }; const union = (a: number, b: number) => { const ra = find(a); const rb = find(b); if (ra !== rb) { // Always attach the higher-index root under the lower-index root // so cluster representatives are first-mention. if (ra < rb) parent[rb] = ra; else parent[ra] = rb; } }; for (let i = 0; i < n; i++) { const ant = decisions[i].antecedent; if (ant > 0) { // ant is 1-based; the mention it points to is ant - 1. union(i, ant - 1); } } // Bucket by root. const roots: number[][] = []; const rootIdx = new Map(); for (let i = 0; i < n; i++) { const r = find(i); let idx = rootIdx.get(r); if (idx === undefined) { idx = roots.length; roots.push([]); rootIdx.set(r, idx); } roots[idx].push(i); } // Collapse: only multi-mention clusters get a stable id; singletons // get -1. const cluster = new Array(n).fill(-1); const clusters: number[][] = []; for (const group of roots) { if (group.length < 2) continue; const cid = clusters.length; clusters.push(group); for (const m of group) cluster[m] = cid; } return { cluster, clusters }; }