meridian-mcp monorepo sync commited on
Commit
e84fa2c
·
1 Parent(s): 066683d

Sync from meridian-mcp@bf1547266b87d49e4fef560bf851bd07c585f5ac

Browse files
.github/workflows/finish-line.yml ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: photon-route finish-line (B1 + B2 + A3-Real)
2
+
3
+ on:
4
+ workflow_dispatch:
5
+ inputs:
6
+ seeds:
7
+ description: "Sweep seeds for A3-Real (default: 1 2 3 4 5)"
8
+ required: false
9
+ default: "1 2 3 4 5"
10
+ steps:
11
+ description: "Training steps for A3-Real (default: 200)"
12
+ required: false
13
+ default: "200"
14
+
15
+ permissions:
16
+ contents: read
17
+
18
+ jobs:
19
+ finish:
20
+ runs-on: ubuntu-latest
21
+ timeout-minutes: 90
22
+ steps:
23
+ - uses: actions/checkout@v4
24
+ - uses: actions/setup-python@v5
25
+ with:
26
+ python-version: "3.11"
27
+
28
+ - name: Install dependencies
29
+ run: |
30
+ python -m pip install --upgrade pip
31
+ pip install numpy scipy thewalrus torch sentence-transformers
32
+
33
+ - name: B1 — g^(1) coherence-time sim
34
+ run: |
35
+ mkdir -p reports
36
+ PYTHONPATH=src:. python -m space.sim_b1_g1_coherence | tee reports/b1.txt
37
+
38
+ - name: B2 — g^(2) classifier sim
39
+ run: |
40
+ PYTHONPATH=src:. python -m space.sim_b2_g2_classifier | tee reports/b2.txt
41
+
42
+ - name: Expand relevance (title queries)
43
+ run: |
44
+ PYTHONPATH=src:. python -m eval.expand_titles
45
+
46
+ - name: A3-Real — Fock-basis non-Gaussian sweep
47
+ run: |
48
+ PYTHONPATH=src:. python -m space.run_sweep_fock \
49
+ --seeds ${{ github.event.inputs.seeds }} \
50
+ --steps ${{ github.event.inputs.steps }} \
51
+ --herald-ns 1 0 \
52
+ --out-csv sweep_fock_results.csv \
53
+ --log-dir sweep_fock_logs \
54
+ | tee reports/a3.txt
55
+
56
+ - name: Upload artifacts
57
+ if: always()
58
+ uses: actions/upload-artifact@v4
59
+ with:
60
+ name: photon-finish-line
61
+ path: |
62
+ reports/
63
+ sweep_fock_results.csv
64
+ sweep_fock_logs/
65
+ eval/relevance_expanded.json
66
+ if-no-files-found: warn
67
+ retention-days: 30
68
+
69
+ - name: Print summaries
70
+ if: always()
71
+ run: |
72
+ echo "=========================================="
73
+ echo "B1 — g^(1) coherence sim"
74
+ echo "=========================================="
75
+ tail -25 reports/b1.txt || true
76
+ echo
77
+ echo "=========================================="
78
+ echo "B2 — g^(2) classifier sim"
79
+ echo "=========================================="
80
+ tail -25 reports/b2.txt || true
81
+ echo
82
+ echo "=========================================="
83
+ echo "A3-Real — Fock sweep CSV"
84
+ echo "=========================================="
85
+ cat sweep_fock_results.csv || true
.github/workflows/pages.yml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Deploy photon-route UI to GitHub Pages
2
+
3
+ on:
4
+ push:
5
+ branches: [main]
6
+ paths:
7
+ - 'pages/**'
8
+ - '.github/workflows/pages.yml'
9
+ workflow_dispatch: {}
10
+
11
+ permissions:
12
+ contents: read
13
+ pages: write
14
+ id-token: write
15
+
16
+ concurrency:
17
+ group: pages
18
+ cancel-in-progress: false
19
+
20
+ jobs:
21
+ deploy:
22
+ environment:
23
+ name: github-pages
24
+ url: ${{ steps.deployment.outputs.page_url }}
25
+ runs-on: ubuntu-latest
26
+ steps:
27
+ - uses: actions/checkout@v4
28
+ - uses: actions/configure-pages@v5
29
+ - uses: actions/upload-pages-artifact@v3
30
+ with:
31
+ path: ./pages
32
+ - id: deployment
33
+ uses: actions/deploy-pages@v4
.github/workflows/photon-sweep.yml ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: photon-route sweep (5 splits × squeeze ablation)
2
+
3
+ on:
4
+ workflow_dispatch:
5
+ inputs:
6
+ seeds:
7
+ description: "Space-separated seeds (default: 1 2 3 4 5)"
8
+ required: false
9
+ default: "1 2 3 4 5"
10
+ steps:
11
+ description: "Training steps per run (default: 200)"
12
+ required: false
13
+ default: "200"
14
+
15
+ permissions:
16
+ contents: read
17
+
18
+ jobs:
19
+ sweep:
20
+ runs-on: ubuntu-latest
21
+ timeout-minutes: 60
22
+ steps:
23
+ - uses: actions/checkout@v4
24
+
25
+ - uses: actions/setup-python@v5
26
+ with:
27
+ python-version: "3.11"
28
+
29
+ - name: Install dependencies
30
+ run: |
31
+ python -m pip install --upgrade pip
32
+ pip install numpy scipy thewalrus torch sentence-transformers
33
+
34
+ - name: Expand relevance (title queries)
35
+ run: |
36
+ PYTHONPATH=src:. python -m eval.expand_titles
37
+
38
+ - name: Run sweep
39
+ run: |
40
+ PYTHONPATH=src:. python -m space.run_sweep \
41
+ --seeds ${{ github.event.inputs.seeds }} \
42
+ --steps ${{ github.event.inputs.steps }} \
43
+ --out-csv sweep_results.csv \
44
+ --log-dir sweep_logs
45
+
46
+ - name: Upload sweep results
47
+ uses: actions/upload-artifact@v4
48
+ with:
49
+ name: photon-sweep-results
50
+ path: |
51
+ sweep_results.csv
52
+ sweep_logs/
53
+ eval/relevance_expanded.json
54
+ if-no-files-found: error
55
+ retention-days: 30
56
+
57
+ - name: Print CSV summary
58
+ if: always()
59
+ run: |
60
+ echo "=== sweep_results.csv ==="
61
+ cat sweep_results.csv || true
eval/expand_titles.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Expand the eval relevance set with title-as-query pairs.
2
+
3
+ For each arXiv ID in corpus_ids.json, fetch og:title from the abstract
4
+ page (same scrape pattern as eval.fetch but a different meta tag), and
5
+ emit one query whose only relevant document is that paper.
6
+
7
+ Output: eval/relevance_expanded.json — original 6 multi-positive queries
8
+ plus 20 single-positive title queries = 26 total. Increases trainer
9
+ signal 4× without any human labeling.
10
+
11
+ This script does NOT touch the existing relevance.json. It writes a
12
+ sibling file the trainer / eval harness opt into via --relevance.
13
+ """
14
+ from __future__ import annotations
15
+
16
+ import argparse
17
+ import html
18
+ import json
19
+ import re
20
+ import time
21
+ import urllib.request
22
+ from pathlib import Path
23
+
24
+ ROOT = Path(__file__).resolve().parent.parent
25
+ ARXIV_ABS = "https://arxiv.org/abs/"
26
+ _OG_TITLE = re.compile(
27
+ r'<meta\s+(?:property|name)="og:title"\s+content="([^"]*)"',
28
+ re.IGNORECASE,
29
+ )
30
+ _BROWSER_HEADERS = {
31
+ "User-Agent": (
32
+ "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 "
33
+ "(KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36"
34
+ ),
35
+ "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
36
+ "Accept-Language": "en-US,en;q=0.9",
37
+ }
38
+
39
+
40
+ def _normalize(text: str) -> str:
41
+ return re.sub(r"\s+", " ", text).strip()
42
+
43
+
44
+ def _strip_arxiv_prefix(title: str) -> str:
45
+ """og:title comes back as '[2304.12717] Quantum natural language ...';
46
+ strip the '[id]' prefix so the query is just the paper title."""
47
+ return re.sub(r"^\s*\[[^\]]+\]\s*", "", title).strip()
48
+
49
+
50
+ def fetch_title(arxiv_id: str, timeout: float = 30.0) -> str:
51
+ url = ARXIV_ABS + arxiv_id
52
+ req = urllib.request.Request(url, headers=_BROWSER_HEADERS)
53
+ with urllib.request.urlopen(req, timeout=timeout) as resp:
54
+ body = resp.read().decode("utf-8", errors="replace")
55
+ m = _OG_TITLE.search(body)
56
+ if not m:
57
+ raise RuntimeError(f"og:title not found for {arxiv_id}")
58
+ return _strip_arxiv_prefix(_normalize(html.unescape(m.group(1))))
59
+
60
+
61
+ def main():
62
+ ap = argparse.ArgumentParser()
63
+ ap.add_argument("--corpus", type=Path, default=ROOT / "eval" / "corpus_ids.json")
64
+ ap.add_argument("--in-relevance", type=Path, default=ROOT / "eval" / "relevance.json")
65
+ ap.add_argument("--out", type=Path, default=ROOT / "eval" / "relevance_expanded.json")
66
+ ap.add_argument("--cache", type=Path, default=Path.home() / ".cache" / "photon-route" / "titles")
67
+ ap.add_argument("--sleep", type=float, default=0.5)
68
+ args = ap.parse_args()
69
+
70
+ args.cache.mkdir(parents=True, exist_ok=True)
71
+ ids = json.loads(args.corpus.read_text("utf-8"))["ids"]
72
+ base = json.loads(args.in_relevance.read_text("utf-8"))
73
+
74
+ titles = {}
75
+ for j, i in enumerate(ids):
76
+ cache_path = args.cache / f"{i}.title"
77
+ if cache_path.exists():
78
+ titles[i] = cache_path.read_text("utf-8").strip()
79
+ continue
80
+ t = fetch_title(i)
81
+ cache_path.write_text(t, encoding="utf-8")
82
+ titles[i] = t
83
+ print(f"[{j+1:2d}/{len(ids)}] {i}: {t[:60]}")
84
+ if j + 1 < len(ids):
85
+ time.sleep(args.sleep)
86
+
87
+ title_queries = [
88
+ {"query": titles[i], "relevant_ids": [i], "kind": "title"}
89
+ for i in ids
90
+ ]
91
+ out_payload = {
92
+ **base,
93
+ "queries": [
94
+ *[{**q, "kind": "topical"} for q in base["queries"]],
95
+ *title_queries,
96
+ ],
97
+ }
98
+ args.out.write_text(json.dumps(out_payload, indent=2) + "\n", encoding="utf-8")
99
+ print(f"\nwrote {len(out_payload['queries'])} queries → {args.out}")
100
+
101
+
102
+ if __name__ == "__main__":
103
+ main()
eval/run_bm25.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """BM25 baseline against the same eval set photon-route uses.
2
+
3
+ Drops in next to eval.run so apples-to-apples on Recall@k / nDCG@k.
4
+ Pure-stdlib BM25 — no external IR library — to keep the dependency
5
+ surface identical to the rest of eval/.
6
+ """
7
+ from __future__ import annotations
8
+
9
+ import argparse
10
+ import json
11
+ import math
12
+ from collections import Counter
13
+ from pathlib import Path
14
+
15
+ import numpy as np
16
+
17
+ from eval.fetch import fetch_all, verify_against_manifest
18
+
19
+
20
+ class BM25:
21
+ def __init__(self, docs: list[str], k1: float = 1.5, b: float = 0.75):
22
+ self.k1, self.b = k1, b
23
+ self.toks = [d.lower().split() for d in docs]
24
+ self.N = len(docs)
25
+ self.avgdl = sum(len(t) for t in self.toks) / self.N
26
+ df: Counter = Counter()
27
+ for t in self.toks:
28
+ for w in set(t):
29
+ df[w] += 1
30
+ self.idf = {
31
+ w: math.log(1 + (self.N - n + 0.5) / (n + 0.5)) for w, n in df.items()
32
+ }
33
+
34
+ def score(self, query: str, doc_index: int) -> float:
35
+ d = self.toks[doc_index]
36
+ tf = Counter(d)
37
+ s = 0.0
38
+ for w in query.lower().split():
39
+ if w not in self.idf:
40
+ continue
41
+ f = tf[w]
42
+ denom = f + self.k1 * (1 - self.b + self.b * len(d) / self.avgdl)
43
+ s += self.idf[w] * f * (self.k1 + 1) / max(denom, 1e-9)
44
+ return s
45
+
46
+
47
+ def recall_at_k(ranked_ids: list[str], relevant: set[str], k: int) -> float:
48
+ if not relevant:
49
+ return float("nan")
50
+ return len(set(ranked_ids[:k]) & relevant) / len(relevant)
51
+
52
+
53
+ def ndcg_at_k(ranked_ids: list[str], relevant: set[str], k: int) -> float:
54
+ if not relevant:
55
+ return float("nan")
56
+ dcg = sum(
57
+ 1.0 / math.log2(i + 1)
58
+ for i, a in enumerate(ranked_ids[:k], start=1)
59
+ if a in relevant
60
+ )
61
+ ideal = sum(1.0 / math.log2(i + 1) for i in range(1, min(k, len(relevant)) + 1))
62
+ return dcg / ideal if ideal > 0 else float("nan")
63
+
64
+
65
+ def main():
66
+ ap = argparse.ArgumentParser()
67
+ ap.add_argument("--corpus", type=Path, default=Path(__file__).parent / "corpus_ids.json")
68
+ ap.add_argument("--relevance", type=Path, default=Path(__file__).parent / "relevance.json")
69
+ ap.add_argument("--manifest", type=Path, default=Path(__file__).parent / "manifest.json")
70
+ ap.add_argument("--ks", type=int, nargs="+", default=[1, 3, 5, 10])
71
+ args = ap.parse_args()
72
+
73
+ ids = json.loads(args.corpus.read_text("utf-8"))["ids"]
74
+ queries = json.loads(args.relevance.read_text("utf-8"))["queries"]
75
+ abstracts = fetch_all(ids)
76
+ bad = verify_against_manifest(abstracts, args.manifest)
77
+ if bad:
78
+ raise SystemExit(f"manifest mismatch: {list(bad)[:3]}")
79
+
80
+ docs_in_order = [abstracts[i] for i in ids]
81
+ bm25 = BM25(docs_in_order)
82
+
83
+ per_query = []
84
+ for q in queries:
85
+ scored = sorted(
86
+ ((bm25.score(q["query"], i), ids[i]) for i in range(len(ids))),
87
+ key=lambda x: -x[0],
88
+ )
89
+ ranked_ids = [doc_id for _, doc_id in scored]
90
+ rel = set(q["relevant_ids"])
91
+ row = {"query": q["query"], "ranked": ranked_ids[: max(args.ks)]}
92
+ for k in args.ks:
93
+ row[f"recall@{k}"] = recall_at_k(ranked_ids, rel, k)
94
+ row[f"ndcg@{k}"] = ndcg_at_k(ranked_ids, rel, k)
95
+ per_query.append(row)
96
+
97
+ aggregate = {
98
+ f"recall@{k}": float(np.mean([q[f"recall@{k}"] for q in per_query])) for k in args.ks
99
+ }
100
+ aggregate.update(
101
+ {f"ndcg@{k}": float(np.mean([q[f"ndcg@{k}"] for q in per_query])) for k in args.ks}
102
+ )
103
+ print(f"backend=bm25 corpus={len(ids)} queries={len(queries)}")
104
+ for q in per_query:
105
+ cells = " ".join(
106
+ f"{m}={q[m]:.3f}" for m in q if m.startswith(("recall", "ndcg"))
107
+ )
108
+ print(f" {q['query'][:48]:<48s} {cells}")
109
+ print("aggregate: " + " ".join(f"{m}={aggregate[m]:.3f}" for m in aggregate))
110
+
111
+
112
+ if __name__ == "__main__":
113
+ main()
eval/run_sbert.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SBERT (all-MiniLM-L6-v2) baseline against the same eval set.
2
+
3
+ Mean-pooled 384-d sentence embedding, cosine similarity. Establishes the
4
+ modern dense-retrieval ceiling for the photon-route eval. Runs entirely
5
+ on CPU in a few seconds for this corpus size.
6
+ """
7
+ from __future__ import annotations
8
+
9
+ import argparse
10
+ import json
11
+ import math
12
+ from pathlib import Path
13
+
14
+ import numpy as np
15
+
16
+ from eval.fetch import fetch_all, verify_against_manifest
17
+
18
+
19
+ def recall_at_k(ranked_ids: list[str], relevant: set[str], k: int) -> float:
20
+ if not relevant:
21
+ return float("nan")
22
+ return len(set(ranked_ids[:k]) & relevant) / len(relevant)
23
+
24
+
25
+ def ndcg_at_k(ranked_ids: list[str], relevant: set[str], k: int) -> float:
26
+ if not relevant:
27
+ return float("nan")
28
+ dcg = sum(
29
+ 1.0 / math.log2(i + 1)
30
+ for i, a in enumerate(ranked_ids[:k], start=1)
31
+ if a in relevant
32
+ )
33
+ ideal = sum(1.0 / math.log2(i + 1) for i in range(1, min(k, len(relevant)) + 1))
34
+ return dcg / ideal if ideal > 0 else float("nan")
35
+
36
+
37
+ def main():
38
+ ap = argparse.ArgumentParser()
39
+ ap.add_argument("--model", default="sentence-transformers/all-MiniLM-L6-v2")
40
+ ap.add_argument("--corpus", type=Path, default=Path(__file__).parent / "corpus_ids.json")
41
+ ap.add_argument("--relevance", type=Path, default=Path(__file__).parent / "relevance.json")
42
+ ap.add_argument("--manifest", type=Path, default=Path(__file__).parent / "manifest.json")
43
+ ap.add_argument("--ks", type=int, nargs="+", default=[1, 3, 5, 10])
44
+ args = ap.parse_args()
45
+
46
+ from sentence_transformers import SentenceTransformer
47
+
48
+ ids = json.loads(args.corpus.read_text("utf-8"))["ids"]
49
+ queries = json.loads(args.relevance.read_text("utf-8"))["queries"]
50
+ abstracts = fetch_all(ids)
51
+ bad = verify_against_manifest(abstracts, args.manifest)
52
+ if bad:
53
+ raise SystemExit(f"manifest mismatch: {list(bad)[:3]}")
54
+
55
+ print(f"loading {args.model}...")
56
+ model = SentenceTransformer(args.model)
57
+
58
+ docs_in_order = [abstracts[i] for i in ids]
59
+ doc_emb = model.encode(docs_in_order, normalize_embeddings=True, show_progress_bar=False)
60
+ q_emb = model.encode([q["query"] for q in queries], normalize_embeddings=True, show_progress_bar=False)
61
+
62
+ per_query = []
63
+ for qi, q in enumerate(queries):
64
+ sims = doc_emb @ q_emb[qi] # cosine since both are normalized
65
+ order = np.argsort(-sims)
66
+ ranked_ids = [ids[i] for i in order]
67
+ rel = set(q["relevant_ids"])
68
+ row = {"query": q["query"], "ranked": ranked_ids[: max(args.ks)]}
69
+ for k in args.ks:
70
+ row[f"recall@{k}"] = recall_at_k(ranked_ids, rel, k)
71
+ row[f"ndcg@{k}"] = ndcg_at_k(ranked_ids, rel, k)
72
+ per_query.append(row)
73
+
74
+ aggregate = {f"recall@{k}": float(np.mean([q[f"recall@{k}"] for q in per_query])) for k in args.ks}
75
+ aggregate.update(
76
+ {f"ndcg@{k}": float(np.mean([q[f"ndcg@{k}"] for q in per_query])) for k in args.ks}
77
+ )
78
+
79
+ print(f"backend=sbert/{args.model.split('/')[-1]} corpus={len(ids)} queries={len(queries)}")
80
+ for q in per_query:
81
+ cells = " ".join(f"{m}={q[m]:.3f}" for m in q if m.startswith(("recall", "ndcg")))
82
+ print(f" {q['query'][:48]:<48s} {cells}")
83
+ print("aggregate: " + " ".join(f"{m}={aggregate[m]:.3f}" for m in aggregate))
84
+
85
+
86
+ if __name__ == "__main__":
87
+ main()
pages/CNAME ADDED
@@ -0,0 +1 @@
 
 
1
+ photon.ask-meridian.uk
worker/proxy.js → pages/index.html RENAMED
@@ -1,57 +1,4 @@
1
- // photon-route Cloudflare Worker
2
- // Routes photon.ask-meridian.uk → HF Space (luuow-photon-route),
3
- // edge-caches /rank, and serves an inline interactive UI at /.
4
- //
5
- // Endpoints (worker-local):
6
- // / interactive HTML page
7
- // /health liveness JSON
8
- // /api service banner JSON
9
- // Endpoints (proxied to HF Space):
10
- // /rank?q=&top_k= edge-cached 24 h
11
- // /version /docs /openapi.json
12
- //
13
- // Cache key includes CACHE_VERSION; bump to invalidate after UI changes
14
- // or fixture updates upstream.
15
-
16
- const HF = 'https://luuow-photon-route.hf.space';
17
- const PROXY = new Set(['/rank', '/version', '/docs', '/openapi.json']);
18
- // CACHE_VERSION bump: /rank responses now include `backend` and the
19
- // payload differs per backend; the cache key already includes the
20
- // query string so backends are cached separately, but old (v3) entries
21
- // don't have the backend field — invalidate.
22
- const CACHE_VERSION = 'v4';
23
-
24
- const BANNER = {
25
- service: 'photon-route',
26
- proxy: 'cloudflare-worker',
27
- backend: 'huggingface-space',
28
- upstream: HF,
29
- repo: 'https://github.com/LuuOW/photon-route',
30
- sister: 'https://qrouter.ask-meridian.uk (DV / qubit-gate sister project)',
31
- endpoints: {
32
- ui: '/ (interactive HTML)',
33
- api: '/api (this banner)',
34
- health: '/health (worker-local)',
35
- rank: '/rank?q=<query>&top_k=N&backend=v1|sha_init|trained (proxied, edge-cached 24 h)',
36
- version: '/version (proxied)',
37
- docs: '/docs (proxied — FastAPI swagger)',
38
- },
39
- backends: ['v1 (SF)', 'sha_init (numpy)', 'trained (numpy + learned)'],
40
- note: 'CV photonic retrieval. Strawberry Fields Gaussian programs, thewalrus closed-form fidelity.',
41
- };
42
-
43
- const CSP = [
44
- "default-src 'self'",
45
- "style-src 'self' 'unsafe-inline'",
46
- "script-src 'self' 'unsafe-inline'",
47
- "connect-src 'self'",
48
- "img-src 'self' data:",
49
- "base-uri 'none'",
50
- "form-action 'none'",
51
- "frame-ancestors 'none'",
52
- ].join('; ');
53
-
54
- const HTML = `<!doctype html>
55
  <html lang="en">
56
  <head>
57
  <meta charset="utf-8">
@@ -161,7 +108,6 @@ const HTML = `<!doctype html>
161
  border-bottom:1px dotted var(--line)}
162
  footer a:hover{color:var(--fg);border-bottom-color:var(--cyan)}
163
  .empty{color:var(--muted);text-align:center;padding:32px 12px;font-size:12px}
164
- /* phase-space visualization */
165
  .viz{margin:14px 0 18px;display:grid;grid-template-columns:1fr 1fr;gap:10px}
166
  .modepanel{position:relative;background:var(--panel);border:1px solid var(--line);
167
  border-radius:var(--radius);overflow:hidden;aspect-ratio:5/4;min-height:200px}
@@ -198,6 +144,193 @@ const HTML = `<!doctype html>
198
  </style>
199
  </head>
200
  <body>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  <main>
202
  <header>
203
  <div>
@@ -249,28 +382,29 @@ const HTML = `<!doctype html>
249
  <div class="body">
250
  <p><strong>photon-route</strong> is a research artifact exploring whether semantic retrieval can run in the continuous-variable (CV) photonic regime — the regime that real photonic hardware (Xanadu Borealis, fiber-loop reservoirs, coherent Ising machines) actually operates in.</p>
251
  <p>Each document is encoded as a <em>Gaussian state</em> over N bosonic modes via a <a href="https://strawberryfields.ai/" target="_blank" rel="noopener">Strawberry Fields</a> program: words contribute squeezing and displacement operations, then a beam-splitter network mixes the modes. Query and document fidelity is computed in closed form using the <a href="https://the-walrus.readthedocs.io/" target="_blank" rel="noopener">thewalrus</a> implementation of the Banchi-Braunstein-Pirandola formula.</p>
252
- <p>Three swappable encoders share the same fidelity scoring. <strong>v1</strong> uses Strawberry Fields to run an N-mode Gaussian program with SHA-256-derived parameters. <strong>sha_init</strong> is a pure-numpy port of the same gates (no SF at deploy time). <strong>trained</strong> swaps the SHA-256 lookup for a learned table fit by InfoNCE + Bhattacharyya-coefficient surrogate fidelity on a small arXiv quant-ph eval set. Toggle <em>compare</em> to see the three side by side; the DV-qubit sister project is <a href="https://qrouter.ask-meridian.uk" target="_blank" rel="noopener">qrouter</a>.</p>
253
  <p>Source · <a href="https://github.com/LuuOW/photon-route" target="_blank" rel="noopener">github.com/LuuOW/photon-route</a></p>
254
  </div>
255
  </details>
256
 
257
  <footer>
258
- <span>CV photonic · gaussian backend · edge-cached at the Cloudflare boundary</span>
259
- <span><a href="https://qrouter.ask-meridian.uk" rel="noopener">qrouter (DV)</a> · <a href="/docs">/docs</a> · <a href="/api">json</a></span>
260
  </footer>
261
  </main>
262
 
263
  <script>
264
  (function(){
 
 
 
 
 
265
  function $(id){return document.getElementById(id)}
266
  var q=$('q'), k=$('k'), results=$('results'), status=$('status');
267
  var healthPill=$('health'), healthText=$('health-text');
268
  var abort=null, debounceT=0;
269
 
270
- // ============================================================
271
- // Gaussian-state encoding (mirrors src/photon_route/encode.py)
272
- // xpxp ordering: mu = [q0, p0, q1, p1]; sigma is 4x4
273
- // ============================================================
274
  var N_MODES = 2, MAX_SQUEEZE = 0.5, MAX_DISPLACE = 1.0;
275
  var _wcache = new Map();
276
 
@@ -282,7 +416,6 @@ const HTML = `<!doctype html>
282
  var hash = await crypto.subtle.digest('SHA-256', buf);
283
  bytes = new Uint8Array(hash);
284
  } catch(e){
285
- // fallback: not cryptographic but stable enough for viz when subtle unavailable
286
  bytes = new Uint8Array(32);
287
  var h = 2166136261;
288
  for(var i=0;i<word.length;i++){ h ^= word.charCodeAt(i); h = (h*16777619)>>>0; }
@@ -290,7 +423,6 @@ const HTML = `<!doctype html>
290
  }
291
  var parts = [];
292
  for(var i2=0;i2<4;i2++){
293
- // 8 bytes -> BigInt -> mod 1e9 / 1e9
294
  var big = 0n;
295
  for(var j2=0;j2<8;j2++) big = (big << 8n) + BigInt(bytes[i2*8 + j2]);
296
  parts.push(Number(big % 1000000000n) / 1e9);
@@ -322,7 +454,6 @@ const HTML = `<!doctype html>
322
  function tr4(A){var T=mat4(); for(var i=0;i<4;i++) for(var j=0;j<4;j++) T[i][j]=A[j][i]; return T;}
323
 
324
  function sgateMat(k, r, phi){
325
- // S22 = R(phi/2) @ diag(e^{-r}, e^r) @ R(-phi/2)
326
  var c=Math.cos(phi/2), s=Math.sin(phi/2), em=Math.exp(-r), ep=Math.exp(r);
327
  var a = c*c*em + s*s*ep;
328
  var b = c*s*(em - ep);
@@ -377,11 +508,7 @@ const HTML = `<!doctype html>
377
  };
378
  }
379
 
380
- // ============================================================
381
- // 2.5D Wigner-function rendering on a 2D canvas (no WebGL).
382
- // For Gaussian states: W(x) = (1 / (2π√det Σ)) exp(-½ (x-μ)ᵀΣ⁻¹(x-μ))
383
- // ============================================================
384
- var GRID = 26; // quads per side; 26² × 2 modes ~= 1352 quads/frame
385
  var DPR = Math.min(window.devicePixelRatio || 1, 2);
386
 
387
  function WignerView(canvas, coordEl){
@@ -390,8 +517,8 @@ const HTML = `<!doctype html>
390
  this.ctx = canvas.getContext('2d');
391
  this.mu = [0,0];
392
  this.sigma = [[1,0],[0,1]];
393
- this.yaw = 0.55; // initial ~31°
394
- this.pitch = 0.85; // tilt down
395
  this.userYaw = false;
396
  this.dragging = false;
397
  this.lastX = 0; this.lastY = 0;
@@ -440,29 +567,24 @@ const HTML = `<!doctype html>
440
  WignerView.prototype.draw = function(){
441
  var ctx = this.ctx, W = this.canvas.width, H = this.canvas.height;
442
  ctx.clearRect(0,0,W,H);
443
-
444
  var mu = this.mu, sg = this.sigma;
445
  var det = sg[0][0]*sg[1][1] - sg[0][1]*sg[1][0];
446
  if(det < 1e-12){ return; }
447
  var iv00 = sg[1][1]/det, iv01 = -sg[0][1]/det, iv10 = -sg[1][0]/det, iv11 = sg[0][0]/det;
448
  var norm = 1/(2*Math.PI*Math.sqrt(det));
449
-
450
- // auto-scale to fit ±3σ ellipse around μ
451
  var sQ = Math.sqrt(Math.abs(sg[0][0])), sP = Math.sqrt(Math.abs(sg[1][1]));
452
  var extent = Math.max(2.6, Math.abs(mu[0]) + 3*sQ, Math.abs(mu[1]) + 3*sP);
453
  var scaleXY = (Math.min(W, H) * 0.32) / extent;
454
- var scaleZ = (Math.min(W, H) * 0.55) * Math.sqrt(det); // visual height ~ peak amplitude
455
  var ox = W*0.5, oy = H*0.62;
456
-
457
  var cy = Math.cos(this.yaw), sy = Math.sin(this.yaw);
458
  var cp = Math.cos(this.pitch), sp = Math.sin(this.pitch);
459
-
460
  function project(q, p, w){
461
  var xr = q*cy - p*sy;
462
  var yr = q*sy + p*cy;
463
  var sx = ox + xr * scaleXY;
464
  var sy_ = oy - yr * scaleXY * cp - w * scaleZ * sp;
465
- var depth = yr * sp - w * cp; // larger = farther back
466
  return [sx, sy_, depth];
467
  }
468
  function projectFlat(q, p){
@@ -470,8 +592,6 @@ const HTML = `<!doctype html>
470
  var yr = q*sy + p*cy;
471
  return [ox + xr*scaleXY, oy - yr*scaleXY*cp];
472
  }
473
-
474
- // floor grid on (q, p) plane
475
  ctx.strokeStyle = 'rgba(28,39,66,0.7)';
476
  ctx.lineWidth = 1;
477
  var gN = 6;
@@ -484,7 +604,6 @@ const HTML = `<!doctype html>
484
  var dFlat = projectFlat( extent, t);
485
  ctx.beginPath(); ctx.moveTo(cFlat[0], cFlat[1]); ctx.lineTo(dFlat[0], dFlat[1]); ctx.stroke();
486
  }
487
- // axes: q (cyan), p (indigo) at origin
488
  var oFlat = projectFlat(0,0);
489
  var qAxis = projectFlat(extent, 0);
490
  var pAxis = projectFlat(0, extent);
@@ -492,8 +611,6 @@ const HTML = `<!doctype html>
492
  ctx.beginPath(); ctx.moveTo(oFlat[0], oFlat[1]); ctx.lineTo(qAxis[0], qAxis[1]); ctx.stroke();
493
  ctx.strokeStyle = 'rgba(129,140,248,0.45)';
494
  ctx.beginPath(); ctx.moveTo(oFlat[0], oFlat[1]); ctx.lineTo(pAxis[0], pAxis[1]); ctx.stroke();
495
-
496
- // sample Wigner on grid + project
497
  var N = GRID;
498
  var step = (2*extent)/N;
499
  var pts = new Array(N+1);
@@ -512,8 +629,6 @@ const HTML = `<!doctype html>
512
  pts[i][j] = {w:w, sx:pr[0], sy:pr[1], depth:pr[2]};
513
  }
514
  }
515
-
516
- // build quads + sort back-to-front
517
  var quads = [];
518
  for(var i2=0;i2<N;i2++){
519
  for(var j2=0;j2<N;j2++){
@@ -523,13 +638,10 @@ const HTML = `<!doctype html>
523
  quads.push({a:a, b:b, c:c, d:d, depth:depth, w:wAvg});
524
  }
525
  }
526
- quads.sort(function(x, y){ return y.depth - x.depth; }); // larger depth first (farther)
527
-
528
- // draw
529
  for(var qi2=0; qi2<quads.length; qi2++){
530
  var qd = quads[qi2];
531
  var t2 = wmax > 1e-12 ? Math.max(0, Math.min(1, qd.w / wmax)) : 0;
532
- // indigo (low) -> cyan (high)
533
  var rC = Math.round(0x81*(1-t2) + 0x22*t2);
534
  var gC = Math.round(0x8c*(1-t2) + 0xd3*t2);
535
  var bC = Math.round(0xf8*(1-t2) + 0xee*t2);
@@ -556,7 +668,6 @@ const HTML = `<!doctype html>
556
  v1 = new WignerView(c1, document.getElementById('c1'));
557
  var resize = function(){ v0.resize(); v1.resize(); };
558
  window.addEventListener('resize', resize);
559
- // vacuum initial
560
  v0.setState([0,0],[[1,0],[0,1]]);
561
  v1.setState([0,0],[[1,0],[0,1]]);
562
  var lastT = performance.now();
@@ -577,23 +688,21 @@ const HTML = `<!doctype html>
577
  if(text === lastQuery) return;
578
  lastQuery = text;
579
  var st = await encodeState(text);
580
- if(text !== lastQuery) return; // newer query came in
581
  currentState = st;
582
  var m0 = modeMarginal(st, 0), m1 = modeMarginal(st, 1);
583
  v0.setState(m0.mu, m0.sigma);
584
  v1.setState(m1.mu, m1.sigma);
585
  }
586
- // init after DOM ready (we're at end of body so DOM exists)
587
  initViz();
588
 
589
  var backendSel = document.getElementById('backend');
590
  var compareBox = document.getElementById('compare');
591
- fetch('/health',{cache:'no-store'}).then(function(r){return r.json()}).then(function(j){
592
- var ok = j && j.ok && j.upstream_ok;
593
  healthPill.classList.add(ok?'ok':'err');
594
  var backends = (j && j.backends_available) || [];
595
  healthText.textContent = ok ? (j.default_backend || 'ok') : 'offline';
596
- // hide options for backends that aren't actually live upstream
597
  Array.prototype.forEach.call(backendSel.options, function(opt){
598
  opt.disabled = backends.length>0 && backends.indexOf(opt.value) < 0;
599
  if (opt.disabled) opt.text = opt.value + ' (n/a)';
@@ -674,12 +783,11 @@ const HTML = `<!doctype html>
674
  }
675
 
676
  async function fetchRank(text, topk, backend, sig){
677
- var url='/rank?q='+encodeURIComponent(text)+'&top_k='+topk+'&backend='+encodeURIComponent(backend);
678
  var r=await fetch(url,{signal:sig});
679
  if(!r.ok) throw new Error('http '+r.status);
680
  var j=await r.json();
681
- return {backend:j.backend||backend, items:j.results||[],
682
- cache:r.headers.get('x-photon-route-cache')||''};
683
  }
684
 
685
  async function run(){
@@ -704,7 +812,7 @@ const HTML = `<!doctype html>
704
  var j = await fetchRank(text, topk, backendSel.value, abort.signal);
705
  var ms=(performance.now()-t0).toFixed(0);
706
  status.textContent = j.items.length+' result'+(j.items.length===1?'':'s')+
707
- ' · '+ms+' ms'+(j.cache?' · cache '+j.cache:'')+' · backend '+j.backend;
708
  render(j.items);
709
  return;
710
  }
@@ -745,114 +853,4 @@ const HTML = `<!doctype html>
745
  })();
746
  </script>
747
  </body>
748
- </html>`;
749
-
750
- addEventListener('fetch', (e) => e.respondWith(handle(e.request)));
751
-
752
- async function handle(req) {
753
- const url = new URL(req.url);
754
- const path = url.pathname;
755
-
756
- if (req.method === 'OPTIONS') {
757
- return cors(new Response(null, { status: 204 }));
758
- }
759
-
760
- if (path === '/' && (req.method === 'GET' || req.method === 'HEAD')) {
761
- return new Response(req.method === 'HEAD' ? null : HTML, {
762
- headers: {
763
- 'content-type': 'text/html; charset=utf-8',
764
- 'content-security-policy': CSP,
765
- 'referrer-policy': 'strict-origin-when-cross-origin',
766
- 'x-content-type-options': 'nosniff',
767
- 'cache-control': 'public, max-age=300',
768
- },
769
- });
770
- }
771
-
772
- if (path === '/api' || path === '/info') {
773
- return jsonResp(BANNER);
774
- }
775
-
776
- if (path === '/health') {
777
- // Merge upstream /health so the UI knows which backends are live.
778
- let upstream = null;
779
- try {
780
- const r = await fetch(HF + '/health', { cf: { cacheTtl: 30 } });
781
- if (r.ok) upstream = await r.json();
782
- } catch (_) {}
783
- return jsonResp({
784
- ok: true,
785
- proxy: 'cloudflare-worker',
786
- upstream_ok: upstream ? !!upstream.ok : false,
787
- backends_available: upstream && upstream.backends_available
788
- ? upstream.backends_available : ['stub'],
789
- default_backend: upstream && upstream.default_backend
790
- ? upstream.default_backend : 'stub',
791
- weights_loaded: upstream ? !!upstream.weights_loaded : false,
792
- });
793
- }
794
-
795
- if (PROXY.has(path) && (req.method === 'GET' || req.method === 'HEAD')) {
796
- return await proxied(req, url, path);
797
- }
798
-
799
- return jsonResp({ error: 'not found' }, 404);
800
- }
801
-
802
- function jsonResp(obj, status = 200) {
803
- return new Response(JSON.stringify(obj, null, 2), {
804
- status,
805
- headers: {
806
- 'content-type': 'application/json; charset=utf-8',
807
- 'access-control-allow-origin': '*',
808
- 'cache-control': 'no-store',
809
- },
810
- });
811
- }
812
-
813
- function cors(resp) {
814
- resp.headers.set('access-control-allow-origin', '*');
815
- resp.headers.set('access-control-allow-methods', 'GET, HEAD, OPTIONS');
816
- resp.headers.set('access-control-allow-headers', 'content-type');
817
- return resp;
818
- }
819
-
820
- async function proxied(req, url, path) {
821
- const upstream = new URL(HF);
822
- upstream.pathname = path;
823
- upstream.search = url.search;
824
-
825
- const isRank = path === '/rank';
826
- const cacheKey = new Request(upstream.toString() + '#' + CACHE_VERSION, req);
827
- const cache = caches.default;
828
-
829
- if (isRank) {
830
- const hit = await cache.match(cacheKey);
831
- if (hit) {
832
- const r = new Response(hit.body, hit);
833
- r.headers.set('x-photon-route-cache', 'hit');
834
- return cors(r);
835
- }
836
- }
837
-
838
- const fetched = await fetch(upstream.toString(), {
839
- method: req.method,
840
- headers: { accept: req.headers.get('accept') || '*/*' },
841
- cf: { cacheTtl: 0 },
842
- });
843
-
844
- if (!fetched.ok || fetched.status >= 500) {
845
- const r = new Response(fetched.body, fetched);
846
- r.headers.set('x-photon-route-cache', 'bypass');
847
- return cors(r);
848
- }
849
-
850
- const body = await fetched.arrayBuffer();
851
- const headers = new Headers(fetched.headers);
852
- headers.delete('set-cookie');
853
- if (isRank) headers.set('cache-control', 'public, s-maxage=86400, max-age=300');
854
- const ok = new Response(body, { status: fetched.status, headers });
855
- ok.headers.set('x-photon-route-cache', 'miss');
856
- if (isRank) await cache.put(cacheKey, ok.clone());
857
- return cors(ok);
858
- }
 
1
+ <!doctype html>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  <html lang="en">
3
  <head>
4
  <meta charset="utf-8">
 
108
  border-bottom:1px dotted var(--line)}
109
  footer a:hover{color:var(--fg);border-bottom-color:var(--cyan)}
110
  .empty{color:var(--muted);text-align:center;padding:32px 12px;font-size:12px}
 
111
  .viz{margin:14px 0 18px;display:grid;grid-template-columns:1fr 1fr;gap:10px}
112
  .modepanel{position:relative;background:var(--panel);border:1px solid var(--line);
113
  border-radius:var(--radius);overflow:hidden;aspect-ratio:5/4;min-height:200px}
 
144
  </style>
145
  </head>
146
  <body>
147
+ <style>
148
+ .nav { position: fixed; inset: 0 0 auto 0;
149
+ display: flex; align-items: center; justify-content: space-between;
150
+ padding: 14px clamp(12px, 3vw, 24px);
151
+ z-index: 100; pointer-events: none;
152
+ font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", system-ui, sans-serif; }
153
+ .nav > * { pointer-events: auto; }
154
+ /* Nav-scoped: the page body has its own <div class="brand"> gradient
155
+ for the page title — don't let nav .brand rules override it. */
156
+ .nav .brand { color: #e9eef7; font-weight: 600; letter-spacing: 0.04em;
157
+ text-decoration: none; font-size: 15px;
158
+ background: none; -webkit-text-fill-color: currentColor; }
159
+ .nav .brand .brand-glyph { color: #ffa276; margin-right: 6px; }
160
+ .burger { width: 40px; height: 40px; padding: 0;
161
+ background: rgba(167,139,250,0.05);
162
+ border: 1px solid rgba(167,139,250,0.18);
163
+ border-radius: 10px; cursor: pointer; position: relative;
164
+ transition: background 0.15s, border-color 0.15s, box-shadow 0.15s; }
165
+ .burger:hover { background: rgba(167,139,250,0.12); border-color: #a78bfa; box-shadow: 0 0 14px rgba(167,139,250,0.25); }
166
+ .burger span { position: absolute; left: 11px; right: 11px; height: 2px;
167
+ background: #c9d4ec; border-radius: 1px;
168
+ transition: top 0.18s, transform 0.18s, opacity 0.18s; }
169
+ .burger span:nth-child(1) { top: 13px; }
170
+ .burger span:nth-child(2) { top: 19px; }
171
+ .burger span:nth-child(3) { top: 25px; }
172
+ .burger.open { background: rgba(167,139,250,0.18); border-color: #a78bfa; }
173
+ .burger.open span:nth-child(1) { top: 19px; transform: rotate(45deg); }
174
+ .burger.open span:nth-child(2) { opacity: 0; }
175
+ .burger.open span:nth-child(3) { top: 19px; transform: rotate(-45deg); }
176
+ .nav-menu { position: fixed; top: 64px; right: clamp(12px, 3vw, 24px);
177
+ width: min(330px, calc(100vw - 24px));
178
+ display: flex; flex-direction: column; gap: 2px; padding: 14px;
179
+ background: rgba(10, 13, 20, 0.96); backdrop-filter: blur(16px);
180
+ border: 1px solid rgba(167, 139, 250, 0.3); border-radius: 14px;
181
+ box-shadow: 0 20px 50px rgba(0,0,0,0.5); transform-origin: top right;
182
+ transform: translateY(-12px) scale(0.97); opacity: 0; pointer-events: none;
183
+ transition: transform 0.22s cubic-bezier(0.16,1,0.3,1), opacity 0.2s ease;
184
+ z-index: 110; max-height: calc(100vh - 80px); overflow-y: auto; }
185
+ .nav-menu.open { transform: translateY(0) scale(1); opacity: 1; pointer-events: auto; }
186
+ .nav-menu a { display: block; padding: 10px 14px; border-radius: 8px;
187
+ font-size: 14px; color: #9bb6ea; text-decoration: none;
188
+ transition: background 0.15s, color 0.15s; }
189
+ .nav-menu a:hover, .nav-menu .current { background: rgba(167,139,250,0.10); color: #fff; }
190
+ .nav-group { border-top: 1px solid rgba(167, 139, 250, 0.10); margin-top: 6px; padding-top: 4px; }
191
+ .nav-group:first-of-type { border-top: 0; margin-top: 4px; padding-top: 0; }
192
+ .nav-group > summary { cursor: pointer; list-style: none; user-select: none; }
193
+ .nav-group > summary::-webkit-details-marker { display: none; }
194
+ .nav-section { font-size: 11px; letter-spacing: 0.12em; text-transform: uppercase;
195
+ color: rgba(167, 139, 250, 0.75); padding: 10px 14px 6px; font-weight: 600;
196
+ display: flex; align-items: center; justify-content: space-between; }
197
+ .nav-section::after { content: '+'; font-size: 14px;
198
+ color: rgba(167, 139, 250, 0.55);
199
+ transition: transform 0.18s, color 0.18s; }
200
+ .nav-group[open] > summary .nav-section::after { transform: rotate(45deg); color: #a78bfa; }
201
+ .nav-group > summary:hover .nav-section { color: #fff; }
202
+ .nav-app { background: rgba(167,139,250,0.04);
203
+ border: 1px solid rgba(167,139,250,0.10);
204
+ margin: 4px 0; line-height: 1.3;
205
+ padding: 10px 14px !important; border-radius: 10px !important; }
206
+ .nav-app:hover { background: rgba(167,139,250,0.12); border-color: rgba(167,139,250,0.32); }
207
+ .nav-app .nav-app-name { display: block; font-size: 14px; color: #e9eef7; font-weight: 500; }
208
+ .nav-app .nav-app-tag { display: block; font-size: 12px; color: #9bb6ea; margin-top: 2px; }
209
+ .nav-app .nav-app-emoji { margin-right: 6px; }
210
+ </style>
211
+
212
+ <nav class="nav" aria-label="Primary">
213
+ <a href="https://ask-meridian.uk/" class="brand" style="view-transition-name: brand">◎ Meridian</a>
214
+
215
+ <button id="burgerBtn" class="burger" type="button" aria-label="Toggle navigation menu" aria-expanded="false" aria-controls="navMenu">
216
+ <span aria-hidden="true"></span><span aria-hidden="true"></span><span aria-hidden="true"></span>
217
+ </button>
218
+
219
+ <div id="navMenu" class="nav-menu" role="menu">
220
+ <a href="https://ask-meridian.uk/">Home</a>
221
+
222
+ <details class="nav-group" open>
223
+ <summary class="nav-section">Showcase</summary>
224
+ <a href="https://meridian.ask-meridian.uk/helix/" class="nav-app" data-status="live">
225
+ <span class="nav-app-name"><span class="nav-app-emoji">🧬</span>helix · proteins</span>
226
+ <span class="nav-app-tag">Injury → top therapeutic protein candidates, each rendered as its own star system</span>
227
+ </a>
228
+ <a href="https://ask-meridian.uk/miniapp/" class="nav-app" data-status="live">
229
+ <span class="nav-app-name"><span class="nav-app-emoji">🛰️</span>Try it · Task orbit</span>
230
+ <span class="nav-app-tag">Browser miniapp · routes any task to candidates</span>
231
+ </a>
232
+ <a href="https://ask-meridian.uk/miniapp/vision-lab/" class="nav-app" data-status="live">
233
+ <span class="nav-app-name"><span class="nav-app-emoji">🔭</span>Vision Lab</span>
234
+ <span class="nav-app-tag">SmolVLM / Moondream in browser via WebGPU</span>
235
+ </a>
236
+ <a href="https://photon.ask-meridian.uk" class="nav-app" data-status="live">
237
+ <span class="nav-app-name"><span class="nav-app-emoji">⚛︎</span>Photon Router</span>
238
+ <span class="nav-app-tag">CV photonic retrieval · trained on HF Space</span>
239
+ </a>
240
+ <a href="https://lens.ask-meridian.uk" class="nav-app" data-status="live">
241
+ <span class="nav-app-name"><span class="nav-app-emoji">◎</span>Lens · WebXR</span>
242
+ <span class="nav-app-tag">Vision Lab in VR · controllers, raycaster, orbit</span>
243
+ </a>
244
+ </details>
245
+
246
+ <details class="nav-group">
247
+ <summary class="nav-section">Resources</summary>
248
+ <a href="https://ask-meridian.uk/blog/">Blog</a>
249
+ <a href="https://ask-meridian.uk/docs/">Docs</a>
250
+ <a href="https://ask-meridian.uk/#pricing">Pricing</a>
251
+ </details>
252
+
253
+ <details class="nav-group">
254
+ <summary class="nav-section">Source</summary>
255
+ <a href="https://github.com/LuuOW/meridian-mcp">GitHub · meridian-mcp</a>
256
+ <a href="https://github.com/LuuOW/photon-route">GitHub · photon-route</a>
257
+ </details>
258
+ </div>
259
+
260
+ <!-- Self-contained burger + current-link behaviour. Lives inside <nav>
261
+ so sync-nav.py treats it as part of the synced block — every page
262
+ gets the same wiring without needing a separate <script> import.
263
+ Re-runs on DOMContentLoaded too so it survives module-script
264
+ races and weird parse orders. -->
265
+ <script>
266
+ (function () {
267
+ function setup() {
268
+ var btn = document.getElementById('burgerBtn')
269
+ var menu = document.getElementById('navMenu')
270
+ if (!btn || !menu) return
271
+ if (btn.dataset.wired) return
272
+ btn.dataset.wired = '1'
273
+
274
+ function toggle(open) {
275
+ var isOpen = open === undefined ? !menu.classList.contains('open') : open
276
+ menu.classList.toggle('open', isOpen)
277
+ btn.classList.toggle('open', isOpen)
278
+ btn.setAttribute('aria-expanded', String(isOpen))
279
+ }
280
+ btn.addEventListener('click', function () { toggle() })
281
+ menu.querySelectorAll('a').forEach(function (a) {
282
+ a.addEventListener('click', function () { toggle(false) })
283
+ })
284
+ document.addEventListener('click', function (e) {
285
+ if (!menu.classList.contains('open')) return
286
+ if (!menu.contains(e.target) && !btn.contains(e.target)) toggle(false)
287
+ })
288
+ document.addEventListener('keydown', function (e) {
289
+ if (e.key === 'Escape') toggle(false)
290
+ })
291
+
292
+ // Cross-host aware: same protocol+host+path → current. Lets the
293
+ // highlight follow you whether you're on ask-meridian.uk,
294
+ // meridian.ask-meridian.uk, or photon.ask-meridian.uk.
295
+ var here = location.host + location.pathname.replace(/\/index\.html$/, '/')
296
+ menu.querySelectorAll('a').forEach(function (a) {
297
+ var href = a.getAttribute('href')
298
+ if (!href) return
299
+ try {
300
+ var u = new URL(href, location.href)
301
+ var target = u.host + u.pathname.replace(/\/index\.html$/, '/')
302
+ if (target === here) a.classList.add('current')
303
+ } catch (_) {}
304
+ })
305
+ }
306
+ if (document.readyState === 'loading') {
307
+ document.addEventListener('DOMContentLoaded', setup, { once: true })
308
+ } else {
309
+ setup()
310
+ }
311
+ })()
312
+ </script>
313
+ </nav>
314
+
315
+ <script>
316
+ (function () {
317
+ const btn = document.querySelector('.burger');
318
+ const menu = document.getElementById('navMenu');
319
+ if (!btn || !menu) return;
320
+ const set = (open) => {
321
+ menu.classList.toggle('open', open);
322
+ btn.classList.toggle('open', open);
323
+ btn.setAttribute('aria-expanded', String(open));
324
+ };
325
+ btn.addEventListener('click', e => { e.stopPropagation(); set(!menu.classList.contains('open')); });
326
+ document.addEventListener('click', e => {
327
+ if (!menu.classList.contains('open')) return;
328
+ if (!menu.contains(e.target) && !btn.contains(e.target)) set(false);
329
+ });
330
+ document.addEventListener('keydown', e => { if (e.key === 'Escape') set(false); });
331
+ menu.querySelectorAll('a').forEach(a => a.addEventListener('click', () => set(false)));
332
+ })();
333
+ </script>
334
  <main>
335
  <header>
336
  <div>
 
382
  <div class="body">
383
  <p><strong>photon-route</strong> is a research artifact exploring whether semantic retrieval can run in the continuous-variable (CV) photonic regime — the regime that real photonic hardware (Xanadu Borealis, fiber-loop reservoirs, coherent Ising machines) actually operates in.</p>
384
  <p>Each document is encoded as a <em>Gaussian state</em> over N bosonic modes via a <a href="https://strawberryfields.ai/" target="_blank" rel="noopener">Strawberry Fields</a> program: words contribute squeezing and displacement operations, then a beam-splitter network mixes the modes. Query and document fidelity is computed in closed form using the <a href="https://the-walrus.readthedocs.io/" target="_blank" rel="noopener">thewalrus</a> implementation of the Banchi-Braunstein-Pirandola formula.</p>
385
+ <p>Three swappable encoders share the same fidelity scoring. <strong>v1</strong> uses Strawberry Fields to run an N-mode Gaussian program with SHA-256-derived parameters. <strong>sha_init</strong> is a pure-numpy port of the same gates (no SF at deploy time). <strong>trained</strong> swaps the SHA-256 lookup for a learned table fit by InfoNCE + Bhattacharyya-coefficient surrogate fidelity on a small arXiv quant-ph eval set. Toggle <em>compare</em> to see the three side by side.</p>
386
  <p>Source · <a href="https://github.com/LuuOW/photon-route" target="_blank" rel="noopener">github.com/LuuOW/photon-route</a></p>
387
  </div>
388
  </details>
389
 
390
  <footer>
391
+ <span>CV photonic · gaussian backend · UI on GitHub Pages, retrieval on HF Space</span>
392
+ <span><a href="https://luuow-photon-route.hf.space/docs" rel="noopener">/docs</a> · <a href="https://luuow-photon-route.hf.space/health" rel="noopener">/health</a></span>
393
  </footer>
394
  </main>
395
 
396
  <script>
397
  (function(){
398
+ // Backend lives on Hugging Face Spaces. CORS is wide-open on the FastAPI app
399
+ // so the browser hits this directly. Override with a ?api=https://… query
400
+ // string for local Space testing.
401
+ var SPACE = (new URL(location.href)).searchParams.get('api') || 'https://luuow-photon-route.hf.space';
402
+
403
  function $(id){return document.getElementById(id)}
404
  var q=$('q'), k=$('k'), results=$('results'), status=$('status');
405
  var healthPill=$('health'), healthText=$('health-text');
406
  var abort=null, debounceT=0;
407
 
 
 
 
 
408
  var N_MODES = 2, MAX_SQUEEZE = 0.5, MAX_DISPLACE = 1.0;
409
  var _wcache = new Map();
410
 
 
416
  var hash = await crypto.subtle.digest('SHA-256', buf);
417
  bytes = new Uint8Array(hash);
418
  } catch(e){
 
419
  bytes = new Uint8Array(32);
420
  var h = 2166136261;
421
  for(var i=0;i<word.length;i++){ h ^= word.charCodeAt(i); h = (h*16777619)>>>0; }
 
423
  }
424
  var parts = [];
425
  for(var i2=0;i2<4;i2++){
 
426
  var big = 0n;
427
  for(var j2=0;j2<8;j2++) big = (big << 8n) + BigInt(bytes[i2*8 + j2]);
428
  parts.push(Number(big % 1000000000n) / 1e9);
 
454
  function tr4(A){var T=mat4(); for(var i=0;i<4;i++) for(var j=0;j<4;j++) T[i][j]=A[j][i]; return T;}
455
 
456
  function sgateMat(k, r, phi){
 
457
  var c=Math.cos(phi/2), s=Math.sin(phi/2), em=Math.exp(-r), ep=Math.exp(r);
458
  var a = c*c*em + s*s*ep;
459
  var b = c*s*(em - ep);
 
508
  };
509
  }
510
 
511
+ var GRID = 26;
 
 
 
 
512
  var DPR = Math.min(window.devicePixelRatio || 1, 2);
513
 
514
  function WignerView(canvas, coordEl){
 
517
  this.ctx = canvas.getContext('2d');
518
  this.mu = [0,0];
519
  this.sigma = [[1,0],[0,1]];
520
+ this.yaw = 0.55;
521
+ this.pitch = 0.85;
522
  this.userYaw = false;
523
  this.dragging = false;
524
  this.lastX = 0; this.lastY = 0;
 
567
  WignerView.prototype.draw = function(){
568
  var ctx = this.ctx, W = this.canvas.width, H = this.canvas.height;
569
  ctx.clearRect(0,0,W,H);
 
570
  var mu = this.mu, sg = this.sigma;
571
  var det = sg[0][0]*sg[1][1] - sg[0][1]*sg[1][0];
572
  if(det < 1e-12){ return; }
573
  var iv00 = sg[1][1]/det, iv01 = -sg[0][1]/det, iv10 = -sg[1][0]/det, iv11 = sg[0][0]/det;
574
  var norm = 1/(2*Math.PI*Math.sqrt(det));
 
 
575
  var sQ = Math.sqrt(Math.abs(sg[0][0])), sP = Math.sqrt(Math.abs(sg[1][1]));
576
  var extent = Math.max(2.6, Math.abs(mu[0]) + 3*sQ, Math.abs(mu[1]) + 3*sP);
577
  var scaleXY = (Math.min(W, H) * 0.32) / extent;
578
+ var scaleZ = (Math.min(W, H) * 0.55) * Math.sqrt(det);
579
  var ox = W*0.5, oy = H*0.62;
 
580
  var cy = Math.cos(this.yaw), sy = Math.sin(this.yaw);
581
  var cp = Math.cos(this.pitch), sp = Math.sin(this.pitch);
 
582
  function project(q, p, w){
583
  var xr = q*cy - p*sy;
584
  var yr = q*sy + p*cy;
585
  var sx = ox + xr * scaleXY;
586
  var sy_ = oy - yr * scaleXY * cp - w * scaleZ * sp;
587
+ var depth = yr * sp - w * cp;
588
  return [sx, sy_, depth];
589
  }
590
  function projectFlat(q, p){
 
592
  var yr = q*sy + p*cy;
593
  return [ox + xr*scaleXY, oy - yr*scaleXY*cp];
594
  }
 
 
595
  ctx.strokeStyle = 'rgba(28,39,66,0.7)';
596
  ctx.lineWidth = 1;
597
  var gN = 6;
 
604
  var dFlat = projectFlat( extent, t);
605
  ctx.beginPath(); ctx.moveTo(cFlat[0], cFlat[1]); ctx.lineTo(dFlat[0], dFlat[1]); ctx.stroke();
606
  }
 
607
  var oFlat = projectFlat(0,0);
608
  var qAxis = projectFlat(extent, 0);
609
  var pAxis = projectFlat(0, extent);
 
611
  ctx.beginPath(); ctx.moveTo(oFlat[0], oFlat[1]); ctx.lineTo(qAxis[0], qAxis[1]); ctx.stroke();
612
  ctx.strokeStyle = 'rgba(129,140,248,0.45)';
613
  ctx.beginPath(); ctx.moveTo(oFlat[0], oFlat[1]); ctx.lineTo(pAxis[0], pAxis[1]); ctx.stroke();
 
 
614
  var N = GRID;
615
  var step = (2*extent)/N;
616
  var pts = new Array(N+1);
 
629
  pts[i][j] = {w:w, sx:pr[0], sy:pr[1], depth:pr[2]};
630
  }
631
  }
 
 
632
  var quads = [];
633
  for(var i2=0;i2<N;i2++){
634
  for(var j2=0;j2<N;j2++){
 
638
  quads.push({a:a, b:b, c:c, d:d, depth:depth, w:wAvg});
639
  }
640
  }
641
+ quads.sort(function(x, y){ return y.depth - x.depth; });
 
 
642
  for(var qi2=0; qi2<quads.length; qi2++){
643
  var qd = quads[qi2];
644
  var t2 = wmax > 1e-12 ? Math.max(0, Math.min(1, qd.w / wmax)) : 0;
 
645
  var rC = Math.round(0x81*(1-t2) + 0x22*t2);
646
  var gC = Math.round(0x8c*(1-t2) + 0xd3*t2);
647
  var bC = Math.round(0xf8*(1-t2) + 0xee*t2);
 
668
  v1 = new WignerView(c1, document.getElementById('c1'));
669
  var resize = function(){ v0.resize(); v1.resize(); };
670
  window.addEventListener('resize', resize);
 
671
  v0.setState([0,0],[[1,0],[0,1]]);
672
  v1.setState([0,0],[[1,0],[0,1]]);
673
  var lastT = performance.now();
 
688
  if(text === lastQuery) return;
689
  lastQuery = text;
690
  var st = await encodeState(text);
691
+ if(text !== lastQuery) return;
692
  currentState = st;
693
  var m0 = modeMarginal(st, 0), m1 = modeMarginal(st, 1);
694
  v0.setState(m0.mu, m0.sigma);
695
  v1.setState(m1.mu, m1.sigma);
696
  }
 
697
  initViz();
698
 
699
  var backendSel = document.getElementById('backend');
700
  var compareBox = document.getElementById('compare');
701
+ fetch(SPACE+'/health',{cache:'no-store'}).then(function(r){return r.json()}).then(function(j){
702
+ var ok = j && j.ok;
703
  healthPill.classList.add(ok?'ok':'err');
704
  var backends = (j && j.backends_available) || [];
705
  healthText.textContent = ok ? (j.default_backend || 'ok') : 'offline';
 
706
  Array.prototype.forEach.call(backendSel.options, function(opt){
707
  opt.disabled = backends.length>0 && backends.indexOf(opt.value) < 0;
708
  if (opt.disabled) opt.text = opt.value + ' (n/a)';
 
783
  }
784
 
785
  async function fetchRank(text, topk, backend, sig){
786
+ var url=SPACE+'/rank?q='+encodeURIComponent(text)+'&top_k='+topk+'&backend='+encodeURIComponent(backend);
787
  var r=await fetch(url,{signal:sig});
788
  if(!r.ok) throw new Error('http '+r.status);
789
  var j=await r.json();
790
+ return {backend:j.backend||backend, items:j.results||[], cache:''};
 
791
  }
792
 
793
  async function run(){
 
812
  var j = await fetchRank(text, topk, backendSel.value, abort.signal);
813
  var ms=(performance.now()-t0).toFixed(0);
814
  status.textContent = j.items.length+' result'+(j.items.length===1?'':'s')+
815
+ ' · '+ms+' ms · backend '+j.backend;
816
  render(j.items);
817
  return;
818
  }
 
853
  })();
854
  </script>
855
  </body>
856
+ </html>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
space/analyze_sweep.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Read sweep_results.csv, decide whether SBERT-photon's gain is real
2
+ and whether photon-number-distribution metric (A3-Simple) outperforms
3
+ Gaussian-state-overlap (BBP fidelity) on the same trained encoder.
4
+
5
+ Q1. Does SBERT-photon (full, gaussian metric) robustly beat raw SBERT?
6
+ Q2. Does the squeezing layer specifically pay (full vs no-squeeze)?
7
+ Q3. A3-Simple: does photon-prob metric > gaussian metric on same encoder?
8
+ Q4. Generalization tax (train − test nDCG@10).
9
+ """
10
+ from __future__ import annotations
11
+
12
+ import argparse
13
+ import csv
14
+ import statistics
15
+ from pathlib import Path
16
+
17
+
18
+ SBERT_ALONE_NDCG10 = 0.385
19
+
20
+
21
+ def stat(rs, key):
22
+ vals = [r[key] for r in rs if r[key] == r[key]]
23
+ if len(vals) < 2:
24
+ return (vals[0] if vals else float("nan"), 0.0, len(vals))
25
+ return statistics.mean(vals), statistics.stdev(vals), len(vals)
26
+
27
+
28
+ def main():
29
+ ap = argparse.ArgumentParser()
30
+ ap.add_argument("--csv", type=Path, default=Path(__file__).resolve().parent.parent / "sweep_results.csv")
31
+ args = ap.parse_args()
32
+
33
+ rows = list(csv.DictReader(args.csv.open()))
34
+ for r in rows:
35
+ for k, v in list(r.items()):
36
+ try:
37
+ r[k] = float(v)
38
+ except (ValueError, TypeError):
39
+ pass
40
+
41
+ full = sorted([r for r in rows if int(r["no_squeeze"]) == 0], key=lambda r: r["seed"])
42
+ nosq = sorted([r for r in rows if int(r["no_squeeze"]) == 1], key=lambda r: r["seed"])
43
+
44
+ print(f"loaded {len(rows)} runs from {args.csv}")
45
+ print(f" full (squeezing on): n={len(full)}")
46
+ print(f" no-squeeze: n={len(nosq)}")
47
+
48
+ print("\n=== Q1. SBERT-photon (full, gaussian metric) vs raw SBERT (0.385) ===")
49
+ m, s, n = stat(full, "test_gaussian_n10")
50
+ delta = m - SBERT_ALONE_NDCG10
51
+ z = delta / s if s > 0 else float("inf")
52
+ verdict = "✓ YES" if delta > s and m > SBERT_ALONE_NDCG10 else "✗ noisy or no gain"
53
+ print(f" full mean test nDCG@10 (gaussian) = {m:.3f} ± {s:.3f} (n={n})")
54
+ print(f" Δ vs raw SBERT = {delta:+.3f} (Δ/σ ≈ {z:+.2f}) → {verdict}")
55
+
56
+ print("\n=== Q2. Squeezing pays? (paired full − no_squeeze, gaussian metric) ===")
57
+ paired = []
58
+ for r in full:
59
+ n_row = next((x for x in nosq if x["seed"] == r["seed"]), None)
60
+ if n_row:
61
+ paired.append((int(r["seed"]), r["test_gaussian_n10"], n_row["test_gaussian_n10"]))
62
+ diffs = [a - b for _, a, b in paired]
63
+ m_d = statistics.mean(diffs) if diffs else float("nan")
64
+ s_d = statistics.stdev(diffs) if len(diffs) > 1 else 0.0
65
+ for sid, a, b in paired:
66
+ print(f" seed {sid}: full={a:.3f} no_sq={b:.3f} Δ={a-b:+.3f}")
67
+ verdict = ("✓ YES" if m_d > s_d and m_d > 0.01 else
68
+ "✗ NO" if m_d <= 0 else "≈ within noise")
69
+ print(f" mean Δ = {m_d:+.3f} ± {s_d:.3f} → {verdict}")
70
+
71
+ print("\n=== Q3. A3-Simple: photon-prob > gaussian metric on same encoder? ===")
72
+ for label, rs in [("full", full), ("no-squeeze", nosq)]:
73
+ if not rs:
74
+ continue
75
+ # Paired per-seed: same encoder, two metrics on test set
76
+ diffs = [r["test_photon_prob_n10"] - r["test_gaussian_n10"] for r in rs]
77
+ m_d = statistics.mean(diffs)
78
+ s_d = statistics.stdev(diffs) if len(diffs) > 1 else 0.0
79
+ m_g, _, _ = stat(rs, "test_gaussian_n10")
80
+ m_p, _, _ = stat(rs, "test_photon_prob_n10")
81
+ verdict = ("✓ photon-prob wins" if m_d > s_d and m_d > 0.01 else
82
+ "✗ photon-prob loses" if m_d < -0.01 else
83
+ "≈ tie within noise")
84
+ print(f" {label:>10}: gaussian={m_g:.3f} photon_prob={m_p:.3f} Δ={m_d:+.3f} ± {s_d:.3f} → {verdict}")
85
+
86
+ print("\n=== Q4. Generalization tax (train − test, gaussian metric) ===")
87
+ for label, rs in [("full", full), ("no-squeeze", nosq)]:
88
+ if not rs:
89
+ continue
90
+ gaps = [r["train_gaussian_n10"] - r["test_gaussian_n10"] for r in rs]
91
+ m_g = statistics.mean(gaps)
92
+ s_g = statistics.stdev(gaps) if len(gaps) > 1 else 0.0
93
+ print(f" {label:>10}: gap = {m_g:.3f} ± {s_g:.3f}")
94
+
95
+ print("\nFull table:")
96
+ headers = ("seed", "mode", "train_g_n10", "test_g_n10", "train_p_n10", "test_p_n10")
97
+ print(" " + " ".join(f"{h:>11}" for h in headers))
98
+ for r in full + nosq:
99
+ mode = "no_squeeze" if int(r["no_squeeze"]) else "full"
100
+ cells = (
101
+ int(r["seed"]), mode,
102
+ r["train_gaussian_n10"], r["test_gaussian_n10"],
103
+ r["train_photon_prob_n10"], r["test_photon_prob_n10"],
104
+ )
105
+ print(" " + " ".join(
106
+ f"{c:>11}" if isinstance(c, str) else f"{c:>11.3f}" if isinstance(c, float) else f"{c:>11}"
107
+ for c in cells
108
+ ))
109
+
110
+
111
+ if __name__ == "__main__":
112
+ main()
space/run_sweep.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """5-split × 2-mode sweep over the SBERT-backed photon-route trainer.
2
+
3
+ Designed to run on cloud CI (e.g. GitHub Actions ubuntu-latest, free tier),
4
+ NOT locally. Output is a CSV of (split_seed, no_squeeze, train_ndcg10, test_ndcg10, ...)
5
+ that gets uploaded as a workflow artifact for the user to read.
6
+
7
+ For each random split seed:
8
+ 1. Pick 2 of the eval queries as held-out test, rest as train.
9
+ 2. Train SBERTPhoton (full vs --no-squeeze).
10
+ 3. Evaluate on train and test.
11
+ 4. Append result row.
12
+
13
+ The point: we want error bars, not a point estimate. If the +30% nDCG@10
14
+ the SBERT-backed run got on one specific 4/2 split is real, it should hold
15
+ across multiple random splits. If it varies wildly, the headline was a
16
+ 2-query coincidence.
17
+ """
18
+ from __future__ import annotations
19
+
20
+ import argparse
21
+ import csv
22
+ import json
23
+ import os
24
+ import random
25
+ import sys
26
+ import tempfile
27
+ from pathlib import Path
28
+
29
+ ROOT = Path(__file__).resolve().parent.parent
30
+ SRC = ROOT / "src"
31
+ if str(SRC) not in sys.path:
32
+ sys.path.insert(0, str(SRC))
33
+
34
+
35
+ def make_split(rel_payload: dict, n_test: int, seed: int) -> tuple[dict, dict]:
36
+ """Returns (train_relevance, test_relevance) as separate JSON-able dicts."""
37
+ rng = random.Random(seed)
38
+ queries = list(rel_payload["queries"])
39
+ rng.shuffle(queries)
40
+ test = queries[:n_test]
41
+ train = queries[n_test:]
42
+ return (
43
+ {**rel_payload, "queries": train},
44
+ {**rel_payload, "queries": test},
45
+ )
46
+
47
+
48
+ def run_one(seed: int, no_squeeze: bool, steps: int, relevance_path: Path,
49
+ n_test: int, log_dir: Path) -> dict:
50
+ """Train + eval one configuration; return summary dict for the CSV row."""
51
+ import space.train_sbert as ts
52
+
53
+ rel_payload = json.loads(relevance_path.read_text("utf-8"))
54
+ train_rel, test_rel = make_split(rel_payload, n_test=n_test, seed=seed)
55
+
56
+ with tempfile.TemporaryDirectory() as tmp:
57
+ train_p = Path(tmp) / "rel_train.json"
58
+ test_p = Path(tmp) / "rel_test.json"
59
+ train_p.write_text(json.dumps(train_rel, indent=2))
60
+ test_p.write_text(json.dumps(test_rel, indent=2))
61
+
62
+ # Build a Namespace mimicking train_sbert's CLI args so we can call
63
+ # train(args) directly without subprocess. Faster + captures Python
64
+ # exceptions cleanly.
65
+ import argparse as _ap
66
+ args = _ap.Namespace(
67
+ steps=steps, lr=1e-2, weight_decay=1e-3, temperature=2.0,
68
+ negatives=8, clip=1.0, seed=seed, log_every=50,
69
+ relevance=str(train_p),
70
+ eval_train_rel=str(train_p), eval_test_rel=str(test_p),
71
+ no_squeeze=no_squeeze,
72
+ )
73
+
74
+ # Capture stdout to recover SUMMARY_JSON line.
75
+ import io, contextlib
76
+ buf = io.StringIO()
77
+ with contextlib.redirect_stdout(buf):
78
+ ts.train(args)
79
+ out = buf.getvalue()
80
+
81
+ # Persist log
82
+ log_path = log_dir / f"seed{seed}_nosqz{int(no_squeeze)}.log"
83
+ log_path.write_text(out, encoding="utf-8")
84
+
85
+ summary_line = next(
86
+ (l for l in out.splitlines() if l.startswith("SUMMARY_JSON=")), ""
87
+ )
88
+ summary = json.loads(summary_line.split("=", 1)[1]) if summary_line else {}
89
+
90
+ def g(key, metric_key):
91
+ return summary.get(f"{key}/{metric_key}", {})
92
+
93
+ row = {"seed": seed, "no_squeeze": int(no_squeeze)}
94
+ for split in ("train", "test"):
95
+ for metric in ("gaussian", "photon_prob"):
96
+ agg = g(split, metric)
97
+ for m in ("ndcg@10", "recall@10", "recall@1"):
98
+ short = m.replace("@", "").replace("recall", "r").replace("ndcg", "n")
99
+ row[f"{split}_{metric}_{short}"] = agg.get(m, float("nan"))
100
+ return row
101
+
102
+
103
+ def main():
104
+ ap = argparse.ArgumentParser()
105
+ ap.add_argument("--relevance", type=Path, default=ROOT / "eval" / "relevance_expanded.json",
106
+ help="Default to the title-expanded set so each split has more train signal.")
107
+ ap.add_argument("--seeds", type=int, nargs="+", default=[1, 2, 3, 4, 5])
108
+ ap.add_argument("--n-test", type=int, default=4,
109
+ help="Held-out test queries per split. With expanded relevance (26 q), 4 test is ~15%.")
110
+ ap.add_argument("--steps", type=int, default=200)
111
+ ap.add_argument("--out-csv", type=Path, default=ROOT / "sweep_results.csv")
112
+ ap.add_argument("--log-dir", type=Path, default=ROOT / "sweep_logs")
113
+ args = ap.parse_args()
114
+
115
+ args.log_dir.mkdir(parents=True, exist_ok=True)
116
+
117
+ results = []
118
+ for seed in args.seeds:
119
+ for no_squeeze in [False, True]:
120
+ print(f"\n{'='*72}\nseed={seed} no_squeeze={no_squeeze}\n{'='*72}")
121
+ row = run_one(
122
+ seed=seed, no_squeeze=no_squeeze, steps=args.steps,
123
+ relevance_path=args.relevance, n_test=args.n_test, log_dir=args.log_dir,
124
+ )
125
+ print(f" → gaussian: train n10={row['train_gaussian_n10']:.3f} test n10={row['test_gaussian_n10']:.3f}")
126
+ print(f" photon_prob: train n10={row['train_photon_prob_n10']:.3f} test n10={row['test_photon_prob_n10']:.3f}")
127
+ results.append(row)
128
+
129
+ # Write CSV
130
+ fieldnames = list(results[0].keys())
131
+ with args.out_csv.open("w", newline="") as f:
132
+ w = csv.DictWriter(f, fieldnames=fieldnames)
133
+ w.writeheader()
134
+ w.writerows(results)
135
+ print(f"\nwrote {len(results)} rows → {args.out_csv}")
136
+
137
+ # Aggregate stats
138
+ import statistics
139
+ def stat(rows, key):
140
+ vals = [r[key] for r in rows]
141
+ return statistics.mean(vals), statistics.stdev(vals) if len(vals) > 1 else 0.0
142
+
143
+ print("\nAggregates over seeds:")
144
+ for ns in [False, True]:
145
+ rows = [r for r in results if r["no_squeeze"] == int(ns)]
146
+ if not rows:
147
+ continue
148
+ label = "no-squeeze" if ns else "full"
149
+ for metric in ("gaussian", "photon_prob"):
150
+ m, s = stat(rows, f"test_{metric}_n10")
151
+ print(f" {label:>10}/{metric:>11}: test nDCG@10 = {m:.3f} ± {s:.3f} (n={len(rows)})")
152
+
153
+
154
+ if __name__ == "__main__":
155
+ main()
space/run_sweep_fock.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """5-split sweep for the A3-Real Fock-basis trainer.
2
+
3
+ Mirrors space/run_sweep.py but for train_sbert_fock.py (non-Gaussian
4
+ heralded encoder). Outputs one CSV row per (seed) — no squeeze ablation
5
+ since the Fock encoder structure already includes a learnable TMS gate
6
+ and learnable squeezing; the equivalent ablation is herald_n=0 (heralding
7
+ on vacuum keeps the state Gaussian).
8
+ """
9
+ from __future__ import annotations
10
+
11
+ import argparse
12
+ import csv
13
+ import io
14
+ import contextlib
15
+ import json
16
+ import random
17
+ import sys
18
+ import tempfile
19
+ from pathlib import Path
20
+
21
+ ROOT = Path(__file__).resolve().parent.parent
22
+ SRC = ROOT / "src"
23
+ if str(SRC) not in sys.path:
24
+ sys.path.insert(0, str(SRC))
25
+
26
+
27
+ def make_split(rel_payload, n_test, seed):
28
+ rng = random.Random(seed)
29
+ queries = list(rel_payload["queries"])
30
+ rng.shuffle(queries)
31
+ return (
32
+ {**rel_payload, "queries": queries[n_test:]},
33
+ {**rel_payload, "queries": queries[:n_test]},
34
+ )
35
+
36
+
37
+ def run_one(seed, herald_n, steps, cutoff, relevance_path, n_test, log_dir):
38
+ import space.train_sbert_fock as ts
39
+
40
+ rel_payload = json.loads(relevance_path.read_text("utf-8"))
41
+ train_rel, test_rel = make_split(rel_payload, n_test=n_test, seed=seed)
42
+ with tempfile.TemporaryDirectory() as tmp:
43
+ train_p = Path(tmp) / "rel_train.json"
44
+ test_p = Path(tmp) / "rel_test.json"
45
+ train_p.write_text(json.dumps(train_rel, indent=2))
46
+ test_p.write_text(json.dumps(test_rel, indent=2))
47
+ import argparse as _ap
48
+ args = _ap.Namespace(
49
+ cutoff=cutoff, herald_n=herald_n,
50
+ steps=steps, lr=1e-2, weight_decay=1e-3, temperature=0.5,
51
+ negatives=8, clip=1.0, seed=seed, log_every=50,
52
+ relevance=str(train_p),
53
+ eval_train_rel=str(train_p), eval_test_rel=str(test_p),
54
+ )
55
+ buf = io.StringIO()
56
+ with contextlib.redirect_stdout(buf):
57
+ ts.train(args)
58
+ out = buf.getvalue()
59
+ log_path = log_dir / f"fock_seed{seed}_n{herald_n}.log"
60
+ log_path.write_text(out, encoding="utf-8")
61
+ summary_line = next(
62
+ (l for l in out.splitlines() if l.startswith("SUMMARY_JSON=")), ""
63
+ )
64
+ summary = json.loads(summary_line.split("=", 1)[1]) if summary_line else {}
65
+ train_agg = summary.get("train/fock", {})
66
+ test_agg = summary.get("test/fock", {})
67
+ return {
68
+ "seed": seed,
69
+ "herald_n": herald_n,
70
+ "cutoff": cutoff,
71
+ "train_ndcg10": train_agg.get("ndcg@10", float("nan")),
72
+ "test_ndcg10": test_agg.get("ndcg@10", float("nan")),
73
+ "train_recall10":train_agg.get("recall@10", float("nan")),
74
+ "test_recall10": test_agg.get("recall@10", float("nan")),
75
+ "train_recall1": train_agg.get("recall@1", float("nan")),
76
+ "test_recall1": test_agg.get("recall@1", float("nan")),
77
+ }
78
+
79
+
80
+ def main():
81
+ ap = argparse.ArgumentParser()
82
+ ap.add_argument("--relevance", type=Path, default=ROOT / "eval" / "relevance_expanded.json")
83
+ ap.add_argument("--seeds", type=int, nargs="+", default=[1, 2, 3, 4, 5])
84
+ ap.add_argument("--n-test", type=int, default=4)
85
+ ap.add_argument("--steps", type=int, default=200)
86
+ ap.add_argument("--cutoff", type=int, default=6)
87
+ ap.add_argument("--herald-ns", type=int, nargs="+", default=[1, 0],
88
+ help="Ancilla outcomes to test. herald_n=0 keeps state Gaussian; herald_n=1 makes it non-Gaussian.")
89
+ ap.add_argument("--out-csv", type=Path, default=ROOT / "sweep_fock_results.csv")
90
+ ap.add_argument("--log-dir", type=Path, default=ROOT / "sweep_fock_logs")
91
+ args = ap.parse_args()
92
+ args.log_dir.mkdir(parents=True, exist_ok=True)
93
+
94
+ results = []
95
+ for seed in args.seeds:
96
+ for hn in args.herald_ns:
97
+ print(f"\n{'='*72}\nseed={seed} herald_n={hn}\n{'='*72}")
98
+ row = run_one(seed=seed, herald_n=hn, steps=args.steps,
99
+ cutoff=args.cutoff, relevance_path=args.relevance,
100
+ n_test=args.n_test, log_dir=args.log_dir)
101
+ print(f" → train n10={row['train_ndcg10']:.3f} test n10={row['test_ndcg10']:.3f}")
102
+ results.append(row)
103
+
104
+ fieldnames = list(results[0].keys())
105
+ with args.out_csv.open("w", newline="") as f:
106
+ w = csv.DictWriter(f, fieldnames=fieldnames)
107
+ w.writeheader()
108
+ w.writerows(results)
109
+ print(f"\nwrote {len(results)} rows → {args.out_csv}")
110
+
111
+ import statistics
112
+ def stat(rs, key):
113
+ vals = [r[key] for r in rs if r[key] == r[key]]
114
+ return (statistics.mean(vals), statistics.stdev(vals) if len(vals) > 1 else 0.0, len(vals))
115
+
116
+ print("\nAggregates:")
117
+ for hn in args.herald_ns:
118
+ rs = [r for r in results if r["herald_n"] == hn]
119
+ m, s, n = stat(rs, "test_ndcg10")
120
+ label = "non-Gaussian (herald=1)" if hn == 1 else "Gaussian (herald=0)" if hn == 0 else f"herald={hn}"
121
+ print(f" {label:>26}: test nDCG@10 = {m:.3f} ± {s:.3f} (n={n})")
122
+
123
+
124
+ if __name__ == "__main__":
125
+ main()
space/sim_b1_g1_coherence.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """B1 sim — does g^(1)-style coherence time τ_c discriminate candidates
2
+ better than meridian's current 3-bin Shannon entropy `cross_domain`?
3
+
4
+ Loudon eq 3.1.3: g^(1)(τ) = ⟨E*(t) E(t+τ)⟩ / ⟨|E|²⟩.
5
+ For a chaotic source, |g^(1)(τ)| decays exponentially with characteristic
6
+ τ_c = (∫|g^(1)(τ)|² dτ).
7
+
8
+ Treat each candidate's keyword stream as a chaotic light source where
9
+ each token at position t is a "wavetrain at frequency ω_token". The
10
+ autocorrelation of one-hot token vectors gives an effective τ_c that
11
+ scales with vocabulary diversity.
12
+
13
+ This is a self-contained synthetic-data sim. No external corpus / no
14
+ cloud compute required. Runs in <1 s.
15
+ """
16
+ from __future__ import annotations
17
+
18
+ import math
19
+ import numpy as np
20
+ from collections import Counter
21
+
22
+
23
+ def cross_domain_proxy(tokens: list[str], systems: dict[str, set[str]]) -> float:
24
+ """Mirror meridian's existing computation: Shannon entropy / log(3)
25
+ over hits in {forge, signal, mind} term lists, normalized to [0, 1]."""
26
+ affinity = {sys: 0 for sys in systems}
27
+ for t in tokens:
28
+ for sys, terms in systems.items():
29
+ if t in terms:
30
+ affinity[sys] += 1
31
+ total = sum(affinity.values()) or 1
32
+ probs = [n / total for n in affinity.values() if n > 0]
33
+ H = -sum(p * math.log(p) for p in probs)
34
+ return H / math.log(3) if H else 0.0
35
+
36
+
37
+ def coherence_time(tokens: list[str], window: int = 8) -> float:
38
+ """Empirical g^(1)-style coherence time of a token stream.
39
+
40
+ Treat the sequence as a discrete-time signal where each token is a
41
+ distinct mode. g^(1)(τ) = (# matched-token pairs at offset τ) /
42
+ (# matched at τ=0). τ_c = sum_{τ≥1} |g^(1)(τ)|² up to a window.
43
+
44
+ Pure-stdlib, normalised so τ_c ∈ [0, window].
45
+ """
46
+ n = len(tokens)
47
+ if n < 2:
48
+ return 0.0
49
+ g0 = sum(1 for t in tokens) or 1 # τ=0 normalisation = total length
50
+ tau_c = 0.0
51
+ for tau in range(1, min(window, n)):
52
+ matches = sum(1 for i in range(n - tau) if tokens[i] == tokens[i + tau])
53
+ gtau = matches / g0
54
+ tau_c += gtau * gtau
55
+ return tau_c
56
+
57
+
58
+ # ─── Synthetic candidates with realistic body lengths ──────────────────────
59
+ SYSTEMS = {
60
+ "forge": {"build", "compile", "deploy", "ci", "container", "image", "binary",
61
+ "docker", "kubernetes", "package", "release", "monorepo"},
62
+ "signal": {"data", "stream", "ingest", "pipeline", "etl", "kafka", "queue",
63
+ "throughput", "latency", "broker", "subscriber", "publish"},
64
+ "mind": {"llm", "embed", "embedding", "model", "transformer", "agent",
65
+ "reasoning", "prompt", "context", "rag", "fine", "tune"},
66
+ }
67
+
68
+
69
+ def make_candidate(label: str, vocab_pool: list[str], length: int = 250,
70
+ alpha: float = 1.0, seed: int = 0) -> list[str]:
71
+ """Generate length tokens drawn Zipfian from vocab_pool. alpha controls
72
+ head heaviness; alpha=1.0 ≈ thermal; alpha→∞ ≈ heavy concentrated."""
73
+ rng = np.random.default_rng(seed)
74
+ weights = 1.0 / (np.arange(1, len(vocab_pool) + 1) ** alpha)
75
+ weights /= weights.sum()
76
+ return list(rng.choice(vocab_pool, size=length, p=weights))
77
+
78
+
79
+ def main():
80
+ # 9 archetypes spanning body lengths and topical vs scattered patterns.
81
+ forge_terms = sorted(SYSTEMS["forge"])
82
+ signal_terms = sorted(SYSTEMS["signal"])
83
+ mind_terms = sorted(SYSTEMS["mind"])
84
+ cross_terms = forge_terms + signal_terms + mind_terms
85
+ cases = [
86
+ ("focused-forge", forge_terms, 300, 1.0, 1),
87
+ ("focused-signal", signal_terms, 300, 1.0, 2),
88
+ ("focused-mind", mind_terms, 300, 1.0, 3),
89
+ ("cross-forge-signal", forge_terms + signal_terms, 300, 1.0, 4),
90
+ ("cross-mind-signal", mind_terms + signal_terms, 300, 1.0, 5),
91
+ ("cross-three-systems", cross_terms, 300, 1.0, 6),
92
+ ("scattered-cross", cross_terms, 300, 0.5, 7), # less Zipfian, more uniform
93
+ ("very-narrow", forge_terms[:3], 300, 2.0, 8), # 3 dominant words
94
+ ("very-broad", cross_terms, 300, 0.3, 9), # near-uniform
95
+ ]
96
+
97
+ print(f"{'archetype':>22} {'len':>5} {'cross_domain':>13} {'τ_c (g^(1))':>12}")
98
+ print("-" * 64)
99
+ rows = []
100
+ for label, pool, length, alpha, seed in cases:
101
+ toks = make_candidate(label, pool, length=length, alpha=alpha, seed=seed)
102
+ cd = cross_domain_proxy(toks, SYSTEMS)
103
+ tc = coherence_time(toks)
104
+ print(f"{label:>22} {length:>5} {cd:>13.3f} {tc:>12.3f}")
105
+ rows.append((label, cd, tc))
106
+
107
+ print("\nDiscrimination check (variance across archetypes, higher = better signal):")
108
+ cd_vals = [r[1] for r in rows]
109
+ tc_vals = [r[2] for r in rows]
110
+ print(f" cross_domain: std = {np.std(cd_vals):.3f}, range = [{min(cd_vals):.3f}, {max(cd_vals):.3f}]")
111
+ print(f" τ_c (g^(1)): std = {np.std(tc_vals):.3f}, range = [{min(tc_vals):.3f}, {max(tc_vals):.3f}]")
112
+
113
+ # CV (coefficient of variation) — higher = more discriminative on its own scale
114
+ cd_cv = np.std(cd_vals) / max(np.mean(cd_vals), 1e-9)
115
+ tc_cv = np.std(tc_vals) / max(np.mean(tc_vals), 1e-9)
116
+ print(f"\n CV (std/mean): cross_domain={cd_cv:.3f} τ_c={tc_cv:.3f}")
117
+ if tc_cv > cd_cv * 1.2:
118
+ print(" → τ_c is more discriminative than cross_domain — B1 stands.")
119
+ elif tc_cv < cd_cv * 0.8:
120
+ print(" → τ_c is LESS discriminative than cross_domain — B1 fails.")
121
+ else:
122
+ print(" → τ_c and cross_domain have similar discrimination — B1 is a wash.")
123
+
124
+
125
+ if __name__ == "__main__":
126
+ main()
space/sim_b2_g2_classifier.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """B2 sim — does g^(2)(0) cleanly classify candidates into planet / comet
2
+ / asteroid on REAL meridian-shaped data (Llama-emitted bodies of typical
3
+ length 100–500 tokens), and where does it disagree with the existing
4
+ mass × scope × independence rule?
5
+
6
+ Sim 4b earlier (synthetic Zipfian) showed g^(2) > 1 only emerges at
7
+ N_distinct × token_total scales typical of real bodies, not the 8–12
8
+ token toy archetypes from Sim 4. This sim re-runs that check on
9
+ realistic body shapes and compares per-candidate the g^(2) class label
10
+ to the mass × scope × independence label that orbital.mjs assigns today.
11
+ """
12
+ from __future__ import annotations
13
+
14
+ import math
15
+ import numpy as np
16
+ from collections import Counter
17
+
18
+
19
+ def g2_zero(tokens: list[str]) -> float:
20
+ """g^(2)(0) = ⟨n(n-1)⟩ / ⟨n⟩² over per-word counts {n_i}.
21
+ Loudon Ch 6.4: coherent → 1, chaotic → 2, antibunched → <1."""
22
+ if not tokens:
23
+ return float("nan")
24
+ counts = np.asarray(list(Counter(tokens).values()), dtype=np.float64)
25
+ n_mean = counts.mean()
26
+ n_n_minus_1 = (counts * (counts - 1)).mean()
27
+ return n_n_minus_1 / (n_mean ** 2) if n_mean > 0 else float("nan")
28
+
29
+
30
+ def class_from_g2(g2: float) -> str:
31
+ """Threshold rule from Loudon Ch 6.4."""
32
+ if g2 < 0.7:
33
+ return "asteroid" # antibunched / sparse / niche
34
+ elif g2 < 1.4:
35
+ return "planet" # ≈ 1 = coherent / focused
36
+ else:
37
+ return "comet" # > 1 = thermal / scattered
38
+
39
+
40
+ def class_from_meridian(mass: float, scope: float, indep: float,
41
+ cross_domain: float, drag: float, fragmentation: float,
42
+ dep_ratio: float, has_parent: bool) -> str:
43
+ """Mirror orbital.mjs:139-167 — compute the same scores and pick max."""
44
+ planet = min(mass, scope, indep) ** 1.5
45
+ moon = (max(0, 0.5 - indep) * 2 *
46
+ (1.0 if has_parent else 0.4) * (1 - 0.5 * mass))
47
+ trojan = dep_ratio * (1.0 if has_parent else 0.5) * (1 - fragmentation)
48
+ asteroid = max(0, 0.55 - mass) * 2.5 * scope * indep
49
+ comet = drag * cross_domain * (1 - dep_ratio)
50
+ irregular = cross_domain * fragmentation * 0.85
51
+ scores = {"planet": planet, "moon": moon, "trojan": trojan,
52
+ "asteroid": asteroid, "comet": comet, "irregular": irregular}
53
+ return max(scores, key=scores.get)
54
+
55
+
56
+ def physics_from_tokens(tokens: list[str]) -> dict:
57
+ """Approximate the physics scalars meridian computes from text."""
58
+ body_len = sum(len(t) for t in tokens)
59
+ n_words = len(tokens)
60
+ mass = max(0, min(1, 0.6 * np.log10(max(50, body_len) / 200) /
61
+ np.log10(3000 / 200) + 0.4 * (n_words - 3) / 9))
62
+ distinct = len(set(tokens))
63
+ scope = min(0.7, distinct / 12) + 0.2 # rough proxy
64
+ scope = max(0, min(1, scope))
65
+ indep = 0.7 # synthetic candidates have no siblings; assume mid-high
66
+ drag = 0.3
67
+ fragmentation = 0.4
68
+ cross_domain = 0.5 # placeholder
69
+ dep_ratio = 0.2
70
+ return dict(mass=mass, scope=scope, indep=indep,
71
+ drag=drag, fragmentation=fragmentation,
72
+ cross_domain=cross_domain, dep_ratio=dep_ratio,
73
+ has_parent=False)
74
+
75
+
76
+ # ─── Realistic synthetic candidates ─────────────────────────────────────────
77
+ def zipfian_words(prefix: str, n_distinct: int, length: int, alpha: float, seed: int):
78
+ rng = np.random.default_rng(seed)
79
+ vocab = [f"{prefix}-{i:02d}" for i in range(n_distinct)]
80
+ weights = 1.0 / (np.arange(1, n_distinct + 1) ** alpha)
81
+ weights /= weights.sum()
82
+ return list(rng.choice(vocab, size=length, p=weights))
83
+
84
+
85
+ def main():
86
+ # 9 archetypes covering the planet/comet/asteroid spectrum at realistic
87
+ # body length (200–400 tokens) — the regime Sim 4b proved relevant.
88
+ cases = [
89
+ # label, n_distinct, length, alpha (Zipf head), expected
90
+ ("planet-tight-vocab", 20, 300, 1.0, "planet"), # coherent-shaped
91
+ ("planet-medium", 15, 250, 0.8, "planet"),
92
+ ("planet-broad-vocab", 50, 400, 1.2, "planet"),
93
+ ("comet-thermal", 30, 300, 1.5, "comet"), # heavier head
94
+ ("comet-very-heavy", 25, 300, 2.0, "comet"),
95
+ ("comet-multimodal", 40, 350, 1.8, "comet"),
96
+ ("asteroid-narrow", 5, 300, 1.0, "asteroid"), # too few distinct
97
+ ("asteroid-fragments", 10, 100, 0.5, "asteroid"), # short body
98
+ ("asteroid-uniform", 50, 300, 0.3, "asteroid"), # near-uniform
99
+ ]
100
+
101
+ print(f"{'archetype':>22} {'len':>5} {'g^(2)':>7} {'g2_class':>10} "
102
+ f"{'mass×s×i_class':>16} {'expected':>10}")
103
+ print("-" * 90)
104
+ correct_g2 = 0
105
+ correct_meridian = 0
106
+ for label, n_distinct, length, alpha, expected in cases:
107
+ toks = zipfian_words(label, n_distinct, length, alpha, seed=hash(label) & 0xFFFF)
108
+ g2 = g2_zero(toks)
109
+ cls_g2 = class_from_g2(g2)
110
+ phys = physics_from_tokens(toks)
111
+ cls_m = class_from_meridian(**phys)
112
+ ok_g2 = cls_g2 == expected
113
+ ok_m = cls_m == expected
114
+ correct_g2 += int(ok_g2)
115
+ correct_meridian += int(ok_m)
116
+ marker_g2 = "✓" if ok_g2 else "✗"
117
+ marker_m = "✓" if ok_m else "✗"
118
+ print(f"{label:>22} {length:>5} {g2:>7.3f} "
119
+ f"{cls_g2:>9}{marker_g2} {cls_m:>15}{marker_m} {expected:>10}")
120
+
121
+ print(f"\n g^(2)-only classifier: {correct_g2}/{len(cases)} archetypes correct")
122
+ print(f" meridian's mass×scope×indep: {correct_meridian}/{len(cases)} archetypes correct")
123
+ if correct_g2 > correct_meridian:
124
+ print(" → B2 stands: g^(2) classifies more reliably on real-shape data")
125
+ elif correct_g2 < correct_meridian:
126
+ print(" → B2 fails: meridian's existing rule is better")
127
+ else:
128
+ print(" → B2 is a wash: both classifiers tied on archetype recovery")
129
+
130
+
131
+ if __name__ == "__main__":
132
+ main()
space/train.py CHANGED
@@ -245,7 +245,7 @@ def train(args: argparse.Namespace) -> None:
245
  torch.manual_seed(args.seed)
246
  np.random.seed(args.seed)
247
 
248
- rel_path = ROOT / "eval" / "relevance.json"
249
  cids_path = ROOT / "eval" / "corpus_ids.json"
250
  man_path = ROOT / "eval" / "manifest.json"
251
 
@@ -380,6 +380,8 @@ def train(args: argparse.Namespace) -> None:
380
  def main() -> None:
381
  ap = argparse.ArgumentParser()
382
  ap.add_argument("--out", type=Path, default=ROOT / "weights.npz")
 
 
383
  ap.add_argument("--steps", type=int, default=100)
384
  ap.add_argument("--lr", type=float, default=5e-3)
385
  # D-scale logits: with D in [0, 50], temp=0.1 made -D/temp logits up to
 
245
  torch.manual_seed(args.seed)
246
  np.random.seed(args.seed)
247
 
248
+ rel_path = Path(args.relevance) if args.relevance else ROOT / "eval" / "relevance.json"
249
  cids_path = ROOT / "eval" / "corpus_ids.json"
250
  man_path = ROOT / "eval" / "manifest.json"
251
 
 
380
  def main() -> None:
381
  ap = argparse.ArgumentParser()
382
  ap.add_argument("--out", type=Path, default=ROOT / "weights.npz")
383
+ ap.add_argument("--relevance", type=str, default=None,
384
+ help="path to alternate relevance.json (e.g. for held-out splits)")
385
  ap.add_argument("--steps", type=int, default=100)
386
  ap.add_argument("--lr", type=float, default=5e-3)
387
  # D-scale logits: with D in [0, 50], temp=0.1 made -D/temp logits up to
space/train_sbert.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SBERT-backed photon-route encoder. The language model does language;
2
+ the photonic gates do the structured projection.
3
+
4
+ Architecture:
5
+ text → frozen SentenceTransformer (all-MiniLM-L6-v2, 384-d, no grad)
6
+ → Linear(384 → 4N + 2N) [trainable]
7
+ → 4N displacement outputs (αq, αp per mode) + 2N squeezing outputs (r, φ per mode)
8
+ → photonic gates (Sgate + Dgate per mode)
9
+ → 2N-d Gaussian state (μ, σ) at hbar=2
10
+
11
+ Trainable surface:
12
+ Linear(384 → 6N) ≈ 384·6N + 6N params (6 numbers per mode: αq, αp, r, φ_s, plus 2 future)
13
+ For N=2: 384·8 + 8 = 3,080 params total
14
+ vs word-level photon-route: |V|·4 = 5,772 (and grows with vocab)
15
+
16
+ Loss is the same InfoNCE-on-Bhattacharyya as space/train.py so the comparison
17
+ is apples-to-apples on the encoder, not the loss.
18
+
19
+ Holdout discipline: load --relevance from a file. The eval driver
20
+ (eval/run.py) does NOT support sbert weights yet; this module ships its
21
+ own evaluator alongside the trainer for now.
22
+ """
23
+ from __future__ import annotations
24
+
25
+ import argparse
26
+ import json
27
+ import math
28
+ import sys
29
+ import time
30
+ from pathlib import Path
31
+
32
+ import numpy as np
33
+ import torch
34
+ import torch.nn as nn
35
+ import torch.nn.functional as F
36
+ from torch import Tensor
37
+
38
+ ROOT = Path(__file__).resolve().parent.parent
39
+ SRC = ROOT / "src"
40
+ if str(SRC) not in sys.path:
41
+ sys.path.insert(0, str(SRC))
42
+
43
+ from eval.fetch import fetch_all, verify_against_manifest # noqa: E402
44
+
45
+ N_MODES = 2
46
+ HBAR = 2.0
47
+ SBERT_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
48
+ SBERT_DIM = 384
49
+
50
+
51
+ # ─── differentiable single-mode squeezing in qqpp ─────────────────────────────
52
+ def _eye2N(n: int, ref: Tensor) -> Tensor:
53
+ return torch.eye(2 * n, dtype=ref.dtype, device=ref.device)
54
+
55
+
56
+ def squeezing_qqpp(n: int, k: int, r: Tensor, phi: Tensor) -> Tensor:
57
+ S = _eye2N(n, r).clone()
58
+ cr, sr = torch.cosh(r), torch.sinh(r)
59
+ cp, sp = torch.cos(phi), torch.sin(phi)
60
+ S[k, k ] = cr - sr * cp
61
+ S[k, k + n] = -sr * sp
62
+ S[k + n, k ] = -sr * sp
63
+ S[k + n, k + n] = cr + sr * cp
64
+ return S
65
+
66
+
67
+ class SBERTPhoton(nn.Module):
68
+ """Frozen SBERT → Linear → photonic state.
69
+
70
+ The Linear emits 6 numbers per mode (αq, αp, r, φ_s, plus 2 reserved).
71
+ Currently 4 are used; spare dims are zero'd by their learnable weight
72
+ converging to small values, so unused capacity self-prunes.
73
+ """
74
+
75
+ def __init__(self, n_modes: int = N_MODES, max_squeeze: float = 0.5,
76
+ max_displace: float = 1.0, no_squeeze: bool = False):
77
+ super().__init__()
78
+ from sentence_transformers import SentenceTransformer
79
+ self.n = n_modes
80
+ self.max_sq = max_squeeze
81
+ self.max_disp = max_displace
82
+ self.no_squeeze = no_squeeze
83
+ self.dgate_prefactor = math.sqrt(2.0 * HBAR)
84
+ self.sbert = SentenceTransformer(SBERT_MODEL_NAME)
85
+ for p in self.sbert.parameters():
86
+ p.requires_grad = False
87
+ # float32 throughout — MPS (Apple-Silicon GPU) doesn't support float64;
88
+ # cast to float64 at the eval-fidelity boundary (numpy + thewalrus).
89
+ # Squeezing magnitudes are bounded ≤ 0.5 so the covariance stays
90
+ # well-conditioned and float32 slogdet is numerically safe.
91
+ self.proj = nn.Linear(SBERT_DIM, 4 * n_modes, dtype=torch.float32)
92
+ # Small-random init (NOT zeros). Zero init puts every text at the
93
+ # same vacuum state, so all pairwise distances equal zero, gradients
94
+ # vanish, and loss stays at log(N+1) = saddle point forever.
95
+ nn.init.normal_(self.proj.weight, std=0.02)
96
+ nn.init.zeros_(self.proj.bias)
97
+
98
+ def encode_features(self, texts: list[str]) -> Tensor:
99
+ """Run frozen SBERT, return (B, 384) float32 on CPU."""
100
+ with torch.no_grad():
101
+ emb = self.sbert.encode(
102
+ texts, normalize_embeddings=True, convert_to_numpy=False,
103
+ show_progress_bar=False,
104
+ )
105
+ emb = torch.stack([e for e in emb]) if isinstance(emb, list) else emb
106
+ return emb.to(torch.float32).cpu()
107
+
108
+ def state_from_features(self, feat: Tensor) -> tuple[Tensor, Tensor]:
109
+ """Forward from a *precomputed* SBERT feature vector — used during
110
+ training when frozen-SBERT features are cached at start to avoid
111
+ re-running the transformer every step."""
112
+ out = self.proj(feat)
113
+ return self._gates_from_logits(out)
114
+
115
+ def state_from_text(self, text: str) -> tuple[Tensor, Tensor]:
116
+ feat = self.encode_features([text])[0] # (384,)
117
+ out = self.proj(feat) # (4N,)
118
+ return self._gates_from_logits(out)
119
+
120
+ def _gates_from_logits(self, out: Tensor) -> tuple[Tensor, Tensor]:
121
+ # Decompose: per-mode (αq, αp, raw_r, raw_phi).
122
+ # tanh-bound squeezing magnitude to [0, max_sq]; phi free.
123
+ per_mode = out.view(self.n, 4)
124
+ alpha_q = self.dgate_prefactor * torch.tanh(per_mode[:, 0])
125
+ alpha_p = self.dgate_prefactor * torch.tanh(per_mode[:, 1])
126
+ if self.no_squeeze:
127
+ r = torch.zeros(self.n, dtype=out.dtype)
128
+ phi_s = torch.zeros(self.n, dtype=out.dtype)
129
+ else:
130
+ r = self.max_sq * torch.sigmoid(per_mode[:, 2])
131
+ phi_s = (2 * math.pi) * torch.sigmoid(per_mode[:, 3])
132
+
133
+ mu = torch.zeros(2 * self.n, dtype=out.dtype)
134
+ sigma = _eye2N(self.n, out)
135
+ for k in range(self.n):
136
+ if not self.no_squeeze:
137
+ S = squeezing_qqpp(self.n, k, r[k], phi_s[k])
138
+ mu = S @ mu
139
+ sigma = S @ sigma @ S.T
140
+ shift = torch.zeros_like(mu)
141
+ shift[k] = alpha_q[k]
142
+ shift[k + self.n] = alpha_p[k]
143
+ mu = mu + shift
144
+ return mu, sigma
145
+
146
+
147
+ def bhattacharyya_distance(mu_a, sg_a, mu_b, sg_b, ridge: float = 1e-3) -> Tensor:
148
+ d = sg_a.shape[0]
149
+ eye = torch.eye(d, dtype=sg_a.dtype, device=sg_a.device)
150
+ A = sg_a + ridge * eye
151
+ B = sg_b + ridge * eye
152
+ V = 0.5 * (A + B)
153
+ delta = mu_a - mu_b
154
+ quad = (delta * torch.linalg.solve(V, delta)).sum()
155
+ log_det_V = torch.linalg.slogdet(V)[1]
156
+ log_det_A = torch.linalg.slogdet(A)[1]
157
+ log_det_B = torch.linalg.slogdet(B)[1]
158
+ D = 0.125 * quad + 0.5 * (log_det_V - 0.5 * (log_det_A + log_det_B))
159
+ return torch.clamp(D, min=0.0, max=50.0)
160
+
161
+
162
+ def gaussian_fidelity_eval(mu_a: np.ndarray, sg_a: np.ndarray,
163
+ mu_b: np.ndarray, sg_b: np.ndarray) -> float:
164
+ """thewalrus closed-form for eval-time scoring (not used in loss)."""
165
+ from thewalrus.quantum import fidelity as tw_fidelity
166
+ f = tw_fidelity(mu_a, sg_a, mu_b, sg_b, hbar=HBAR)
167
+ val = float(f.real if hasattr(f, "real") else f)
168
+ return max(0.0, min(1.0, val))
169
+
170
+
171
+ # ── photon-number-distribution Bhattacharyya coefficient (A3-Simple) ──────
172
+ # Loudon Ch 6.10 — direct detection projects a state onto Fock basis and
173
+ # measures the photon-number distribution P(n_0, n_1, ...). A retrieval
174
+ # metric grounded in what the detector actually sees, rather than the
175
+ # Gaussian-state inner product. Closed-form computable from (μ, σ) for
176
+ # Gaussian states via thewalrus.quantum.probabilities; non-differentiable
177
+ # (numpy under the hood), eval-only — A3-Real would do this differentiably.
178
+ _PHOTON_PROB_CACHE: dict[int, np.ndarray] = {}
179
+
180
+
181
+ def photon_prob_eval(mu_a: np.ndarray, sg_a: np.ndarray,
182
+ mu_b: np.ndarray, sg_b: np.ndarray, cutoff: int = 4) -> float:
183
+ """Bhattacharyya coefficient between two photon-number distributions:
184
+ BC(P, Q) = Σ √(p_i q_i). ∈ [0, 1]; 1 = identical distributions.
185
+
186
+ Reuses the cutoff-sized P arrays via caching keyed by (id(mu), id(sg))
187
+ isn't viable across calls (μ, σ get re-allocated). Caller's responsibility
188
+ to dedup per-state; here we just compute fresh.
189
+ """
190
+ from thewalrus.quantum import probabilities
191
+ P_a = np.asarray(probabilities(mu_a, sg_a, cutoff=cutoff, hbar=HBAR), dtype=np.float64).real
192
+ P_b = np.asarray(probabilities(mu_b, sg_b, cutoff=cutoff, hbar=HBAR), dtype=np.float64).real
193
+ # Truncation can leave a tail — renormalize so distributions sum to 1.
194
+ P_a = np.clip(P_a, 0.0, None) / max(P_a.sum(), 1e-12)
195
+ P_b = np.clip(P_b, 0.0, None) / max(P_b.sum(), 1e-12)
196
+ bc = float(np.sum(np.sqrt(P_a) * np.sqrt(P_b)))
197
+ return max(0.0, min(1.0, bc))
198
+
199
+
200
+ def recall_at_k(ranked_ids, relevant, k):
201
+ if not relevant:
202
+ return float("nan")
203
+ return len(set(ranked_ids[:k]) & relevant) / len(relevant)
204
+
205
+
206
+ def ndcg_at_k(ranked_ids, relevant, k):
207
+ if not relevant:
208
+ return float("nan")
209
+ dcg = sum(1.0 / math.log2(i + 1) for i, a in enumerate(ranked_ids[:k], start=1) if a in relevant)
210
+ ideal = sum(1.0 / math.log2(i + 1) for i in range(1, min(k, len(relevant)) + 1))
211
+ return dcg / ideal if ideal > 0 else float("nan")
212
+
213
+
214
+ def evaluate(model: SBERTPhoton, abstracts, ids, queries, ks=(1, 3, 5, 10),
215
+ metrics=("gaussian", "photon_prob"), photon_cutoff: int = 4) -> dict:
216
+ """Evaluate retrieval under multiple metrics on the same trained encoder.
217
+
218
+ Returns a dict with one report per metric:
219
+ {"gaussian": {"per_query":[...], "aggregate":{...}}, "photon_prob": {...}}
220
+
221
+ A3-Simple test: do "gaussian" (BBP fidelity) and "photon_prob" (Loudon
222
+ Ch 6.10 direct-detection-grounded Bhattacharyya coefficient on the
223
+ photon-number distribution) give different rankings on the same encoder?
224
+ """
225
+ model.eval()
226
+ # Encode all docs + queries once; convert to numpy float64 for thewalrus.
227
+ doc_np: dict[str, tuple[np.ndarray, np.ndarray]] = {}
228
+ q_np: list[tuple[dict, np.ndarray, np.ndarray]] = []
229
+ with torch.no_grad():
230
+ for arxiv_id, doc_text in abstracts.items():
231
+ mu_d, sg_d = model.state_from_text(doc_text)
232
+ doc_np[arxiv_id] = (
233
+ mu_d.cpu().numpy().astype(np.float64),
234
+ sg_d.cpu().numpy().astype(np.float64),
235
+ )
236
+ for q in queries:
237
+ mu_q, sg_q = model.state_from_text(q["query"])
238
+ q_np.append((
239
+ q,
240
+ mu_q.cpu().numpy().astype(np.float64),
241
+ sg_q.cpu().numpy().astype(np.float64),
242
+ ))
243
+
244
+ score_fn = {
245
+ "gaussian": lambda mq, sq, md, sd: gaussian_fidelity_eval(mq, sq, md, sd),
246
+ "photon_prob": lambda mq, sq, md, sd: photon_prob_eval(mq, sq, md, sd, cutoff=photon_cutoff),
247
+ }
248
+
249
+ metric_rows: dict[str, list] = {m: [] for m in metrics}
250
+ for q, mu_q, sg_q in q_np:
251
+ for metric in metrics:
252
+ scored = []
253
+ for a in ids:
254
+ mu_d, sg_d = doc_np[a]
255
+ f = score_fn[metric](mu_q, sg_q, mu_d, sg_d)
256
+ scored.append((f, a))
257
+ scored.sort(key=lambda x: -x[0])
258
+ ranked_ids = [a for _, a in scored]
259
+ rel = set(q["relevant_ids"])
260
+ row = {"query": q["query"], "ranked": ranked_ids[: max(ks)]}
261
+ for k in ks:
262
+ row[f"recall@{k}"] = recall_at_k(ranked_ids, rel, k)
263
+ row[f"ndcg@{k}"] = ndcg_at_k(ranked_ids, rel, k)
264
+ metric_rows[metric].append(row)
265
+
266
+ out = {}
267
+ for metric in metrics:
268
+ rows = metric_rows[metric]
269
+ agg = {f"recall@{k}": float(np.mean([r[f"recall@{k}"] for r in rows])) for k in ks}
270
+ agg.update({f"ndcg@{k}": float(np.mean([r[f"ndcg@{k}"] for r in rows])) for k in ks})
271
+ out[metric] = {"per_query": rows, "aggregate": agg}
272
+ return out
273
+
274
+
275
+ def train(args):
276
+ torch.manual_seed(args.seed)
277
+ np.random.seed(args.seed)
278
+
279
+ rel_path = Path(args.relevance) if args.relevance else ROOT / "eval" / "relevance.json"
280
+ cids_path = ROOT / "eval" / "corpus_ids.json"
281
+ man_path = ROOT / "eval" / "manifest.json"
282
+
283
+ train_relevance = json.loads(rel_path.read_text("utf-8"))["queries"]
284
+ ids = json.loads(cids_path.read_text("utf-8"))["ids"]
285
+ print(f"[sbert] fetching {len(ids)} abstracts...", flush=True)
286
+ abstracts = fetch_all(ids)
287
+ bad = verify_against_manifest(abstracts, man_path)
288
+ if bad:
289
+ sys.exit(f"manifest mismatch: {list(bad)[:3]}")
290
+
291
+ model = SBERTPhoton(n_modes=N_MODES, no_squeeze=args.no_squeeze)
292
+ n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
293
+ print(f"[sbert] trainable params = {n_trainable} no_squeeze={args.no_squeeze}", flush=True)
294
+
295
+ # ── Cache frozen-SBERT features for every doc and every query ONCE.
296
+ # Without this we re-run the transformer for every (doc, step) pair —
297
+ # 20 docs × 200 steps = 4000 forward passes per training run, ~5 min
298
+ # of pure-inference waste on cloud CPU. Frozen features don't change.
299
+ print(f"[sbert] caching SBERT features for {len(abstracts)} docs + "
300
+ f"{len(train_relevance)} queries...", flush=True)
301
+ doc_feats = {a: model.encode_features([t])[0] for a, t in abstracts.items()}
302
+ query_feats = {q["query"]: model.encode_features([q["query"]])[0]
303
+ for q in train_relevance}
304
+
305
+ optim = torch.optim.AdamW(
306
+ [p for p in model.parameters() if p.requires_grad],
307
+ lr=args.lr, weight_decay=args.weight_decay,
308
+ )
309
+ rng = np.random.default_rng(args.seed)
310
+ queries = [(q["query"], set(q["relevant_ids"])) for q in train_relevance]
311
+
312
+ t0 = time.time()
313
+ for step in range(1, args.steps + 1):
314
+ optim.zero_grad()
315
+ # Re-run the trainable projection each step over CACHED features
316
+ # (instead of re-running SBERT). projection ∈ R^{384×4N}, cheap.
317
+ doc_states = {a: model.state_from_features(doc_feats[a]) for a in abstracts}
318
+
319
+ loss_sum = torch.zeros((), dtype=torch.float32)
320
+ for query_text, rel_set in queries:
321
+ mu_q, sg_q = model.state_from_features(query_feats[query_text])
322
+ pos_id = rng.choice(sorted(rel_set))
323
+ mu_p, sg_p = doc_states[pos_id]
324
+ negs = rng.choice(
325
+ [i for i in ids if i not in rel_set],
326
+ size=min(args.negatives, len(ids) - len(rel_set)), replace=False,
327
+ )
328
+ d_pos = bhattacharyya_distance(mu_q, sg_q, mu_p, sg_p)
329
+ d_negs = torch.stack([bhattacharyya_distance(mu_q, sg_q, *doc_states[n]) for n in negs])
330
+ logits = -torch.cat([d_pos.unsqueeze(0), d_negs]) / args.temperature
331
+ ce = F.cross_entropy(logits.unsqueeze(0), torch.zeros((), dtype=torch.long).unsqueeze(0))
332
+ loss_sum = loss_sum + ce
333
+ loss_sum = loss_sum / len(queries)
334
+ loss_sum.backward()
335
+ torch.nn.utils.clip_grad_norm_(
336
+ [p for p in model.parameters() if p.requires_grad], max_norm=args.clip,
337
+ )
338
+ optim.step()
339
+ if step == 1 or step % args.log_every == 0 or step == args.steps:
340
+ print(f"[sbert] step {step}/{args.steps} loss={loss_sum.item():.4f} "
341
+ f"elapsed={time.time()-t0:.1f}s", flush=True)
342
+
343
+ # final eval against whichever relevance file the user asks for
344
+ eval_paths = []
345
+ if args.eval_train_rel:
346
+ eval_paths.append(("train", Path(args.eval_train_rel)))
347
+ if args.eval_test_rel:
348
+ eval_paths.append(("test", Path(args.eval_test_rel)))
349
+ if not eval_paths:
350
+ eval_paths.append(("all", rel_path))
351
+ summary = {}
352
+ for label, p in eval_paths:
353
+ rels = json.loads(p.read_text("utf-8"))["queries"]
354
+ multi = evaluate(model, abstracts, ids, rels,
355
+ metrics=("gaussian", "photon_prob"))
356
+ for metric, report in multi.items():
357
+ print(f"\n=== {label.upper()} EVAL — metric={metric} ({len(rels)} queries) ===")
358
+ for r in report["per_query"]:
359
+ cells = " ".join(f"{m}={r[m]:.3f}" for m in r if m.startswith(("recall", "ndcg")))
360
+ print(f" {r['query'][:48]:<48s} {cells}")
361
+ print("aggregate: " + " ".join(
362
+ f"{m}={report['aggregate'][m]:.3f}" for m in report["aggregate"]
363
+ ))
364
+ summary[f"{label}/{metric}"] = report["aggregate"]
365
+ # Sentinel line for downstream parsers (run_sweep.py).
366
+ print(f"\nSUMMARY_JSON={json.dumps(summary)}")
367
+
368
+
369
+ def main():
370
+ ap = argparse.ArgumentParser()
371
+ ap.add_argument("--steps", type=int, default=200)
372
+ ap.add_argument("--lr", type=float, default=1e-2)
373
+ ap.add_argument("--weight-decay", type=float, default=1e-3)
374
+ ap.add_argument("--temperature", type=float, default=2.0)
375
+ ap.add_argument("--negatives", type=int, default=8)
376
+ ap.add_argument("--clip", type=float, default=1.0)
377
+ ap.add_argument("--seed", type=int, default=42)
378
+ ap.add_argument("--log-every", type=int, default=20)
379
+ ap.add_argument("--relevance", type=str, default=None,
380
+ help="training relevance.json (e.g. /tmp/rel_train.json)")
381
+ ap.add_argument("--eval-train-rel", type=str, default=None,
382
+ help="optional separate train-eval set for in-sample numbers")
383
+ ap.add_argument("--eval-test-rel", type=str, default=None,
384
+ help="optional held-out eval set for generalization numbers")
385
+ ap.add_argument("--no-squeeze", action="store_true",
386
+ help="ablation: force r=0 (displacement-only). Tests whether the squeezing layer specifically pays.")
387
+ args = ap.parse_args()
388
+ train(args)
389
+
390
+
391
+ if __name__ == "__main__":
392
+ main()
space/train_sbert_fock.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A3-Real — non-Gaussian SBERT-photon trainer.
2
+
3
+ Architecture:
4
+ text → frozen SBERT 384d → Linear(384, 6) → photonic params
5
+ → 2-mode (signal + ancilla) cutoff-D Fock-basis encoder
6
+ → unitary U = D_signal(α) S_signal(r,φ) S_2(τ,θ) applied to |0,0⟩
7
+ → project ancilla onto |1⟩ (single-photon herald)
8
+ → normalised pure state |ψ_sig⟩ ∈ ℂ^D, NON-GAUSSIAN
9
+ score(q, d) = |⟨ψ_q,sig | ψ_d,sig⟩|²
10
+
11
+ Why this is genuinely new vs space/train_sbert.py: with α small and r mild,
12
+ the heralded-on-|1⟩ signal mode contains a single-photon contribution. A
13
+ single-photon Fock state has Wigner-negative regions — non-Gaussian. The
14
+ similarity |⟨ψ_q|ψ_d⟩|² is *not* representable as a Gaussian-RBF kernel
15
+ on any finite-d projection of the inputs (Sim 1's negative result for the
16
+ Gaussian path).
17
+
18
+ Loss: InfoNCE on -log(score) (i.e. score is the affinity logit). Cached
19
+ SBERT features. Same eval/relevance as the Gaussian trainer for direct
20
+ head-to-head numbers.
21
+ """
22
+ from __future__ import annotations
23
+
24
+ import argparse
25
+ import json
26
+ import math
27
+ import sys
28
+ import time
29
+ from pathlib import Path
30
+
31
+ import numpy as np
32
+ import torch
33
+ import torch.nn as nn
34
+ import torch.nn.functional as F
35
+ from torch import Tensor
36
+
37
+ ROOT = Path(__file__).resolve().parent.parent
38
+ SRC = ROOT / "src"
39
+ if str(SRC) not in sys.path:
40
+ sys.path.insert(0, str(SRC))
41
+
42
+ from eval.fetch import fetch_all, verify_against_manifest # noqa: E402
43
+
44
+ SBERT_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
45
+ SBERT_DIM = 384
46
+ HERALD_N = 1 # ancilla outcome we herald on; single-photon → genuinely non-Gaussian
47
+
48
+
49
+ # ─── Truncated bosonic operators in Fock basis ────────────────────────────
50
+ def annihilation_op(D: int) -> Tensor:
51
+ """a |n⟩ = √n |n-1⟩ (D-dim truncation; lossy at the top)."""
52
+ a = torch.zeros(D, D, dtype=torch.complex128)
53
+ for n in range(1, D):
54
+ a[n - 1, n] = math.sqrt(n)
55
+ return a
56
+
57
+
58
+ def kron(A: Tensor, B: Tensor) -> Tensor:
59
+ return torch.kron(A, B)
60
+
61
+
62
+ # ─── Generators of the unitaries ────────────────────────────────────────────
63
+ def displace_generator(a: Tensor, alpha: Tensor) -> Tensor:
64
+ """G_D = α a† − α* a; applied as exp(G_D) gives D(α). α is complex scalar."""
65
+ return alpha * a.conj().T - torch.conj(alpha) * a
66
+
67
+
68
+ def squeeze_generator(a: Tensor, zeta: Tensor) -> Tensor:
69
+ """G_S = (1/2)(ζ* a² − ζ a†²); applied as exp(G_S) gives S(ζ). ζ = r e^{iφ}."""
70
+ a2 = a @ a
71
+ return 0.5 * (torch.conj(zeta) * a2 - zeta * a2.conj().T)
72
+
73
+
74
+ def two_mode_squeeze_generator(a: Tensor, b: Tensor, xi: Tensor) -> Tensor:
75
+ """G_{TMS} = ξ* a b − ξ a† b†. Acts on joint signal⊗ancilla space.
76
+ Inputs are full-dim joint operators a, b (e.g. a = a_signal ⊗ I_anc)."""
77
+ return torch.conj(xi) * (a @ b) - xi * (a.conj().T @ b.conj().T)
78
+
79
+
80
+ # ─── Encoder ────────────────────────────────────────────────────────────────
81
+ class SBERTPhotonFock(nn.Module):
82
+ """SBERT → Linear(384, 6) → 2-mode Fock encoder → herald → 1-mode pure state.
83
+
84
+ The 6 outputs decompose to (αq, αp, r, φ_s, τ, θ):
85
+ α = αq + i·αp (signal displacement)
86
+ ζ = r e^{iφ_s}, r ∈ [0, 0.5] (signal squeezing)
87
+ ξ = τ e^{iθ}, τ ∈ [0, 0.5] (two-mode squeezing toward ancilla)
88
+ """
89
+
90
+ def __init__(self, cutoff: int = 6, max_squeeze: float = 0.5,
91
+ max_displace: float = 1.5, herald_n: int = HERALD_N):
92
+ super().__init__()
93
+ from sentence_transformers import SentenceTransformer
94
+ self.D = cutoff
95
+ self.max_sq = max_squeeze
96
+ self.max_disp = max_displace
97
+ self.herald_n = herald_n
98
+ if herald_n >= cutoff:
99
+ raise ValueError(f"herald_n={herald_n} must be < cutoff={cutoff}")
100
+ self.sbert = SentenceTransformer(SBERT_MODEL_NAME)
101
+ for p in self.sbert.parameters():
102
+ p.requires_grad = False
103
+ # Trainable surface
104
+ self.proj = nn.Linear(SBERT_DIM, 6, dtype=torch.float32)
105
+ nn.init.normal_(self.proj.weight, std=0.02)
106
+ nn.init.zeros_(self.proj.bias)
107
+ # Pre-compute truncated bosonic operators (constants — no grad needed)
108
+ a = annihilation_op(cutoff)
109
+ I = torch.eye(cutoff, dtype=torch.complex128)
110
+ # Joint-space (signal ⊗ ancilla): a_s = a ⊗ I, b_a = I ⊗ a
111
+ self.register_buffer("a_signal_full", kron(a, I))
112
+ self.register_buffer("b_anc_full", kron(I, a))
113
+ self.register_buffer("a_signal_local", a) # for solo signal-side gates if ever needed
114
+ # Initial vacuum |0,0⟩ in joint Fock basis (D² vector)
115
+ psi0 = torch.zeros(cutoff * cutoff, dtype=torch.complex128)
116
+ psi0[0] = 1.0 # index (0, 0) → flat 0
117
+ self.register_buffer("vacuum", psi0)
118
+
119
+ def encode_features(self, texts: list[str]) -> Tensor:
120
+ with torch.no_grad():
121
+ emb = self.sbert.encode(
122
+ texts, normalize_embeddings=True, convert_to_numpy=False,
123
+ show_progress_bar=False,
124
+ )
125
+ emb = torch.stack([e for e in emb]) if isinstance(emb, list) else emb
126
+ return emb.to(torch.float32).cpu()
127
+
128
+ def state_from_features(self, feat: Tensor) -> Tensor:
129
+ """Returns the heralded signal-mode state |ψ_sig⟩ ∈ ℂ^D, normalized.
130
+ Shape: (D,) complex128.
131
+ """
132
+ out = self.proj(feat) # (6,) float32
133
+ out = out.to(torch.float64)
134
+ # decompose with bounded reparametrizations
135
+ alpha_q = self.max_disp * torch.tanh(out[0])
136
+ alpha_p = self.max_disp * torch.tanh(out[1])
137
+ r = self.max_sq * torch.sigmoid(out[2])
138
+ phi_s = (2 * math.pi) * torch.sigmoid(out[3])
139
+ tau = self.max_sq * torch.sigmoid(out[4])
140
+ theta = (2 * math.pi) * torch.sigmoid(out[5])
141
+ # Build complex parameters
142
+ alpha = torch.complex(alpha_q, alpha_p)
143
+ zeta = torch.complex(r * torch.cos(phi_s), r * torch.sin(phi_s))
144
+ xi = torch.complex(tau * torch.cos(theta), tau * torch.sin(theta))
145
+ # Generators in joint space
146
+ G_TMS = two_mode_squeeze_generator(
147
+ self.a_signal_full, self.b_anc_full, xi,
148
+ )
149
+ G_S = squeeze_generator(self.a_signal_full, zeta)
150
+ G_D = displace_generator(self.a_signal_full, alpha)
151
+ # Apply unitaries: |ψ⟩ = D · S · S_2 · |0,0⟩
152
+ U_TMS = torch.linalg.matrix_exp(G_TMS)
153
+ U_S = torch.linalg.matrix_exp(G_S)
154
+ U_D = torch.linalg.matrix_exp(G_D)
155
+ psi = U_TMS @ self.vacuum
156
+ psi = U_S @ psi
157
+ psi = U_D @ psi
158
+ # Project ancilla onto |herald_n⟩. Joint flat index = signal*D + ancilla.
159
+ # Pick rows where ancilla == herald_n: rows = [signal*D + herald_n for signal in 0..D-1]
160
+ D = self.D
161
+ idx = torch.arange(D, device=psi.device, dtype=torch.long) * D + self.herald_n
162
+ psi_sig = psi[idx]
163
+ # Normalize (heralding probability is the squared norm; we drop it)
164
+ norm = torch.linalg.vector_norm(psi_sig)
165
+ psi_sig = psi_sig / torch.clamp(norm, min=1e-12)
166
+ return psi_sig
167
+
168
+ def state_from_text(self, text: str) -> Tensor:
169
+ feat = self.encode_features([text])[0]
170
+ return self.state_from_features(feat)
171
+
172
+
173
+ def overlap_squared(psi_a: Tensor, psi_b: Tensor) -> Tensor:
174
+ """|⟨ψ_a|ψ_b⟩|². Pure-state fidelity since both are heralded pure."""
175
+ inner = torch.vdot(psi_a, psi_b)
176
+ return (inner.real ** 2 + inner.imag ** 2)
177
+
178
+
179
+ def recall_at_k(ranked, relevant, k):
180
+ if not relevant:
181
+ return float("nan")
182
+ return len(set(ranked[:k]) & relevant) / len(relevant)
183
+
184
+
185
+ def ndcg_at_k(ranked, relevant, k):
186
+ if not relevant:
187
+ return float("nan")
188
+ dcg = sum(1.0 / math.log2(i + 1) for i, a in enumerate(ranked[:k], start=1) if a in relevant)
189
+ ideal = sum(1.0 / math.log2(i + 1) for i in range(1, min(k, len(relevant)) + 1))
190
+ return dcg / ideal if ideal > 0 else float("nan")
191
+
192
+
193
+ def evaluate(model, abstracts, ids, queries, ks=(1, 3, 5, 10)) -> dict:
194
+ model.eval()
195
+ with torch.no_grad():
196
+ doc_states = {a: model.state_from_text(t) for a, t in abstracts.items()}
197
+ rows = []
198
+ for q in queries:
199
+ with torch.no_grad():
200
+ psi_q = model.state_from_text(q["query"])
201
+ scored = []
202
+ for a in ids:
203
+ psi_d = doc_states[a]
204
+ scored.append((float(overlap_squared(psi_q, psi_d).item()), a))
205
+ scored.sort(key=lambda x: -x[0])
206
+ ranked = [a for _, a in scored]
207
+ rel = set(q["relevant_ids"])
208
+ row = {"query": q["query"], "ranked": ranked[: max(ks)]}
209
+ for k in ks:
210
+ row[f"recall@{k}"] = recall_at_k(ranked, rel, k)
211
+ row[f"ndcg@{k}"] = ndcg_at_k(ranked, rel, k)
212
+ rows.append(row)
213
+ agg = {f"recall@{k}": float(np.mean([r[f"recall@{k}"] for r in rows])) for k in ks}
214
+ agg.update({f"ndcg@{k}": float(np.mean([r[f"ndcg@{k}"] for r in rows])) for k in ks})
215
+ return {"per_query": rows, "aggregate": agg}
216
+
217
+
218
+ def train(args):
219
+ torch.manual_seed(args.seed)
220
+ np.random.seed(args.seed)
221
+
222
+ rel_path = Path(args.relevance) if args.relevance else ROOT / "eval" / "relevance.json"
223
+ cids_path = ROOT / "eval" / "corpus_ids.json"
224
+ man_path = ROOT / "eval" / "manifest.json"
225
+ train_relevance = json.loads(rel_path.read_text("utf-8"))["queries"]
226
+ ids = json.loads(cids_path.read_text("utf-8"))["ids"]
227
+ print(f"[fock] fetching {len(ids)} abstracts...", flush=True)
228
+ abstracts = fetch_all(ids)
229
+ bad = verify_against_manifest(abstracts, man_path)
230
+ if bad:
231
+ sys.exit(f"manifest mismatch: {list(bad)[:3]}")
232
+
233
+ model = SBERTPhotonFock(cutoff=args.cutoff, herald_n=args.herald_n)
234
+ n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
235
+ print(f"[fock] cutoff={args.cutoff} herald_n={args.herald_n} trainable={n_trainable}", flush=True)
236
+
237
+ # Cache features
238
+ print(f"[fock] caching SBERT features ({len(abstracts)} docs + "
239
+ f"{len(train_relevance)} queries)...", flush=True)
240
+ doc_feats = {a: model.encode_features([t])[0] for a, t in abstracts.items()}
241
+ query_feats = {q["query"]: model.encode_features([q["query"]])[0]
242
+ for q in train_relevance}
243
+
244
+ optim = torch.optim.AdamW(
245
+ [p for p in model.parameters() if p.requires_grad],
246
+ lr=args.lr, weight_decay=args.weight_decay,
247
+ )
248
+ rng = np.random.default_rng(args.seed)
249
+ queries = [(q["query"], set(q["relevant_ids"])) for q in train_relevance]
250
+
251
+ t0 = time.time()
252
+ for step in range(1, args.steps + 1):
253
+ optim.zero_grad()
254
+ # Recompute every state (proj weights change every step). Cached
255
+ # SBERT features feed straight into state_from_features.
256
+ doc_psi = {a: model.state_from_features(doc_feats[a]) for a in abstracts}
257
+
258
+ loss_sum = torch.zeros((), dtype=torch.float64)
259
+ for query_text, rel_set in queries:
260
+ psi_q = model.state_from_features(query_feats[query_text])
261
+ pos_id = rng.choice(sorted(rel_set))
262
+ psi_pos = doc_psi[pos_id]
263
+ negs = rng.choice(
264
+ [i for i in ids if i not in rel_set],
265
+ size=min(args.negatives, len(ids) - len(rel_set)), replace=False,
266
+ )
267
+ f_pos = overlap_squared(psi_q, psi_pos)
268
+ f_negs = torch.stack([overlap_squared(psi_q, doc_psi[n]) for n in negs])
269
+ # InfoNCE: log P(pos) = log(f_pos / Σ f). Fidelities are in [0,1] so use as logits directly.
270
+ logits = torch.cat([f_pos.unsqueeze(0), f_negs]) / args.temperature
271
+ ce = F.cross_entropy(logits.unsqueeze(0), torch.zeros((), dtype=torch.long).unsqueeze(0))
272
+ loss_sum = loss_sum + ce
273
+ loss_sum = loss_sum / len(queries)
274
+ loss_sum.backward()
275
+ torch.nn.utils.clip_grad_norm_(
276
+ [p for p in model.parameters() if p.requires_grad], max_norm=args.clip,
277
+ )
278
+ optim.step()
279
+ if step == 1 or step % args.log_every == 0 or step == args.steps:
280
+ print(f"[fock] step {step}/{args.steps} loss={loss_sum.item():.4f} "
281
+ f"elapsed={time.time()-t0:.1f}s", flush=True)
282
+
283
+ eval_paths = []
284
+ if args.eval_train_rel:
285
+ eval_paths.append(("train", Path(args.eval_train_rel)))
286
+ if args.eval_test_rel:
287
+ eval_paths.append(("test", Path(args.eval_test_rel)))
288
+ if not eval_paths:
289
+ eval_paths.append(("all", rel_path))
290
+ summary = {}
291
+ for label, p in eval_paths:
292
+ rels = json.loads(p.read_text("utf-8"))["queries"]
293
+ report = evaluate(model, abstracts, ids, rels)
294
+ print(f"\n=== {label.upper()} EVAL ({len(rels)} queries) ===")
295
+ for r in report["per_query"]:
296
+ cells = " ".join(f"{m}={r[m]:.3f}" for m in r if m.startswith(("recall", "ndcg")))
297
+ print(f" {r['query'][:48]:<48s} {cells}")
298
+ print("aggregate: " + " ".join(f"{m}={report['aggregate'][m]:.3f}" for m in report["aggregate"]))
299
+ summary[f"{label}/fock"] = report["aggregate"]
300
+ print(f"\nSUMMARY_JSON={json.dumps(summary)}")
301
+
302
+
303
+ def main():
304
+ ap = argparse.ArgumentParser()
305
+ ap.add_argument("--cutoff", type=int, default=6,
306
+ help="Fock-basis truncation per mode. Joint dim = cutoff².")
307
+ ap.add_argument("--herald-n", type=int, default=HERALD_N,
308
+ help="Ancilla photon-number outcome to project onto.")
309
+ ap.add_argument("--steps", type=int, default=200)
310
+ ap.add_argument("--lr", type=float, default=1e-2)
311
+ ap.add_argument("--weight-decay", type=float, default=1e-3)
312
+ ap.add_argument("--temperature", type=float, default=0.5)
313
+ ap.add_argument("--negatives", type=int, default=8)
314
+ ap.add_argument("--clip", type=float, default=1.0)
315
+ ap.add_argument("--seed", type=int, default=42)
316
+ ap.add_argument("--log-every", type=int, default=20)
317
+ ap.add_argument("--relevance", type=str, default=None)
318
+ ap.add_argument("--eval-train-rel", type=str, default=None)
319
+ ap.add_argument("--eval-test-rel", type=str, default=None)
320
+ args = ap.parse_args()
321
+ train(args)
322
+
323
+
324
+ if __name__ == "__main__":
325
+ main()