File size: 2,930 Bytes
8125804
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from __future__ import annotations

from dataclasses import dataclass

import torch


@dataclass(frozen=True)
class AnchorProbeCase:
    name: str
    description: str
    input_ids: torch.Tensor
    target_ids: torch.Tensor
    expected_anchor_zone: tuple[int, int]
    expected_failure_mode: str


def _make_targets(input_ids: torch.Tensor) -> torch.Tensor:
    return torch.roll(input_ids, shifts=-1, dims=0)


def make_anchor_probe_cases(
    seq_len: int = 24,
    vocab_size: int = 512,
) -> list[AnchorProbeCase]:
    if seq_len < 16:
        raise ValueError('seq_len must be at least 16 for anchor probe cases')

    def clip(values: list[int]) -> torch.Tensor:
        return torch.tensor([v % vocab_size for v in values], dtype=torch.long)

    cases: list[AnchorProbeCase] = []

    stable = clip(
        [3] * 4 + [17] * 4 + [17] * 4 + [22] * 4 + [22] * 4 + [29] * 4
    )
    cases.append(
        AnchorProbeCase(
            name='stable_regime',
            description='Smooth regime shifts without a strong late contradiction pattern.',
            input_ids=stable,
            target_ids=_make_targets(stable),
            expected_anchor_zone=(4, 11),
            expected_failure_mode='none_or_low_pressure',
        )
    )

    quantifier_conflict = clip(
        [5] * 4 + [111] * 4 + [111] * 4 + [7] * 4 + [221] * 4 + [221] * 4
    )
    cases.append(
        AnchorProbeCase(
            name='quantifier_conflict',
            description='Early root stays coherent, then a later incompatible regime appears sharply.',
            input_ids=quantifier_conflict,
            target_ids=_make_targets(quantifier_conflict),
            expected_anchor_zone=(4, 11),
            expected_failure_mode='late_conflict',
        )
    )

    proof_mode_conflict = clip(
        [9] * 4 + [87] * 4 + [12] * 4 + [87] * 4 + [240] * 4 + [12] * 4
    )
    cases.append(
        AnchorProbeCase(
            name='proof_mode_conflict',
            description='Alternating motif with a late contradictory mode spike.',
            input_ids=proof_mode_conflict,
            target_ids=_make_targets(proof_mode_conflict),
            expected_anchor_zone=(4, 15),
            expected_failure_mode='mode_flip',
        )
    )

    complexity_conflict = clip(
        [14] * 4 + [14] * 4 + [41, 42, 43, 44] + [41, 42, 43, 44] + [250] * 4 + [41, 42, 43, 44]
    )
    cases.append(
        AnchorProbeCase(
            name='complexity_conflict',
            description='Structured mid-sequence pattern followed by a heavy late disruption.',
            input_ids=complexity_conflict,
            target_ids=_make_targets(complexity_conflict),
            expected_anchor_zone=(8, 15),
            expected_failure_mode='late_disruption',
        )
    )

    if any(case.input_ids.numel() != seq_len for case in cases):
        raise AssertionError('All anchor probe cases must match seq_len')

    return cases