File size: 2,824 Bytes
57f9808
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
"""Reversed-digit addition data. The reversed format (least-significant digit
first) is the Lee et al. 2023 trick: the model emits the LSB first, matching the
direction carries actually propagate, which is what makes exact addition
learnable at tiny scale. Every number is a fixed-width S-cell strip, zero-padded
in the high positions, so cell i is always digit i (10^i place)."""
import numpy as np

def digits_rev(x, S):
    out = np.zeros(S, dtype=np.int64)
    for i in range(S):
        out[i] = x % 10; x //= 10
    return out

def make_batch(Bn, S, rng, maxdig=None):
    """A,B,Y int grids (Bn,S), reversed digits. Operand lengths ~Uniform[1,maxdig]."""
    maxdig = maxdig or (S - 1)
    la = rng.integers(1, maxdig + 1, Bn)
    lb = rng.integers(1, maxdig + 1, Bn)
    A = np.empty((Bn, S), np.int64); B = np.empty((Bn, S), np.int64); Y = np.empty((Bn, S), np.int64)
    for n in range(Bn):
        a = int(rng.integers(0, 10 ** int(la[n])))
        b = int(rng.integers(0, 10 ** int(lb[n])))
        A[n] = digits_rev(a, S); B[n] = digits_rev(b, S); Y[n] = digits_rev(a + b, S)
    return A, B, Y

def gen_hard(Bn, S, rng):
    """Carry-heavy adversarial pairs the uniform sampler almost never produces:
    long ripple chains. These teach the rare 'receive a carry and emit a new
    most-significant digit' behaviour that full-length ripples (e.g. 9999999+1)
    depend on. All sums are constructed to fit in S cells."""
    A = np.zeros((Bn, S), np.int64); B = np.zeros((Bn, S), np.int64)
    for n in range(Bn):
        L = int(rng.integers(2, S))               # chain length, <= S-1
        t = rng.random()
        if t < 0.45:                              # sum-to-9 chain: any carry ripples the whole way
            a = rng.integers(0, 10, L); b = 9 - a
        elif t < 0.75:                            # all-nines operand(s)
            a = np.full(L, 9)
            b = np.full(L, 9) if rng.random() < 0.5 else rng.integers(0, 10, L)
        else:                                     # random short
            a = rng.integers(0, 10, L); b = rng.integers(0, 10, L)
        A[n, :L] = a; B[n, :L] = b
    pw = 10 ** np.arange(S)
    ai = (A * pw).sum(1); bi = (B * pw).sum(1)
    bi = bi + (rng.random(Bn) < 0.6).astype(np.int64)   # +1 trigger sets off the ripple
    yi = ai + bi
    A = np.array([digits_rev(int(x), S) for x in ai])
    B = np.array([digits_rev(int(x), S) for x in bi])
    Y = np.array([digits_rev(int(x), S) for x in yi])
    return A, B, Y

def to_int(grid):
    """reversed digit grid (Bn,S) -> python ints"""
    Bn, S = grid.shape
    return [int(sum(int(grid[n, i]) * (10 ** i) for i in range(S))) for n in range(Bn)]

def exact_match(pred, Y):
    """fraction of rows where ALL digits match (true exact-match accuracy)."""
    return float((pred == Y).all(axis=1).mean())