Spaces:
Sleeping
Sleeping
File size: 2,930 Bytes
8125804 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 | from __future__ import annotations
from dataclasses import dataclass
import torch
@dataclass(frozen=True)
class AnchorProbeCase:
name: str
description: str
input_ids: torch.Tensor
target_ids: torch.Tensor
expected_anchor_zone: tuple[int, int]
expected_failure_mode: str
def _make_targets(input_ids: torch.Tensor) -> torch.Tensor:
return torch.roll(input_ids, shifts=-1, dims=0)
def make_anchor_probe_cases(
seq_len: int = 24,
vocab_size: int = 512,
) -> list[AnchorProbeCase]:
if seq_len < 16:
raise ValueError('seq_len must be at least 16 for anchor probe cases')
def clip(values: list[int]) -> torch.Tensor:
return torch.tensor([v % vocab_size for v in values], dtype=torch.long)
cases: list[AnchorProbeCase] = []
stable = clip(
[3] * 4 + [17] * 4 + [17] * 4 + [22] * 4 + [22] * 4 + [29] * 4
)
cases.append(
AnchorProbeCase(
name='stable_regime',
description='Smooth regime shifts without a strong late contradiction pattern.',
input_ids=stable,
target_ids=_make_targets(stable),
expected_anchor_zone=(4, 11),
expected_failure_mode='none_or_low_pressure',
)
)
quantifier_conflict = clip(
[5] * 4 + [111] * 4 + [111] * 4 + [7] * 4 + [221] * 4 + [221] * 4
)
cases.append(
AnchorProbeCase(
name='quantifier_conflict',
description='Early root stays coherent, then a later incompatible regime appears sharply.',
input_ids=quantifier_conflict,
target_ids=_make_targets(quantifier_conflict),
expected_anchor_zone=(4, 11),
expected_failure_mode='late_conflict',
)
)
proof_mode_conflict = clip(
[9] * 4 + [87] * 4 + [12] * 4 + [87] * 4 + [240] * 4 + [12] * 4
)
cases.append(
AnchorProbeCase(
name='proof_mode_conflict',
description='Alternating motif with a late contradictory mode spike.',
input_ids=proof_mode_conflict,
target_ids=_make_targets(proof_mode_conflict),
expected_anchor_zone=(4, 15),
expected_failure_mode='mode_flip',
)
)
complexity_conflict = clip(
[14] * 4 + [14] * 4 + [41, 42, 43, 44] + [41, 42, 43, 44] + [250] * 4 + [41, 42, 43, 44]
)
cases.append(
AnchorProbeCase(
name='complexity_conflict',
description='Structured mid-sequence pattern followed by a heavy late disruption.',
input_ids=complexity_conflict,
target_ids=_make_targets(complexity_conflict),
expected_anchor_zone=(8, 15),
expected_failure_mode='late_disruption',
)
)
if any(case.input_ids.numel() != seq_len for case in cases):
raise AssertionError('All anchor probe cases must match seq_len')
return cases
|