cp500's picture
Upload js/src/bio.ts with huggingface_hub
56fcc3f verified
Raw
History Blame Contribute Delete
3.2 kB
/**
* BIO-tag run-length decoder.
*
* The trained model emits per-wordpiece logits over three classes
* ``[O, B, I]``. We argmax + run-length-decode into ``[start, end]``
* span tuples (inclusive on both ends, in wordpiece coordinates).
*
* Mirrors ``_decode_bio`` in
* ``infon/scripts/train_coref_pointer.py`` β€” we keep ``validOnly``
* semantics intact so JS predictions and Python predictions decode
* identically.
*/
/** BIO class indices, matching the trained model's head order. */
export const BIO_O = 0;
export const BIO_B = 1;
export const BIO_I = 2;
/**
* Argmax + run-length-decode BIO logits into wordpiece spans.
*
* @param logits Flat ``Float32Array`` of length ``T * 3``,
* row-major over wordpieces. Class ordering must be
* ``[O, B, I]``.
* @param attention Optional ``BigInt64Array`` mask ``(T,)`` β€” non-1
* positions are ignored (always ``O``). When the
* tokenizer pads to a fixed length pass this so we
* don't decode spans inside padding.
* @param threshold If set, a wordpiece is only labeled ``B``/``I``
* when its softmax probability for that class is
* above the threshold. ``undefined`` = pure argmax.
* Stricter thresholds reduce false-positive spans.
* @returns Spans as ``[start, end]`` *inclusive* wordpiece indices,
* in document order. Drops orphan ``I`` (no preceding ``B``)
* β€” same convention as Python's ``valid_only=True``.
*/
export function decodeBio(
logits: Float32Array,
attention?: BigInt64Array,
threshold?: number,
): [number, number][] {
const T = (attention?.length ?? logits.length / 3) | 0;
const labels = new Int32Array(T);
for (let t = 0; t < T; t++) {
if (attention && attention[t] === 0n) {
labels[t] = BIO_O;
continue;
}
const o = logits[t * 3 + BIO_O];
const b = logits[t * 3 + BIO_B];
const i = logits[t * 3 + BIO_I];
if (threshold !== undefined) {
// Softmax then threshold against the max non-O class. We don't
// need numerically stable softmax for two classes β€” relative
// ordering is enough β€” but we DO need to compare the chosen
// class's prob to ``threshold``.
const m = Math.max(o, b, i);
const eo = Math.exp(o - m);
const eb = Math.exp(b - m);
const ei = Math.exp(i - m);
const z = eo + eb + ei;
const pb = eb / z;
const pi = ei / z;
if (pb > pi && pb >= threshold) {
labels[t] = BIO_B;
} else if (pi >= pb && pi >= threshold) {
labels[t] = BIO_I;
} else {
labels[t] = BIO_O;
}
} else {
labels[t] = b >= o && b >= i ? BIO_B : i >= o ? BIO_I : BIO_O;
}
}
const spans: [number, number][] = [];
let t = 0;
while (t < T) {
if (labels[t] === BIO_B) {
let j = t + 1;
while (j < T && labels[j] === BIO_I) j++;
spans.push([t, j - 1]);
t = j;
} else {
// Orphan I (no preceding B) is silently dropped β€” matches Python
// ``valid_only=True``. O just advances.
t++;
}
}
return spans;
}