/** * BIO-tag run-length decoder. * * The trained model emits per-wordpiece logits over three classes * ``[O, B, I]``. We argmax + run-length-decode into ``[start, end]`` * span tuples (inclusive on both ends, in wordpiece coordinates). * * Mirrors ``_decode_bio`` in * ``infon/scripts/train_coref_pointer.py`` — we keep ``validOnly`` * semantics intact so JS predictions and Python predictions decode * identically. */ /** BIO class indices, matching the trained model's head order. */ export const BIO_O = 0; export const BIO_B = 1; export const BIO_I = 2; /** * Argmax + run-length-decode BIO logits into wordpiece spans. * * @param logits Flat ``Float32Array`` of length ``T * 3``, * row-major over wordpieces. Class ordering must be * ``[O, B, I]``. * @param attention Optional ``BigInt64Array`` mask ``(T,)`` — non-1 * positions are ignored (always ``O``). When the * tokenizer pads to a fixed length pass this so we * don't decode spans inside padding. * @param threshold If set, a wordpiece is only labeled ``B``/``I`` * when its softmax probability for that class is * above the threshold. ``undefined`` = pure argmax. * Stricter thresholds reduce false-positive spans. * @returns Spans as ``[start, end]`` *inclusive* wordpiece indices, * in document order. Drops orphan ``I`` (no preceding ``B``) * — same convention as Python's ``valid_only=True``. */ export function decodeBio( logits: Float32Array, attention?: BigInt64Array, threshold?: number, ): [number, number][] { const T = (attention?.length ?? logits.length / 3) | 0; const labels = new Int32Array(T); for (let t = 0; t < T; t++) { if (attention && attention[t] === 0n) { labels[t] = BIO_O; continue; } const o = logits[t * 3 + BIO_O]; const b = logits[t * 3 + BIO_B]; const i = logits[t * 3 + BIO_I]; if (threshold !== undefined) { // Softmax then threshold against the max non-O class. We don't // need numerically stable softmax for two classes — relative // ordering is enough — but we DO need to compare the chosen // class's prob to ``threshold``. const m = Math.max(o, b, i); const eo = Math.exp(o - m); const eb = Math.exp(b - m); const ei = Math.exp(i - m); const z = eo + eb + ei; const pb = eb / z; const pi = ei / z; if (pb > pi && pb >= threshold) { labels[t] = BIO_B; } else if (pi >= pb && pi >= threshold) { labels[t] = BIO_I; } else { labels[t] = BIO_O; } } else { labels[t] = b >= o && b >= i ? BIO_B : i >= o ? BIO_I : BIO_O; } } const spans: [number, number][] = []; let t = 0; while (t < T) { if (labels[t] === BIO_B) { let j = t + 1; while (j < T && labels[j] === BIO_I) j++; spans.push([t, j - 1]); t = j; } else { // Orphan I (no preceding B) is silently dropped — matches Python // ``valid_only=True``. O just advances. t++; } } return spans; }