infon-coref-pointer / js /src /pairs.ts
cp500's picture
Upload js/src/pairs.ts with huggingface_hub
060b8ea verified
Raw
History Blame Contribute Delete
4.67 kB
/**
* Pair-index builder + per-mention argmax + cluster grouping.
*
* Pure JS, no ORT, no DOM. Mirrors the Python helpers in
* ``infon/scripts/coref_onnx_experiment.py`` (``build_pairs`` /
* ``split_pairs_by_mention``) so the JS pipeline produces
* bit-identical pair tensors.
*/
/**
* Enumerate ``(i, j)`` candidate pairs for ``M`` mentions.
*
* For mention ``m`` (1-indexed because index 0 is DUMMY) we emit
* ``(m, 0), (m, 1), …, (m, m-1)`` — DUMMY first, then every earlier
* mention. This is the same triangular shape the Python
* ``build_pairs`` returns; the scorer ONNX expects this exact layout
* because the in-graph ``index_select`` over the prepended DUMMY
* relies on j=0 meaning "no antecedent."
*
* @param nMentions number of mentions in the doc
* @returns ``[pairI, pairJ]`` BigInt64 typed arrays of equal length.
* Lengths: ``M*(M+1)/2``.
*/
export function buildPairs(nMentions: number): [BigInt64Array, BigInt64Array] {
const pi: bigint[] = [];
const pj: bigint[] = [];
for (let m = 1; m <= nMentions; m++) {
pi.push(BigInt(m));
pj.push(0n);
for (let j = 1; j < m; j++) {
pi.push(BigInt(m));
pj.push(BigInt(j));
}
}
return [BigInt64Array.from(pi), BigInt64Array.from(pj)];
}
/**
* Group flat pair scores back into per-mention argmax decisions.
* Mirrors ``split_pairs_by_mention`` in the Python harness.
*
* @returns ``decisions[i]`` = the mention index (1-based) chosen as
* mention i's antecedent, or ``0`` for DUMMY (no antecedent).
* Translate to 0-based mention indices with ``decisions[i] - 1``.
*/
export function pickAntecedents(
nMentions: number,
pairI: BigInt64Array,
pairJ: BigInt64Array,
scores: Float32Array,
): { antecedent: number; score: number }[] {
const out: { antecedent: number; score: number }[] = [];
for (let m = 1; m <= nMentions; m++) {
let bestIdx = -1;
let bestScore = -Infinity;
for (let k = 0; k < pairI.length; k++) {
if (Number(pairI[k]) !== m) continue;
const s = scores[k];
if (s > bestScore) {
bestScore = s;
bestIdx = k;
}
}
out.push({
antecedent: bestIdx >= 0 ? Number(pairJ[bestIdx]) : 0,
score: bestScore,
});
}
return out;
}
/**
* Group antecedent decisions into clusters using union-find.
*
* Each mention either points to DUMMY (starts its own cluster) or to
* an earlier mention (joins that mention's cluster). Cluster IDs are
* dense 0-based; singletons are not assigned a cluster (returned as
* ``-1``) so callers can render them differently.
*
* @param decisions ``decisions[i].antecedent`` is the *1-based*
* mention index this mention links to, or ``0`` for
* DUMMY. (Same convention as the model output.)
* @returns
* - ``cluster[i]`` — cluster id for mention i, or -1 if singleton
* - ``clusters`` — list of multi-mention clusters, each a list of
* mention indices in document order
*/
export function groupClusters(
decisions: { antecedent: number }[],
): { cluster: number[]; clusters: number[][] } {
const n = decisions.length;
// Union-find. parent[i] points to a smaller-or-equal mention index.
const parent = Array.from({ length: n }, (_, i) => i);
const find = (x: number): number => {
while (parent[x] !== x) {
parent[x] = parent[parent[x]]; // path compression
x = parent[x];
}
return x;
};
const union = (a: number, b: number) => {
const ra = find(a);
const rb = find(b);
if (ra !== rb) {
// Always attach the higher-index root under the lower-index root
// so cluster representatives are first-mention.
if (ra < rb) parent[rb] = ra;
else parent[ra] = rb;
}
};
for (let i = 0; i < n; i++) {
const ant = decisions[i].antecedent;
if (ant > 0) {
// ant is 1-based; the mention it points to is ant - 1.
union(i, ant - 1);
}
}
// Bucket by root.
const roots: number[][] = [];
const rootIdx = new Map<number, number>();
for (let i = 0; i < n; i++) {
const r = find(i);
let idx = rootIdx.get(r);
if (idx === undefined) {
idx = roots.length;
roots.push([]);
rootIdx.set(r, idx);
}
roots[idx].push(i);
}
// Collapse: only multi-mention clusters get a stable id; singletons
// get -1.
const cluster = new Array<number>(n).fill(-1);
const clusters: number[][] = [];
for (const group of roots) {
if (group.length < 2) continue;
const cid = clusters.length;
clusters.push(group);
for (const m of group) cluster[m] = cid;
}
return { cluster, clusters };
}