File size: 6,068 Bytes
641ae8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
#!/usr/bin/env python3
"""
SBERT-Only Ablation Baseline.

This script answers the most important question about Tensegrity:
"Does the cognitive layer add value above SBERT-alone?"

It runs the same benchmark tasks but uses ONLY SBERT cosine similarity
to score choices — no NGC, no causal arena, no Hopfield memory, no
belief updates, no falsification. Just:

    score(choice_i) = cosine_sim(sbert(prompt), sbert(prompt + choice_i))

This is the honest baseline the cognitive layer must beat. If the
cognitive layer's Δ over SBERT-alone is positive, the manifold is
doing real work. If it's zero, the manifold is expensive SBERT.

Usage:
    python scripts/ablation_sbert_only.py --max-samples 100
    python scripts/ablation_sbert_only.py --tasks copa,boolq,sciq
"""
import sys
import os
import time
import json
import argparse
import hashlib
import logging

import numpy as np

logger = logging.getLogger(__name__)


def main():
    parser = argparse.ArgumentParser(description="SBERT-only ablation baseline")
    parser.add_argument("--tasks", default=None, help="Comma-separated task names")
    parser.add_argument("--max-samples", type=int, default=None, help="Max samples per task")
    parser.add_argument("--sbert-model", default="all-MiniLM-L6-v2", help="SBERT model name")
    parser.add_argument("--output", default=None, help="Save JSON results to file")
    parser.add_argument("--seed", type=int, default=42)
    args = parser.parse_args()

    from tensegrity.bench.tasks import TASK_REGISTRY, load_task_samples

    # Load SBERT
    try:
        from sentence_transformers import SentenceTransformer
        sbert = SentenceTransformer(args.sbert_model)
        print(f"Loaded SBERT: {args.sbert_model}")
    except Exception as e:
        print(f"FATAL: Could not load SBERT: {e}")
        sys.exit(1)

    tasks = args.tasks.split(",") if args.tasks else list(TASK_REGISTRY.keys())

    print(f"\n{'█' * 60}")
    print(f"  SBERT-ONLY ABLATION BASELINE")
    print(f"  Model: {args.sbert_model}")
    print(f"  Tasks: {len(tasks)}")
    print(f"  N/task: {args.max_samples or 'all'}")
    print(f"{'█' * 60}")

    t_start = time.time()
    all_results = []
    total_correct_sbert = 0
    total_correct_random = 0
    total_n = 0

    for task_name in tasks:
        config = TASK_REGISTRY[task_name]
        samples = load_task_samples(task_name, args.max_samples)
        print(f"\n  ▸ {task_name}: {config.description} ({len(samples)} samples)")

        task_correct_sbert = 0
        task_correct_random = 0
        task_n = len(samples)

        for sample in samples:
            n = len(sample.choices)
            if n == 0:
                continue

            # SBERT-only scoring: cosine(prompt, prompt+choice)
            texts = [sample.prompt] + [f"{sample.prompt} {c}" for c in sample.choices]
            embs = sbert.encode(texts, show_progress_bar=False)
            pe = embs[0]
            pn = np.linalg.norm(pe)
            scores = np.zeros(n)
            if pn > 1e-8:
                for i in range(n):
                    ce = embs[i + 1]
                    cn = np.linalg.norm(ce)
                    if cn > 1e-8:
                        scores[i] = np.dot(pe, ce) / (pn * cn)

            sbert_pred = int(np.argmax(scores))
            if sbert_pred == sample.gold:
                task_correct_sbert += 1

            # Random baseline for comparison
            seed_bytes = hashlib.sha256(sample.id.encode("utf-8")).digest()
            sample_seed = int.from_bytes(seed_bytes[:8], "big", signed=False) % (2**31)
            rng = np.random.RandomState(sample_seed)
            random_pred = int(np.argmax(rng.randn(n)))
            if random_pred == sample.gold:
                task_correct_random += 1

        sbert_acc = task_correct_sbert / max(task_n, 1)
        random_acc = task_correct_random / max(task_n, 1)
        chance = 1.0 / config.n_choices if config.n_choices > 0 else 0.25

        total_correct_sbert += task_correct_sbert
        total_correct_random += task_correct_random
        total_n += task_n

        result = {
            "task": task_name, "domain": config.domain, "n": task_n,
            "sbert_accuracy": round(sbert_acc, 4),
            "random_accuracy": round(random_acc, 4),
            "chance": round(chance, 4),
            "sbert_over_chance": round(sbert_acc - chance, 4),
        }
        all_results.append(result)
        print(f"    SBERT={sbert_acc:.1%}  random={random_acc:.1%}  "
              f"chance={chance:.1%}  SBERT-chance={sbert_acc-chance:+.1%}")

    total_time = time.time() - t_start
    overall_sbert = total_correct_sbert / max(total_n, 1)
    overall_random = total_correct_random / max(total_n, 1)

    print(f"\n{'═' * 75}")
    print(f"  SBERT-only overall: {overall_sbert:.1%}  (random: {overall_random:.1%})")
    print(f"  Total: {total_n} samples, {total_time:.1f}s")
    print(f"{'═' * 75}")

    # Print comparison table
    print(f"\n{'Task':<22} {'N':>5} {'SBERT':>7} {'Random':>7} {'Chance':>7} {'SBERT-Chance':>12}")
    print("─" * 65)
    for r in sorted(all_results, key=lambda x: x["sbert_over_chance"], reverse=True):
        print(f"{r['task']:<22} {r['n']:>5} {r['sbert_accuracy']:>6.1%} "
              f"{r['random_accuracy']:>6.1%} {r['chance']:>6.1%} "
              f"{r['sbert_over_chance']:>+11.1%}")
    print("─" * 65)
    print(f"{'OVERALL':<22} {total_n:>5} {overall_sbert:>6.1%} {overall_random:>6.1%}")

    output = {
        "mode": "sbert_only_ablation",
        "sbert_model": args.sbert_model,
        "overall_sbert_accuracy": round(overall_sbert, 4),
        "overall_random_accuracy": round(overall_random, 4),
        "total_samples": total_n,
        "wall_time_s": round(total_time, 1),
        "tasks": all_results,
    }

    if args.output:
        with open(args.output, "w") as f:
            json.dump(output, f, indent=2)
        print(f"\nResults saved to {args.output}")
    else:
        print(f"\n{json.dumps(output, indent=2)}")


if __name__ == "__main__":
    main()