File size: 3,196 Bytes
56fcc3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
/**
 * BIO-tag run-length decoder.
 *
 * The trained model emits per-wordpiece logits over three classes
 * ``[O, B, I]``. We argmax + run-length-decode into ``[start, end]``
 * span tuples (inclusive on both ends, in wordpiece coordinates).
 *
 * Mirrors ``_decode_bio`` in
 * ``infon/scripts/train_coref_pointer.py`` — we keep ``validOnly``
 * semantics intact so JS predictions and Python predictions decode
 * identically.
 */

/** BIO class indices, matching the trained model's head order. */
export const BIO_O = 0;
export const BIO_B = 1;
export const BIO_I = 2;

/**
 * Argmax + run-length-decode BIO logits into wordpiece spans.
 *
 * @param logits     Flat ``Float32Array`` of length ``T * 3``,
 *                   row-major over wordpieces. Class ordering must be
 *                   ``[O, B, I]``.
 * @param attention  Optional ``BigInt64Array`` mask ``(T,)`` — non-1
 *                   positions are ignored (always ``O``). When the
 *                   tokenizer pads to a fixed length pass this so we
 *                   don't decode spans inside padding.
 * @param threshold  If set, a wordpiece is only labeled ``B``/``I``
 *                   when its softmax probability for that class is
 *                   above the threshold. ``undefined`` = pure argmax.
 *                   Stricter thresholds reduce false-positive spans.
 * @returns Spans as ``[start, end]`` *inclusive* wordpiece indices,
 *          in document order. Drops orphan ``I`` (no preceding ``B``)
 *          — same convention as Python's ``valid_only=True``.
 */
export function decodeBio(
  logits: Float32Array,
  attention?: BigInt64Array,
  threshold?: number,
): [number, number][] {
  const T = (attention?.length ?? logits.length / 3) | 0;
  const labels = new Int32Array(T);
  for (let t = 0; t < T; t++) {
    if (attention && attention[t] === 0n) {
      labels[t] = BIO_O;
      continue;
    }
    const o = logits[t * 3 + BIO_O];
    const b = logits[t * 3 + BIO_B];
    const i = logits[t * 3 + BIO_I];

    if (threshold !== undefined) {
      // Softmax then threshold against the max non-O class. We don't
      // need numerically stable softmax for two classes — relative
      // ordering is enough — but we DO need to compare the chosen
      // class's prob to ``threshold``.
      const m = Math.max(o, b, i);
      const eo = Math.exp(o - m);
      const eb = Math.exp(b - m);
      const ei = Math.exp(i - m);
      const z = eo + eb + ei;
      const pb = eb / z;
      const pi = ei / z;
      if (pb > pi && pb >= threshold) {
        labels[t] = BIO_B;
      } else if (pi >= pb && pi >= threshold) {
        labels[t] = BIO_I;
      } else {
        labels[t] = BIO_O;
      }
    } else {
      labels[t] = b >= o && b >= i ? BIO_B : i >= o ? BIO_I : BIO_O;
    }
  }

  const spans: [number, number][] = [];
  let t = 0;
  while (t < T) {
    if (labels[t] === BIO_B) {
      let j = t + 1;
      while (j < T && labels[j] === BIO_I) j++;
      spans.push([t, j - 1]);
      t = j;
    } else {
      // Orphan I (no preceding B) is silently dropped — matches Python
      // ``valid_only=True``. O just advances.
      t++;
    }
  }
  return spans;
}