Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| import torch | |
| from src.data.anchor_cases import AnchorProbeCase | |
| class SemanticTokenLegend: | |
| token_id: int | |
| label: str | |
| role: str | |
| def semantic_token_legend() -> list[SemanticTokenLegend]: | |
| return [ | |
| SemanticTokenLegend(11, 'FORALL', 'quantifier_root'), | |
| SemanticTokenLegend(12, 'EXISTS', 'conflicting_quantifier'), | |
| SemanticTokenLegend(13, 'VAR_N', 'bound_variable'), | |
| SemanticTokenLegend(14, 'CLAIM', 'assertion'), | |
| SemanticTokenLegend(15, 'STEP', 'proof_step'), | |
| SemanticTokenLegend(16, 'UNIFORM', 'uniform_descendant'), | |
| SemanticTokenLegend(17, 'WITNESS', 'existential_descendant'), | |
| SemanticTokenLegend(21, 'ASSUME_NOT', 'contradiction_root'), | |
| SemanticTokenLegend(22, 'DERIVE', 'contradiction_step'), | |
| SemanticTokenLegend(23, 'CONTRADICTION', 'contradiction_close'), | |
| SemanticTokenLegend(24, 'DIRECT', 'direct_mode_conflict'), | |
| SemanticTokenLegend(31, 'CONST', 'complexity_root'), | |
| SemanticTokenLegend(32, 'LOOKUP', 'constant_time_descendant'), | |
| SemanticTokenLegend(33, 'CACHE', 'constant_time_support'), | |
| SemanticTokenLegend(34, 'LOOP', 'linear_conflict'), | |
| SemanticTokenLegend(35, 'SCAN', 'linear_descendant'), | |
| SemanticTokenLegend(41, 'INDUCT', 'induction_root'), | |
| SemanticTokenLegend(42, 'BASE', 'induction_base_case'), | |
| SemanticTokenLegend(43, 'STEP_K', 'induction_step'), | |
| SemanticTokenLegend(44, 'STEP_K1', 'induction_successor'), | |
| SemanticTokenLegend(45, 'EXAMPLE', 'finite_example_conflict'), | |
| SemanticTokenLegend(51, 'EPS', 'epsilon_root'), | |
| SemanticTokenLegend(52, 'DELTA', 'delta_descendant'), | |
| SemanticTokenLegend(53, 'BOUND', 'formal_bound_descendant'), | |
| SemanticTokenLegend(54, 'CLOSE', 'intuitive_closeness_conflict'), | |
| ] | |
| def _make_targets(input_ids: torch.Tensor) -> torch.Tensor: | |
| return torch.roll(input_ids, shifts=-1, dims=0) | |
| def make_semantic_anchor_cases(seq_len: int = 24) -> list[AnchorProbeCase]: | |
| if seq_len != 24: | |
| raise ValueError('semantic probe cases currently require seq_len=24') | |
| def t(values: list[int]) -> torch.Tensor: | |
| return torch.tensor(values, dtype=torch.long) | |
| cases: list[AnchorProbeCase] = [] | |
| forall_stable = t([ | |
| 11, 13, 11, 16, | |
| 11, 13, 11, 16, | |
| 11, 13, 11, 16, | |
| 11, 13, 11, 16, | |
| 11, 13, 11, 16, | |
| 11, 13, 11, 16, | |
| ]) | |
| cases.append( | |
| AnchorProbeCase( | |
| name='forall_stable', | |
| description='FORALL root followed by uniform descendants that stay semantically consistent.', | |
| input_ids=forall_stable, | |
| target_ids=_make_targets(forall_stable), | |
| expected_anchor_zone=(0, 11), | |
| expected_failure_mode='stable_quantifier_tree', | |
| ) | |
| ) | |
| forall_exists_conflict = t([ | |
| 11, 13, 11, 16, | |
| 11, 13, 11, 16, | |
| 11, 13, 11, 16, | |
| 12, 17, 12, 17, | |
| 12, 17, 12, 17, | |
| 12, 17, 12, 17, | |
| ]) | |
| cases.append( | |
| AnchorProbeCase( | |
| name='forall_exists_conflict', | |
| description='FORALL root later flips into EXISTS-style descendants.', | |
| input_ids=forall_exists_conflict, | |
| target_ids=_make_targets(forall_exists_conflict), | |
| expected_anchor_zone=(0, 11), | |
| expected_failure_mode='quantifier_flip', | |
| ) | |
| ) | |
| contradiction_stable = t([ | |
| 21, 14, 22, 15, | |
| 21, 14, 22, 15, | |
| 21, 14, 22, 15, | |
| 21, 14, 23, 15, | |
| 21, 14, 23, 15, | |
| 21, 14, 23, 15, | |
| ]) | |
| cases.append( | |
| AnchorProbeCase( | |
| name='contradiction_stable', | |
| description='ASSUME_NOT mode stays in contradiction style until closure.', | |
| input_ids=contradiction_stable, | |
| target_ids=_make_targets(contradiction_stable), | |
| expected_anchor_zone=(0, 15), | |
| expected_failure_mode='stable_contradiction_tree', | |
| ) | |
| ) | |
| contradiction_direct_conflict = t([ | |
| 21, 14, 22, 15, | |
| 21, 14, 22, 15, | |
| 21, 14, 22, 15, | |
| 24, 14, 15, 15, | |
| 24, 14, 15, 15, | |
| 24, 14, 15, 15, | |
| ]) | |
| cases.append( | |
| AnchorProbeCase( | |
| name='contradiction_direct_conflict', | |
| description='ASSUME_NOT root later drifts into direct-proof style descendants.', | |
| input_ids=contradiction_direct_conflict, | |
| target_ids=_make_targets(contradiction_direct_conflict), | |
| expected_anchor_zone=(0, 11), | |
| expected_failure_mode='proof_mode_flip', | |
| ) | |
| ) | |
| const_vs_loop_conflict = t([ | |
| 31, 32, 33, 15, | |
| 31, 32, 33, 15, | |
| 31, 32, 33, 15, | |
| 34, 35, 35, 15, | |
| 34, 35, 35, 15, | |
| 34, 35, 35, 15, | |
| ]) | |
| cases.append( | |
| AnchorProbeCase( | |
| name='const_vs_loop_conflict', | |
| description='Constant-time root later drifts into loop/scan descendants.', | |
| input_ids=const_vs_loop_conflict, | |
| target_ids=_make_targets(const_vs_loop_conflict), | |
| expected_anchor_zone=(0, 11), | |
| expected_failure_mode='complexity_flip', | |
| ) | |
| ) | |
| induction_stable = t([ | |
| 41, 42, 43, 44, | |
| 41, 42, 43, 44, | |
| 41, 42, 43, 44, | |
| 41, 42, 43, 44, | |
| 41, 42, 43, 44, | |
| 41, 42, 43, 44, | |
| ]) | |
| cases.append( | |
| AnchorProbeCase( | |
| name='induction_stable', | |
| description='INDUCT root followed by repeated base/step/successor structure.', | |
| input_ids=induction_stable, | |
| target_ids=_make_targets(induction_stable), | |
| expected_anchor_zone=(0, 11), | |
| expected_failure_mode='stable_induction_tree', | |
| ) | |
| ) | |
| induction_example_conflict = t([ | |
| 41, 42, 43, 44, | |
| 41, 42, 43, 44, | |
| 41, 42, 43, 44, | |
| 45, 45, 45, 15, | |
| 45, 45, 45, 15, | |
| 45, 45, 45, 15, | |
| ]) | |
| cases.append( | |
| AnchorProbeCase( | |
| name='induction_example_conflict', | |
| description='INDUCT root later collapses into repeated finite examples instead of inductive structure.', | |
| input_ids=induction_example_conflict, | |
| target_ids=_make_targets(induction_example_conflict), | |
| expected_anchor_zone=(0, 11), | |
| expected_failure_mode='induction_to_examples', | |
| ) | |
| ) | |
| epsilon_delta_stable = t([ | |
| 51, 52, 53, 15, | |
| 51, 52, 53, 15, | |
| 51, 52, 53, 15, | |
| 51, 52, 53, 15, | |
| 51, 52, 53, 15, | |
| 51, 52, 53, 15, | |
| ]) | |
| cases.append( | |
| AnchorProbeCase( | |
| name='epsilon_delta_stable', | |
| description='EPS root followed by DELTA/BOUND descendants that remain formal.', | |
| input_ids=epsilon_delta_stable, | |
| target_ids=_make_targets(epsilon_delta_stable), | |
| expected_anchor_zone=(0, 11), | |
| expected_failure_mode='stable_formal_limit_tree', | |
| ) | |
| ) | |
| epsilon_close_conflict = t([ | |
| 51, 52, 53, 15, | |
| 51, 52, 53, 15, | |
| 51, 52, 53, 15, | |
| 54, 54, 54, 15, | |
| 54, 54, 54, 15, | |
| 54, 54, 54, 15, | |
| ]) | |
| cases.append( | |
| AnchorProbeCase( | |
| name='epsilon_close_conflict', | |
| description='EPS root later drifts into intuitive closeness language instead of formal bound structure.', | |
| input_ids=epsilon_close_conflict, | |
| target_ids=_make_targets(epsilon_close_conflict), | |
| expected_anchor_zone=(0, 11), | |
| expected_failure_mode='formal_to_intuitive_drift', | |
| ) | |
| ) | |
| return cases | |