File size: 2,599 Bytes
42c0d23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#!/usr/bin/env bash
# Post-processing pipeline: run after all Camelyon17 experiments complete
# Usage: bash scripts/post_process.sh

set -euo pipefail

ROOT="$(cd "$(dirname "$0")/.." && pwd)"
cd "${ROOT}"
source scripts/lib/env.sh
resolve_python

echo "════════════════════════════════════════════════════════════"
echo "  CausalGrok Post-Processing Pipeline"
echo "════════════════════════════════════════════════════════════"

# Step 1: Generate all figures from run histories
echo ""
echo "[1/3] Generating paper figures from run histories..."
"${PYTHON}" -m experiments.plot_results --save_dir paper_figures
echo "  βœ“ Figures saved to paper_figures/"

# Step 2: Run M1 mechanistic interpretability on all runs
echo ""
echo "[2/3] Running M1 analysis (layer-wise probing)..."
"${PYTHON}" -m experiments.mechinterp_m1 \
    --all_runs \
    --data_root data/wilds \
    --latest_only
echo "  βœ“ M1 analysis complete"

# Step 3: Print summary table
echo ""
echo "[3/3] Final results summary:"
echo ""
"${PYTHON}" << 'EOF'
import json
import glob
from pathlib import Path

print(f"{'Run ID':<50} | {'OOD Best':>8} | {'Improvement':>11} | {'Grokking Epoch':>15}")
print("-" * 100)

results = []
for f in sorted(glob.glob("experiments/runs/*camelyon*/results/summary.json")):
    try:
        s = json.load(open(f))
        rid = Path(f).parent.parent.name[:48]
        best_ood = s.get("best_ood", 0)
        impr = s.get("ood_improvement", 0)
        grok_ep = s.get("grokking_epoch", -1)

        results.append((rid, best_ood, impr, grok_ep))

        ep_str = str(int(grok_ep)) if grok_ep > 0 else "β€”"
        print(f"{rid:<50} | {best_ood:>8.3f} | {impr:>+10.3f} | {ep_str:>15}")
    except Exception as e:
        pass

print("")
if results:
    best_idx = max(range(len(results)), key=lambda i: results[i][1])
    best_rid, best_ood, _, _ = results[best_idx]
    print(f"Best OOD performance: {best_rid} ({best_ood:.3f})")
EOF

echo ""
echo "════════════════════════════════════════════════════════════"
echo "  Post-processing complete!"
echo "════════════════════════════════════════════════════════════"