abpt / src /data /anchor_semantic_cases.py
Search
feat: add src/ module for script imports
8125804
from __future__ import annotations
from dataclasses import dataclass
import torch
from src.data.anchor_cases import AnchorProbeCase
@dataclass(frozen=True)
class SemanticTokenLegend:
token_id: int
label: str
role: str
def semantic_token_legend() -> list[SemanticTokenLegend]:
return [
SemanticTokenLegend(11, 'FORALL', 'quantifier_root'),
SemanticTokenLegend(12, 'EXISTS', 'conflicting_quantifier'),
SemanticTokenLegend(13, 'VAR_N', 'bound_variable'),
SemanticTokenLegend(14, 'CLAIM', 'assertion'),
SemanticTokenLegend(15, 'STEP', 'proof_step'),
SemanticTokenLegend(16, 'UNIFORM', 'uniform_descendant'),
SemanticTokenLegend(17, 'WITNESS', 'existential_descendant'),
SemanticTokenLegend(21, 'ASSUME_NOT', 'contradiction_root'),
SemanticTokenLegend(22, 'DERIVE', 'contradiction_step'),
SemanticTokenLegend(23, 'CONTRADICTION', 'contradiction_close'),
SemanticTokenLegend(24, 'DIRECT', 'direct_mode_conflict'),
SemanticTokenLegend(31, 'CONST', 'complexity_root'),
SemanticTokenLegend(32, 'LOOKUP', 'constant_time_descendant'),
SemanticTokenLegend(33, 'CACHE', 'constant_time_support'),
SemanticTokenLegend(34, 'LOOP', 'linear_conflict'),
SemanticTokenLegend(35, 'SCAN', 'linear_descendant'),
SemanticTokenLegend(41, 'INDUCT', 'induction_root'),
SemanticTokenLegend(42, 'BASE', 'induction_base_case'),
SemanticTokenLegend(43, 'STEP_K', 'induction_step'),
SemanticTokenLegend(44, 'STEP_K1', 'induction_successor'),
SemanticTokenLegend(45, 'EXAMPLE', 'finite_example_conflict'),
SemanticTokenLegend(51, 'EPS', 'epsilon_root'),
SemanticTokenLegend(52, 'DELTA', 'delta_descendant'),
SemanticTokenLegend(53, 'BOUND', 'formal_bound_descendant'),
SemanticTokenLegend(54, 'CLOSE', 'intuitive_closeness_conflict'),
]
def _make_targets(input_ids: torch.Tensor) -> torch.Tensor:
return torch.roll(input_ids, shifts=-1, dims=0)
def make_semantic_anchor_cases(seq_len: int = 24) -> list[AnchorProbeCase]:
if seq_len != 24:
raise ValueError('semantic probe cases currently require seq_len=24')
def t(values: list[int]) -> torch.Tensor:
return torch.tensor(values, dtype=torch.long)
cases: list[AnchorProbeCase] = []
forall_stable = t([
11, 13, 11, 16,
11, 13, 11, 16,
11, 13, 11, 16,
11, 13, 11, 16,
11, 13, 11, 16,
11, 13, 11, 16,
])
cases.append(
AnchorProbeCase(
name='forall_stable',
description='FORALL root followed by uniform descendants that stay semantically consistent.',
input_ids=forall_stable,
target_ids=_make_targets(forall_stable),
expected_anchor_zone=(0, 11),
expected_failure_mode='stable_quantifier_tree',
)
)
forall_exists_conflict = t([
11, 13, 11, 16,
11, 13, 11, 16,
11, 13, 11, 16,
12, 17, 12, 17,
12, 17, 12, 17,
12, 17, 12, 17,
])
cases.append(
AnchorProbeCase(
name='forall_exists_conflict',
description='FORALL root later flips into EXISTS-style descendants.',
input_ids=forall_exists_conflict,
target_ids=_make_targets(forall_exists_conflict),
expected_anchor_zone=(0, 11),
expected_failure_mode='quantifier_flip',
)
)
contradiction_stable = t([
21, 14, 22, 15,
21, 14, 22, 15,
21, 14, 22, 15,
21, 14, 23, 15,
21, 14, 23, 15,
21, 14, 23, 15,
])
cases.append(
AnchorProbeCase(
name='contradiction_stable',
description='ASSUME_NOT mode stays in contradiction style until closure.',
input_ids=contradiction_stable,
target_ids=_make_targets(contradiction_stable),
expected_anchor_zone=(0, 15),
expected_failure_mode='stable_contradiction_tree',
)
)
contradiction_direct_conflict = t([
21, 14, 22, 15,
21, 14, 22, 15,
21, 14, 22, 15,
24, 14, 15, 15,
24, 14, 15, 15,
24, 14, 15, 15,
])
cases.append(
AnchorProbeCase(
name='contradiction_direct_conflict',
description='ASSUME_NOT root later drifts into direct-proof style descendants.',
input_ids=contradiction_direct_conflict,
target_ids=_make_targets(contradiction_direct_conflict),
expected_anchor_zone=(0, 11),
expected_failure_mode='proof_mode_flip',
)
)
const_vs_loop_conflict = t([
31, 32, 33, 15,
31, 32, 33, 15,
31, 32, 33, 15,
34, 35, 35, 15,
34, 35, 35, 15,
34, 35, 35, 15,
])
cases.append(
AnchorProbeCase(
name='const_vs_loop_conflict',
description='Constant-time root later drifts into loop/scan descendants.',
input_ids=const_vs_loop_conflict,
target_ids=_make_targets(const_vs_loop_conflict),
expected_anchor_zone=(0, 11),
expected_failure_mode='complexity_flip',
)
)
induction_stable = t([
41, 42, 43, 44,
41, 42, 43, 44,
41, 42, 43, 44,
41, 42, 43, 44,
41, 42, 43, 44,
41, 42, 43, 44,
])
cases.append(
AnchorProbeCase(
name='induction_stable',
description='INDUCT root followed by repeated base/step/successor structure.',
input_ids=induction_stable,
target_ids=_make_targets(induction_stable),
expected_anchor_zone=(0, 11),
expected_failure_mode='stable_induction_tree',
)
)
induction_example_conflict = t([
41, 42, 43, 44,
41, 42, 43, 44,
41, 42, 43, 44,
45, 45, 45, 15,
45, 45, 45, 15,
45, 45, 45, 15,
])
cases.append(
AnchorProbeCase(
name='induction_example_conflict',
description='INDUCT root later collapses into repeated finite examples instead of inductive structure.',
input_ids=induction_example_conflict,
target_ids=_make_targets(induction_example_conflict),
expected_anchor_zone=(0, 11),
expected_failure_mode='induction_to_examples',
)
)
epsilon_delta_stable = t([
51, 52, 53, 15,
51, 52, 53, 15,
51, 52, 53, 15,
51, 52, 53, 15,
51, 52, 53, 15,
51, 52, 53, 15,
])
cases.append(
AnchorProbeCase(
name='epsilon_delta_stable',
description='EPS root followed by DELTA/BOUND descendants that remain formal.',
input_ids=epsilon_delta_stable,
target_ids=_make_targets(epsilon_delta_stable),
expected_anchor_zone=(0, 11),
expected_failure_mode='stable_formal_limit_tree',
)
)
epsilon_close_conflict = t([
51, 52, 53, 15,
51, 52, 53, 15,
51, 52, 53, 15,
54, 54, 54, 15,
54, 54, 54, 15,
54, 54, 54, 15,
])
cases.append(
AnchorProbeCase(
name='epsilon_close_conflict',
description='EPS root later drifts into intuitive closeness language instead of formal bound structure.',
input_ids=epsilon_close_conflict,
target_ids=_make_targets(epsilon_close_conflict),
expected_anchor_zone=(0, 11),
expected_failure_mode='formal_to_intuitive_drift',
)
)
return cases