|
|
""" |
|
|
Statistical significance tests for speculative decoding experiment. |
|
|
|
|
|
Performs chi-square, ANOVA, and t-tests to validate documented findings. |
|
|
|
|
|
Author: Claude Code |
|
|
Date: 2025-11-30 |
|
|
""" |
|
|
|
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
from scipy import stats |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Tuple |
|
|
|
|
|
|
|
|
DATA_DIR = Path(__file__).parent.parent / "data" |
|
|
RESULTS_DIR = Path(__file__).parent.parent / "results" / "statistics" |
|
|
RESULTS_DIR.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
def chi_square_domain_independence(df: pd.DataFrame) -> Dict: |
|
|
"""Test if rejection rate is independent of domain.""" |
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("Chi-Square Test: Domain Independence") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
contingency = pd.crosstab(df['domain'], df['is_rejected']) |
|
|
|
|
|
|
|
|
chi2, p_value, dof, expected = stats.chi2_contingency(contingency) |
|
|
|
|
|
print(f"\nContingency Table:") |
|
|
print(contingency) |
|
|
print(f"\nChi-square statistic: {chi2:.2f}") |
|
|
print(f"Degrees of freedom: {dof}") |
|
|
print(f"p-value: {p_value:.2e}") |
|
|
|
|
|
if p_value < 0.001: |
|
|
print("✅ Result: HIGHLY SIGNIFICANT (p < 0.001)") |
|
|
print(" Rejection rate is strongly domain-dependent") |
|
|
else: |
|
|
print("⚠️ Result: Not significant") |
|
|
|
|
|
return { |
|
|
'test': 'chi_square_domain', |
|
|
'chi2': chi2, |
|
|
'dof': dof, |
|
|
'p_value': p_value, |
|
|
'significant': p_value < 0.05 |
|
|
} |
|
|
|
|
|
|
|
|
def anova_position_effect(df: pd.DataFrame) -> Dict: |
|
|
"""Test if rejection rate varies by token position.""" |
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("ANOVA: Position Effect") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
df['position_bin'] = pd.cut( |
|
|
df['token_position'], |
|
|
bins=[0, 20, 100, np.inf], |
|
|
labels=['early', 'mid', 'late'] |
|
|
) |
|
|
|
|
|
|
|
|
groups = [] |
|
|
for position in ['early', 'mid', 'late']: |
|
|
group_data = df[df['position_bin'] == position]['is_rejected'] |
|
|
groups.append(group_data) |
|
|
print(f"{position:8s}: {group_data.mean():.3f} (n={len(group_data):,})") |
|
|
|
|
|
|
|
|
f_stat, p_value = stats.f_oneway(*groups) |
|
|
|
|
|
print(f"\nF-statistic: {f_stat:.2f}") |
|
|
print(f"p-value: {p_value:.2e}") |
|
|
|
|
|
if p_value < 0.001: |
|
|
print("✅ Result: HIGHLY SIGNIFICANT (p < 0.001)") |
|
|
print(" Position significantly affects rejection rate") |
|
|
else: |
|
|
print("⚠️ Result: Not significant") |
|
|
|
|
|
return { |
|
|
'test': 'anova_position', |
|
|
'f_statistic': f_stat, |
|
|
'p_value': p_value, |
|
|
'significant': p_value < 0.05 |
|
|
} |
|
|
|
|
|
|
|
|
def ttest_frequency_effect(df: pd.DataFrame) -> Dict: |
|
|
"""Test if rare tokens are rejected more than common tokens.""" |
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("T-Test: Frequency Effect") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
rare = df[df['token_frequency_pct'] < 0.01]['is_rejected'] |
|
|
common = df[df['token_frequency_pct'] > 1.0]['is_rejected'] |
|
|
|
|
|
print(f"Rare tokens (<0.01%): {rare.mean():.3f} (n={len(rare):,})") |
|
|
print(f"Common tokens (>1%): {common.mean():.3f} (n={len(common):,})") |
|
|
print(f"Difference: {rare.mean() - common.mean():.3f}") |
|
|
|
|
|
|
|
|
t_stat, p_value = stats.ttest_ind(rare, common) |
|
|
|
|
|
print(f"\nT-statistic: {t_stat:.3f}") |
|
|
print(f"p-value: {p_value:.3f}") |
|
|
|
|
|
if p_value < 0.05: |
|
|
print("✅ Result: SIGNIFICANT (p < 0.05)") |
|
|
print(" Frequency effect exists but is small") |
|
|
else: |
|
|
print("⚠️ Result: Not significant") |
|
|
|
|
|
return { |
|
|
'test': 'ttest_frequency', |
|
|
't_statistic': t_stat, |
|
|
'p_value': p_value, |
|
|
'significant': p_value < 0.05 |
|
|
} |
|
|
|
|
|
|
|
|
def ablation_mask_comparisons(df: pd.DataFrame) -> List[Dict]: |
|
|
"""Pairwise t-tests comparing each mask to causal baseline.""" |
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("T-Tests: Mask Comparisons vs Causal Baseline") |
|
|
print("=" * 60) |
|
|
|
|
|
results = [] |
|
|
|
|
|
for domain in ['code', 'math', 'translation']: |
|
|
print(f"\n--- {domain.upper()} ---") |
|
|
|
|
|
|
|
|
causal = df[(df['domain'] == domain) & (df['mask_type'] == 'causal')]['is_accepted'] |
|
|
|
|
|
for mask in ['tidar', 'bidirectional', 'windowed', 'strided']: |
|
|
mask_data = df[(df['domain'] == domain) & (df['mask_type'] == mask)]['is_accepted'] |
|
|
|
|
|
if len(mask_data) == 0: |
|
|
continue |
|
|
|
|
|
t_stat, p_value = stats.ttest_ind(mask_data, causal) |
|
|
|
|
|
sig_marker = "✅" if p_value < 0.05 else " " |
|
|
better_worse = "better" if mask_data.mean() > causal.mean() else "worse" |
|
|
|
|
|
print(f"{sig_marker} {mask:15s}: t={t_stat:6.3f}, p={p_value:.3f} ({better_worse})") |
|
|
|
|
|
results.append({ |
|
|
'domain': domain, |
|
|
'mask': mask, |
|
|
'baseline': 'causal', |
|
|
't_statistic': t_stat, |
|
|
'p_value': p_value, |
|
|
'significant': p_value < 0.05 |
|
|
}) |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Run all statistical tests.""" |
|
|
|
|
|
print("=" * 60) |
|
|
print("Statistical Significance Testing") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
print("\nLoading data...") |
|
|
cross_domain_df = pd.read_csv(DATA_DIR / "phase1_cross_domain.csv") |
|
|
ablation_df = pd.read_csv(DATA_DIR / "phase3_ablation.csv") |
|
|
print(f"✅ Cross-domain: {len(cross_domain_df):,} tokens") |
|
|
print(f"✅ Ablation: {len(ablation_df):,} tokens") |
|
|
|
|
|
|
|
|
all_results = [] |
|
|
|
|
|
|
|
|
result = chi_square_domain_independence(cross_domain_df) |
|
|
all_results.append(result) |
|
|
|
|
|
|
|
|
result = anova_position_effect(cross_domain_df) |
|
|
all_results.append(result) |
|
|
|
|
|
|
|
|
result = ttest_frequency_effect(cross_domain_df) |
|
|
all_results.append(result) |
|
|
|
|
|
|
|
|
ablation_results = ablation_mask_comparisons(ablation_df) |
|
|
all_results.extend(ablation_results) |
|
|
|
|
|
|
|
|
results_df = pd.DataFrame(all_results) |
|
|
output_path = RESULTS_DIR / "significance_tests.csv" |
|
|
results_df.to_csv(output_path, index=False) |
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print(f"✅ All tests complete! Results saved to:") |
|
|
print(f" {output_path}") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
print("\n=== Summary ===") |
|
|
significant_count = sum(1 for r in all_results if r.get('significant', False)) |
|
|
print(f"Total tests: {len(all_results)}") |
|
|
print(f"Significant (p < 0.05): {significant_count}") |
|
|
print(f"Not significant: {len(all_results) - significant_count}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|