Spaces:
Running
Running
meridian-mcp monorepo sync commited on
Commit ·
e84fa2c
1
Parent(s): 066683d
Sync from meridian-mcp@bf1547266b87d49e4fef560bf851bd07c585f5ac
Browse files- .github/workflows/finish-line.yml +85 -0
- .github/workflows/pages.yml +33 -0
- .github/workflows/photon-sweep.yml +61 -0
- eval/expand_titles.py +103 -0
- eval/run_bm25.py +113 -0
- eval/run_sbert.py +87 -0
- pages/CNAME +1 -0
- worker/proxy.js → pages/index.html +209 -211
- space/analyze_sweep.py +112 -0
- space/run_sweep.py +155 -0
- space/run_sweep_fock.py +125 -0
- space/sim_b1_g1_coherence.py +126 -0
- space/sim_b2_g2_classifier.py +132 -0
- space/train.py +3 -1
- space/train_sbert.py +392 -0
- space/train_sbert_fock.py +325 -0
.github/workflows/finish-line.yml
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: photon-route finish-line (B1 + B2 + A3-Real)
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
workflow_dispatch:
|
| 5 |
+
inputs:
|
| 6 |
+
seeds:
|
| 7 |
+
description: "Sweep seeds for A3-Real (default: 1 2 3 4 5)"
|
| 8 |
+
required: false
|
| 9 |
+
default: "1 2 3 4 5"
|
| 10 |
+
steps:
|
| 11 |
+
description: "Training steps for A3-Real (default: 200)"
|
| 12 |
+
required: false
|
| 13 |
+
default: "200"
|
| 14 |
+
|
| 15 |
+
permissions:
|
| 16 |
+
contents: read
|
| 17 |
+
|
| 18 |
+
jobs:
|
| 19 |
+
finish:
|
| 20 |
+
runs-on: ubuntu-latest
|
| 21 |
+
timeout-minutes: 90
|
| 22 |
+
steps:
|
| 23 |
+
- uses: actions/checkout@v4
|
| 24 |
+
- uses: actions/setup-python@v5
|
| 25 |
+
with:
|
| 26 |
+
python-version: "3.11"
|
| 27 |
+
|
| 28 |
+
- name: Install dependencies
|
| 29 |
+
run: |
|
| 30 |
+
python -m pip install --upgrade pip
|
| 31 |
+
pip install numpy scipy thewalrus torch sentence-transformers
|
| 32 |
+
|
| 33 |
+
- name: B1 — g^(1) coherence-time sim
|
| 34 |
+
run: |
|
| 35 |
+
mkdir -p reports
|
| 36 |
+
PYTHONPATH=src:. python -m space.sim_b1_g1_coherence | tee reports/b1.txt
|
| 37 |
+
|
| 38 |
+
- name: B2 — g^(2) classifier sim
|
| 39 |
+
run: |
|
| 40 |
+
PYTHONPATH=src:. python -m space.sim_b2_g2_classifier | tee reports/b2.txt
|
| 41 |
+
|
| 42 |
+
- name: Expand relevance (title queries)
|
| 43 |
+
run: |
|
| 44 |
+
PYTHONPATH=src:. python -m eval.expand_titles
|
| 45 |
+
|
| 46 |
+
- name: A3-Real — Fock-basis non-Gaussian sweep
|
| 47 |
+
run: |
|
| 48 |
+
PYTHONPATH=src:. python -m space.run_sweep_fock \
|
| 49 |
+
--seeds ${{ github.event.inputs.seeds }} \
|
| 50 |
+
--steps ${{ github.event.inputs.steps }} \
|
| 51 |
+
--herald-ns 1 0 \
|
| 52 |
+
--out-csv sweep_fock_results.csv \
|
| 53 |
+
--log-dir sweep_fock_logs \
|
| 54 |
+
| tee reports/a3.txt
|
| 55 |
+
|
| 56 |
+
- name: Upload artifacts
|
| 57 |
+
if: always()
|
| 58 |
+
uses: actions/upload-artifact@v4
|
| 59 |
+
with:
|
| 60 |
+
name: photon-finish-line
|
| 61 |
+
path: |
|
| 62 |
+
reports/
|
| 63 |
+
sweep_fock_results.csv
|
| 64 |
+
sweep_fock_logs/
|
| 65 |
+
eval/relevance_expanded.json
|
| 66 |
+
if-no-files-found: warn
|
| 67 |
+
retention-days: 30
|
| 68 |
+
|
| 69 |
+
- name: Print summaries
|
| 70 |
+
if: always()
|
| 71 |
+
run: |
|
| 72 |
+
echo "=========================================="
|
| 73 |
+
echo "B1 — g^(1) coherence sim"
|
| 74 |
+
echo "=========================================="
|
| 75 |
+
tail -25 reports/b1.txt || true
|
| 76 |
+
echo
|
| 77 |
+
echo "=========================================="
|
| 78 |
+
echo "B2 — g^(2) classifier sim"
|
| 79 |
+
echo "=========================================="
|
| 80 |
+
tail -25 reports/b2.txt || true
|
| 81 |
+
echo
|
| 82 |
+
echo "=========================================="
|
| 83 |
+
echo "A3-Real — Fock sweep CSV"
|
| 84 |
+
echo "=========================================="
|
| 85 |
+
cat sweep_fock_results.csv || true
|
.github/workflows/pages.yml
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Deploy photon-route UI to GitHub Pages
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
branches: [main]
|
| 6 |
+
paths:
|
| 7 |
+
- 'pages/**'
|
| 8 |
+
- '.github/workflows/pages.yml'
|
| 9 |
+
workflow_dispatch: {}
|
| 10 |
+
|
| 11 |
+
permissions:
|
| 12 |
+
contents: read
|
| 13 |
+
pages: write
|
| 14 |
+
id-token: write
|
| 15 |
+
|
| 16 |
+
concurrency:
|
| 17 |
+
group: pages
|
| 18 |
+
cancel-in-progress: false
|
| 19 |
+
|
| 20 |
+
jobs:
|
| 21 |
+
deploy:
|
| 22 |
+
environment:
|
| 23 |
+
name: github-pages
|
| 24 |
+
url: ${{ steps.deployment.outputs.page_url }}
|
| 25 |
+
runs-on: ubuntu-latest
|
| 26 |
+
steps:
|
| 27 |
+
- uses: actions/checkout@v4
|
| 28 |
+
- uses: actions/configure-pages@v5
|
| 29 |
+
- uses: actions/upload-pages-artifact@v3
|
| 30 |
+
with:
|
| 31 |
+
path: ./pages
|
| 32 |
+
- id: deployment
|
| 33 |
+
uses: actions/deploy-pages@v4
|
.github/workflows/photon-sweep.yml
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: photon-route sweep (5 splits × squeeze ablation)
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
workflow_dispatch:
|
| 5 |
+
inputs:
|
| 6 |
+
seeds:
|
| 7 |
+
description: "Space-separated seeds (default: 1 2 3 4 5)"
|
| 8 |
+
required: false
|
| 9 |
+
default: "1 2 3 4 5"
|
| 10 |
+
steps:
|
| 11 |
+
description: "Training steps per run (default: 200)"
|
| 12 |
+
required: false
|
| 13 |
+
default: "200"
|
| 14 |
+
|
| 15 |
+
permissions:
|
| 16 |
+
contents: read
|
| 17 |
+
|
| 18 |
+
jobs:
|
| 19 |
+
sweep:
|
| 20 |
+
runs-on: ubuntu-latest
|
| 21 |
+
timeout-minutes: 60
|
| 22 |
+
steps:
|
| 23 |
+
- uses: actions/checkout@v4
|
| 24 |
+
|
| 25 |
+
- uses: actions/setup-python@v5
|
| 26 |
+
with:
|
| 27 |
+
python-version: "3.11"
|
| 28 |
+
|
| 29 |
+
- name: Install dependencies
|
| 30 |
+
run: |
|
| 31 |
+
python -m pip install --upgrade pip
|
| 32 |
+
pip install numpy scipy thewalrus torch sentence-transformers
|
| 33 |
+
|
| 34 |
+
- name: Expand relevance (title queries)
|
| 35 |
+
run: |
|
| 36 |
+
PYTHONPATH=src:. python -m eval.expand_titles
|
| 37 |
+
|
| 38 |
+
- name: Run sweep
|
| 39 |
+
run: |
|
| 40 |
+
PYTHONPATH=src:. python -m space.run_sweep \
|
| 41 |
+
--seeds ${{ github.event.inputs.seeds }} \
|
| 42 |
+
--steps ${{ github.event.inputs.steps }} \
|
| 43 |
+
--out-csv sweep_results.csv \
|
| 44 |
+
--log-dir sweep_logs
|
| 45 |
+
|
| 46 |
+
- name: Upload sweep results
|
| 47 |
+
uses: actions/upload-artifact@v4
|
| 48 |
+
with:
|
| 49 |
+
name: photon-sweep-results
|
| 50 |
+
path: |
|
| 51 |
+
sweep_results.csv
|
| 52 |
+
sweep_logs/
|
| 53 |
+
eval/relevance_expanded.json
|
| 54 |
+
if-no-files-found: error
|
| 55 |
+
retention-days: 30
|
| 56 |
+
|
| 57 |
+
- name: Print CSV summary
|
| 58 |
+
if: always()
|
| 59 |
+
run: |
|
| 60 |
+
echo "=== sweep_results.csv ==="
|
| 61 |
+
cat sweep_results.csv || true
|
eval/expand_titles.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Expand the eval relevance set with title-as-query pairs.
|
| 2 |
+
|
| 3 |
+
For each arXiv ID in corpus_ids.json, fetch og:title from the abstract
|
| 4 |
+
page (same scrape pattern as eval.fetch but a different meta tag), and
|
| 5 |
+
emit one query whose only relevant document is that paper.
|
| 6 |
+
|
| 7 |
+
Output: eval/relevance_expanded.json — original 6 multi-positive queries
|
| 8 |
+
plus 20 single-positive title queries = 26 total. Increases trainer
|
| 9 |
+
signal 4× without any human labeling.
|
| 10 |
+
|
| 11 |
+
This script does NOT touch the existing relevance.json. It writes a
|
| 12 |
+
sibling file the trainer / eval harness opt into via --relevance.
|
| 13 |
+
"""
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
import html
|
| 18 |
+
import json
|
| 19 |
+
import re
|
| 20 |
+
import time
|
| 21 |
+
import urllib.request
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
|
| 24 |
+
ROOT = Path(__file__).resolve().parent.parent
|
| 25 |
+
ARXIV_ABS = "https://arxiv.org/abs/"
|
| 26 |
+
_OG_TITLE = re.compile(
|
| 27 |
+
r'<meta\s+(?:property|name)="og:title"\s+content="([^"]*)"',
|
| 28 |
+
re.IGNORECASE,
|
| 29 |
+
)
|
| 30 |
+
_BROWSER_HEADERS = {
|
| 31 |
+
"User-Agent": (
|
| 32 |
+
"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 "
|
| 33 |
+
"(KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36"
|
| 34 |
+
),
|
| 35 |
+
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
|
| 36 |
+
"Accept-Language": "en-US,en;q=0.9",
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _normalize(text: str) -> str:
|
| 41 |
+
return re.sub(r"\s+", " ", text).strip()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _strip_arxiv_prefix(title: str) -> str:
|
| 45 |
+
"""og:title comes back as '[2304.12717] Quantum natural language ...';
|
| 46 |
+
strip the '[id]' prefix so the query is just the paper title."""
|
| 47 |
+
return re.sub(r"^\s*\[[^\]]+\]\s*", "", title).strip()
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def fetch_title(arxiv_id: str, timeout: float = 30.0) -> str:
|
| 51 |
+
url = ARXIV_ABS + arxiv_id
|
| 52 |
+
req = urllib.request.Request(url, headers=_BROWSER_HEADERS)
|
| 53 |
+
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
| 54 |
+
body = resp.read().decode("utf-8", errors="replace")
|
| 55 |
+
m = _OG_TITLE.search(body)
|
| 56 |
+
if not m:
|
| 57 |
+
raise RuntimeError(f"og:title not found for {arxiv_id}")
|
| 58 |
+
return _strip_arxiv_prefix(_normalize(html.unescape(m.group(1))))
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def main():
|
| 62 |
+
ap = argparse.ArgumentParser()
|
| 63 |
+
ap.add_argument("--corpus", type=Path, default=ROOT / "eval" / "corpus_ids.json")
|
| 64 |
+
ap.add_argument("--in-relevance", type=Path, default=ROOT / "eval" / "relevance.json")
|
| 65 |
+
ap.add_argument("--out", type=Path, default=ROOT / "eval" / "relevance_expanded.json")
|
| 66 |
+
ap.add_argument("--cache", type=Path, default=Path.home() / ".cache" / "photon-route" / "titles")
|
| 67 |
+
ap.add_argument("--sleep", type=float, default=0.5)
|
| 68 |
+
args = ap.parse_args()
|
| 69 |
+
|
| 70 |
+
args.cache.mkdir(parents=True, exist_ok=True)
|
| 71 |
+
ids = json.loads(args.corpus.read_text("utf-8"))["ids"]
|
| 72 |
+
base = json.loads(args.in_relevance.read_text("utf-8"))
|
| 73 |
+
|
| 74 |
+
titles = {}
|
| 75 |
+
for j, i in enumerate(ids):
|
| 76 |
+
cache_path = args.cache / f"{i}.title"
|
| 77 |
+
if cache_path.exists():
|
| 78 |
+
titles[i] = cache_path.read_text("utf-8").strip()
|
| 79 |
+
continue
|
| 80 |
+
t = fetch_title(i)
|
| 81 |
+
cache_path.write_text(t, encoding="utf-8")
|
| 82 |
+
titles[i] = t
|
| 83 |
+
print(f"[{j+1:2d}/{len(ids)}] {i}: {t[:60]}")
|
| 84 |
+
if j + 1 < len(ids):
|
| 85 |
+
time.sleep(args.sleep)
|
| 86 |
+
|
| 87 |
+
title_queries = [
|
| 88 |
+
{"query": titles[i], "relevant_ids": [i], "kind": "title"}
|
| 89 |
+
for i in ids
|
| 90 |
+
]
|
| 91 |
+
out_payload = {
|
| 92 |
+
**base,
|
| 93 |
+
"queries": [
|
| 94 |
+
*[{**q, "kind": "topical"} for q in base["queries"]],
|
| 95 |
+
*title_queries,
|
| 96 |
+
],
|
| 97 |
+
}
|
| 98 |
+
args.out.write_text(json.dumps(out_payload, indent=2) + "\n", encoding="utf-8")
|
| 99 |
+
print(f"\nwrote {len(out_payload['queries'])} queries → {args.out}")
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
if __name__ == "__main__":
|
| 103 |
+
main()
|
eval/run_bm25.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""BM25 baseline against the same eval set photon-route uses.
|
| 2 |
+
|
| 3 |
+
Drops in next to eval.run so apples-to-apples on Recall@k / nDCG@k.
|
| 4 |
+
Pure-stdlib BM25 — no external IR library — to keep the dependency
|
| 5 |
+
surface identical to the rest of eval/.
|
| 6 |
+
"""
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import argparse
|
| 10 |
+
import json
|
| 11 |
+
import math
|
| 12 |
+
from collections import Counter
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
|
| 17 |
+
from eval.fetch import fetch_all, verify_against_manifest
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class BM25:
|
| 21 |
+
def __init__(self, docs: list[str], k1: float = 1.5, b: float = 0.75):
|
| 22 |
+
self.k1, self.b = k1, b
|
| 23 |
+
self.toks = [d.lower().split() for d in docs]
|
| 24 |
+
self.N = len(docs)
|
| 25 |
+
self.avgdl = sum(len(t) for t in self.toks) / self.N
|
| 26 |
+
df: Counter = Counter()
|
| 27 |
+
for t in self.toks:
|
| 28 |
+
for w in set(t):
|
| 29 |
+
df[w] += 1
|
| 30 |
+
self.idf = {
|
| 31 |
+
w: math.log(1 + (self.N - n + 0.5) / (n + 0.5)) for w, n in df.items()
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
def score(self, query: str, doc_index: int) -> float:
|
| 35 |
+
d = self.toks[doc_index]
|
| 36 |
+
tf = Counter(d)
|
| 37 |
+
s = 0.0
|
| 38 |
+
for w in query.lower().split():
|
| 39 |
+
if w not in self.idf:
|
| 40 |
+
continue
|
| 41 |
+
f = tf[w]
|
| 42 |
+
denom = f + self.k1 * (1 - self.b + self.b * len(d) / self.avgdl)
|
| 43 |
+
s += self.idf[w] * f * (self.k1 + 1) / max(denom, 1e-9)
|
| 44 |
+
return s
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def recall_at_k(ranked_ids: list[str], relevant: set[str], k: int) -> float:
|
| 48 |
+
if not relevant:
|
| 49 |
+
return float("nan")
|
| 50 |
+
return len(set(ranked_ids[:k]) & relevant) / len(relevant)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def ndcg_at_k(ranked_ids: list[str], relevant: set[str], k: int) -> float:
|
| 54 |
+
if not relevant:
|
| 55 |
+
return float("nan")
|
| 56 |
+
dcg = sum(
|
| 57 |
+
1.0 / math.log2(i + 1)
|
| 58 |
+
for i, a in enumerate(ranked_ids[:k], start=1)
|
| 59 |
+
if a in relevant
|
| 60 |
+
)
|
| 61 |
+
ideal = sum(1.0 / math.log2(i + 1) for i in range(1, min(k, len(relevant)) + 1))
|
| 62 |
+
return dcg / ideal if ideal > 0 else float("nan")
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def main():
|
| 66 |
+
ap = argparse.ArgumentParser()
|
| 67 |
+
ap.add_argument("--corpus", type=Path, default=Path(__file__).parent / "corpus_ids.json")
|
| 68 |
+
ap.add_argument("--relevance", type=Path, default=Path(__file__).parent / "relevance.json")
|
| 69 |
+
ap.add_argument("--manifest", type=Path, default=Path(__file__).parent / "manifest.json")
|
| 70 |
+
ap.add_argument("--ks", type=int, nargs="+", default=[1, 3, 5, 10])
|
| 71 |
+
args = ap.parse_args()
|
| 72 |
+
|
| 73 |
+
ids = json.loads(args.corpus.read_text("utf-8"))["ids"]
|
| 74 |
+
queries = json.loads(args.relevance.read_text("utf-8"))["queries"]
|
| 75 |
+
abstracts = fetch_all(ids)
|
| 76 |
+
bad = verify_against_manifest(abstracts, args.manifest)
|
| 77 |
+
if bad:
|
| 78 |
+
raise SystemExit(f"manifest mismatch: {list(bad)[:3]}")
|
| 79 |
+
|
| 80 |
+
docs_in_order = [abstracts[i] for i in ids]
|
| 81 |
+
bm25 = BM25(docs_in_order)
|
| 82 |
+
|
| 83 |
+
per_query = []
|
| 84 |
+
for q in queries:
|
| 85 |
+
scored = sorted(
|
| 86 |
+
((bm25.score(q["query"], i), ids[i]) for i in range(len(ids))),
|
| 87 |
+
key=lambda x: -x[0],
|
| 88 |
+
)
|
| 89 |
+
ranked_ids = [doc_id for _, doc_id in scored]
|
| 90 |
+
rel = set(q["relevant_ids"])
|
| 91 |
+
row = {"query": q["query"], "ranked": ranked_ids[: max(args.ks)]}
|
| 92 |
+
for k in args.ks:
|
| 93 |
+
row[f"recall@{k}"] = recall_at_k(ranked_ids, rel, k)
|
| 94 |
+
row[f"ndcg@{k}"] = ndcg_at_k(ranked_ids, rel, k)
|
| 95 |
+
per_query.append(row)
|
| 96 |
+
|
| 97 |
+
aggregate = {
|
| 98 |
+
f"recall@{k}": float(np.mean([q[f"recall@{k}"] for q in per_query])) for k in args.ks
|
| 99 |
+
}
|
| 100 |
+
aggregate.update(
|
| 101 |
+
{f"ndcg@{k}": float(np.mean([q[f"ndcg@{k}"] for q in per_query])) for k in args.ks}
|
| 102 |
+
)
|
| 103 |
+
print(f"backend=bm25 corpus={len(ids)} queries={len(queries)}")
|
| 104 |
+
for q in per_query:
|
| 105 |
+
cells = " ".join(
|
| 106 |
+
f"{m}={q[m]:.3f}" for m in q if m.startswith(("recall", "ndcg"))
|
| 107 |
+
)
|
| 108 |
+
print(f" {q['query'][:48]:<48s} {cells}")
|
| 109 |
+
print("aggregate: " + " ".join(f"{m}={aggregate[m]:.3f}" for m in aggregate))
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
if __name__ == "__main__":
|
| 113 |
+
main()
|
eval/run_sbert.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SBERT (all-MiniLM-L6-v2) baseline against the same eval set.
|
| 2 |
+
|
| 3 |
+
Mean-pooled 384-d sentence embedding, cosine similarity. Establishes the
|
| 4 |
+
modern dense-retrieval ceiling for the photon-route eval. Runs entirely
|
| 5 |
+
on CPU in a few seconds for this corpus size.
|
| 6 |
+
"""
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import argparse
|
| 10 |
+
import json
|
| 11 |
+
import math
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
|
| 16 |
+
from eval.fetch import fetch_all, verify_against_manifest
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def recall_at_k(ranked_ids: list[str], relevant: set[str], k: int) -> float:
|
| 20 |
+
if not relevant:
|
| 21 |
+
return float("nan")
|
| 22 |
+
return len(set(ranked_ids[:k]) & relevant) / len(relevant)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def ndcg_at_k(ranked_ids: list[str], relevant: set[str], k: int) -> float:
|
| 26 |
+
if not relevant:
|
| 27 |
+
return float("nan")
|
| 28 |
+
dcg = sum(
|
| 29 |
+
1.0 / math.log2(i + 1)
|
| 30 |
+
for i, a in enumerate(ranked_ids[:k], start=1)
|
| 31 |
+
if a in relevant
|
| 32 |
+
)
|
| 33 |
+
ideal = sum(1.0 / math.log2(i + 1) for i in range(1, min(k, len(relevant)) + 1))
|
| 34 |
+
return dcg / ideal if ideal > 0 else float("nan")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def main():
|
| 38 |
+
ap = argparse.ArgumentParser()
|
| 39 |
+
ap.add_argument("--model", default="sentence-transformers/all-MiniLM-L6-v2")
|
| 40 |
+
ap.add_argument("--corpus", type=Path, default=Path(__file__).parent / "corpus_ids.json")
|
| 41 |
+
ap.add_argument("--relevance", type=Path, default=Path(__file__).parent / "relevance.json")
|
| 42 |
+
ap.add_argument("--manifest", type=Path, default=Path(__file__).parent / "manifest.json")
|
| 43 |
+
ap.add_argument("--ks", type=int, nargs="+", default=[1, 3, 5, 10])
|
| 44 |
+
args = ap.parse_args()
|
| 45 |
+
|
| 46 |
+
from sentence_transformers import SentenceTransformer
|
| 47 |
+
|
| 48 |
+
ids = json.loads(args.corpus.read_text("utf-8"))["ids"]
|
| 49 |
+
queries = json.loads(args.relevance.read_text("utf-8"))["queries"]
|
| 50 |
+
abstracts = fetch_all(ids)
|
| 51 |
+
bad = verify_against_manifest(abstracts, args.manifest)
|
| 52 |
+
if bad:
|
| 53 |
+
raise SystemExit(f"manifest mismatch: {list(bad)[:3]}")
|
| 54 |
+
|
| 55 |
+
print(f"loading {args.model}...")
|
| 56 |
+
model = SentenceTransformer(args.model)
|
| 57 |
+
|
| 58 |
+
docs_in_order = [abstracts[i] for i in ids]
|
| 59 |
+
doc_emb = model.encode(docs_in_order, normalize_embeddings=True, show_progress_bar=False)
|
| 60 |
+
q_emb = model.encode([q["query"] for q in queries], normalize_embeddings=True, show_progress_bar=False)
|
| 61 |
+
|
| 62 |
+
per_query = []
|
| 63 |
+
for qi, q in enumerate(queries):
|
| 64 |
+
sims = doc_emb @ q_emb[qi] # cosine since both are normalized
|
| 65 |
+
order = np.argsort(-sims)
|
| 66 |
+
ranked_ids = [ids[i] for i in order]
|
| 67 |
+
rel = set(q["relevant_ids"])
|
| 68 |
+
row = {"query": q["query"], "ranked": ranked_ids[: max(args.ks)]}
|
| 69 |
+
for k in args.ks:
|
| 70 |
+
row[f"recall@{k}"] = recall_at_k(ranked_ids, rel, k)
|
| 71 |
+
row[f"ndcg@{k}"] = ndcg_at_k(ranked_ids, rel, k)
|
| 72 |
+
per_query.append(row)
|
| 73 |
+
|
| 74 |
+
aggregate = {f"recall@{k}": float(np.mean([q[f"recall@{k}"] for q in per_query])) for k in args.ks}
|
| 75 |
+
aggregate.update(
|
| 76 |
+
{f"ndcg@{k}": float(np.mean([q[f"ndcg@{k}"] for q in per_query])) for k in args.ks}
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
print(f"backend=sbert/{args.model.split('/')[-1]} corpus={len(ids)} queries={len(queries)}")
|
| 80 |
+
for q in per_query:
|
| 81 |
+
cells = " ".join(f"{m}={q[m]:.3f}" for m in q if m.startswith(("recall", "ndcg")))
|
| 82 |
+
print(f" {q['query'][:48]:<48s} {cells}")
|
| 83 |
+
print("aggregate: " + " ".join(f"{m}={aggregate[m]:.3f}" for m in aggregate))
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
if __name__ == "__main__":
|
| 87 |
+
main()
|
pages/CNAME
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
photon.ask-meridian.uk
|
worker/proxy.js → pages/index.html
RENAMED
|
@@ -1,57 +1,4 @@
|
|
| 1 |
-
|
| 2 |
-
// Routes photon.ask-meridian.uk → HF Space (luuow-photon-route),
|
| 3 |
-
// edge-caches /rank, and serves an inline interactive UI at /.
|
| 4 |
-
//
|
| 5 |
-
// Endpoints (worker-local):
|
| 6 |
-
// / interactive HTML page
|
| 7 |
-
// /health liveness JSON
|
| 8 |
-
// /api service banner JSON
|
| 9 |
-
// Endpoints (proxied to HF Space):
|
| 10 |
-
// /rank?q=&top_k= edge-cached 24 h
|
| 11 |
-
// /version /docs /openapi.json
|
| 12 |
-
//
|
| 13 |
-
// Cache key includes CACHE_VERSION; bump to invalidate after UI changes
|
| 14 |
-
// or fixture updates upstream.
|
| 15 |
-
|
| 16 |
-
const HF = 'https://luuow-photon-route.hf.space';
|
| 17 |
-
const PROXY = new Set(['/rank', '/version', '/docs', '/openapi.json']);
|
| 18 |
-
// CACHE_VERSION bump: /rank responses now include `backend` and the
|
| 19 |
-
// payload differs per backend; the cache key already includes the
|
| 20 |
-
// query string so backends are cached separately, but old (v3) entries
|
| 21 |
-
// don't have the backend field — invalidate.
|
| 22 |
-
const CACHE_VERSION = 'v4';
|
| 23 |
-
|
| 24 |
-
const BANNER = {
|
| 25 |
-
service: 'photon-route',
|
| 26 |
-
proxy: 'cloudflare-worker',
|
| 27 |
-
backend: 'huggingface-space',
|
| 28 |
-
upstream: HF,
|
| 29 |
-
repo: 'https://github.com/LuuOW/photon-route',
|
| 30 |
-
sister: 'https://qrouter.ask-meridian.uk (DV / qubit-gate sister project)',
|
| 31 |
-
endpoints: {
|
| 32 |
-
ui: '/ (interactive HTML)',
|
| 33 |
-
api: '/api (this banner)',
|
| 34 |
-
health: '/health (worker-local)',
|
| 35 |
-
rank: '/rank?q=<query>&top_k=N&backend=v1|sha_init|trained (proxied, edge-cached 24 h)',
|
| 36 |
-
version: '/version (proxied)',
|
| 37 |
-
docs: '/docs (proxied — FastAPI swagger)',
|
| 38 |
-
},
|
| 39 |
-
backends: ['v1 (SF)', 'sha_init (numpy)', 'trained (numpy + learned)'],
|
| 40 |
-
note: 'CV photonic retrieval. Strawberry Fields Gaussian programs, thewalrus closed-form fidelity.',
|
| 41 |
-
};
|
| 42 |
-
|
| 43 |
-
const CSP = [
|
| 44 |
-
"default-src 'self'",
|
| 45 |
-
"style-src 'self' 'unsafe-inline'",
|
| 46 |
-
"script-src 'self' 'unsafe-inline'",
|
| 47 |
-
"connect-src 'self'",
|
| 48 |
-
"img-src 'self' data:",
|
| 49 |
-
"base-uri 'none'",
|
| 50 |
-
"form-action 'none'",
|
| 51 |
-
"frame-ancestors 'none'",
|
| 52 |
-
].join('; ');
|
| 53 |
-
|
| 54 |
-
const HTML = `<!doctype html>
|
| 55 |
<html lang="en">
|
| 56 |
<head>
|
| 57 |
<meta charset="utf-8">
|
|
@@ -161,7 +108,6 @@ const HTML = `<!doctype html>
|
|
| 161 |
border-bottom:1px dotted var(--line)}
|
| 162 |
footer a:hover{color:var(--fg);border-bottom-color:var(--cyan)}
|
| 163 |
.empty{color:var(--muted);text-align:center;padding:32px 12px;font-size:12px}
|
| 164 |
-
/* phase-space visualization */
|
| 165 |
.viz{margin:14px 0 18px;display:grid;grid-template-columns:1fr 1fr;gap:10px}
|
| 166 |
.modepanel{position:relative;background:var(--panel);border:1px solid var(--line);
|
| 167 |
border-radius:var(--radius);overflow:hidden;aspect-ratio:5/4;min-height:200px}
|
|
@@ -198,6 +144,193 @@ const HTML = `<!doctype html>
|
|
| 198 |
</style>
|
| 199 |
</head>
|
| 200 |
<body>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
<main>
|
| 202 |
<header>
|
| 203 |
<div>
|
|
@@ -249,28 +382,29 @@ const HTML = `<!doctype html>
|
|
| 249 |
<div class="body">
|
| 250 |
<p><strong>photon-route</strong> is a research artifact exploring whether semantic retrieval can run in the continuous-variable (CV) photonic regime — the regime that real photonic hardware (Xanadu Borealis, fiber-loop reservoirs, coherent Ising machines) actually operates in.</p>
|
| 251 |
<p>Each document is encoded as a <em>Gaussian state</em> over N bosonic modes via a <a href="https://strawberryfields.ai/" target="_blank" rel="noopener">Strawberry Fields</a> program: words contribute squeezing and displacement operations, then a beam-splitter network mixes the modes. Query and document fidelity is computed in closed form using the <a href="https://the-walrus.readthedocs.io/" target="_blank" rel="noopener">thewalrus</a> implementation of the Banchi-Braunstein-Pirandola formula.</p>
|
| 252 |
-
<p>Three swappable encoders share the same fidelity scoring. <strong>v1</strong> uses Strawberry Fields to run an N-mode Gaussian program with SHA-256-derived parameters. <strong>sha_init</strong> is a pure-numpy port of the same gates (no SF at deploy time). <strong>trained</strong> swaps the SHA-256 lookup for a learned table fit by InfoNCE + Bhattacharyya-coefficient surrogate fidelity on a small arXiv quant-ph eval set. Toggle <em>compare</em> to see the three side by side
|
| 253 |
<p>Source · <a href="https://github.com/LuuOW/photon-route" target="_blank" rel="noopener">github.com/LuuOW/photon-route</a></p>
|
| 254 |
</div>
|
| 255 |
</details>
|
| 256 |
|
| 257 |
<footer>
|
| 258 |
-
<span>CV photonic · gaussian backend ·
|
| 259 |
-
<span><a href="https://
|
| 260 |
</footer>
|
| 261 |
</main>
|
| 262 |
|
| 263 |
<script>
|
| 264 |
(function(){
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
function $(id){return document.getElementById(id)}
|
| 266 |
var q=$('q'), k=$('k'), results=$('results'), status=$('status');
|
| 267 |
var healthPill=$('health'), healthText=$('health-text');
|
| 268 |
var abort=null, debounceT=0;
|
| 269 |
|
| 270 |
-
// ============================================================
|
| 271 |
-
// Gaussian-state encoding (mirrors src/photon_route/encode.py)
|
| 272 |
-
// xpxp ordering: mu = [q0, p0, q1, p1]; sigma is 4x4
|
| 273 |
-
// ============================================================
|
| 274 |
var N_MODES = 2, MAX_SQUEEZE = 0.5, MAX_DISPLACE = 1.0;
|
| 275 |
var _wcache = new Map();
|
| 276 |
|
|
@@ -282,7 +416,6 @@ const HTML = `<!doctype html>
|
|
| 282 |
var hash = await crypto.subtle.digest('SHA-256', buf);
|
| 283 |
bytes = new Uint8Array(hash);
|
| 284 |
} catch(e){
|
| 285 |
-
// fallback: not cryptographic but stable enough for viz when subtle unavailable
|
| 286 |
bytes = new Uint8Array(32);
|
| 287 |
var h = 2166136261;
|
| 288 |
for(var i=0;i<word.length;i++){ h ^= word.charCodeAt(i); h = (h*16777619)>>>0; }
|
|
@@ -290,7 +423,6 @@ const HTML = `<!doctype html>
|
|
| 290 |
}
|
| 291 |
var parts = [];
|
| 292 |
for(var i2=0;i2<4;i2++){
|
| 293 |
-
// 8 bytes -> BigInt -> mod 1e9 / 1e9
|
| 294 |
var big = 0n;
|
| 295 |
for(var j2=0;j2<8;j2++) big = (big << 8n) + BigInt(bytes[i2*8 + j2]);
|
| 296 |
parts.push(Number(big % 1000000000n) / 1e9);
|
|
@@ -322,7 +454,6 @@ const HTML = `<!doctype html>
|
|
| 322 |
function tr4(A){var T=mat4(); for(var i=0;i<4;i++) for(var j=0;j<4;j++) T[i][j]=A[j][i]; return T;}
|
| 323 |
|
| 324 |
function sgateMat(k, r, phi){
|
| 325 |
-
// S22 = R(phi/2) @ diag(e^{-r}, e^r) @ R(-phi/2)
|
| 326 |
var c=Math.cos(phi/2), s=Math.sin(phi/2), em=Math.exp(-r), ep=Math.exp(r);
|
| 327 |
var a = c*c*em + s*s*ep;
|
| 328 |
var b = c*s*(em - ep);
|
|
@@ -377,11 +508,7 @@ const HTML = `<!doctype html>
|
|
| 377 |
};
|
| 378 |
}
|
| 379 |
|
| 380 |
-
|
| 381 |
-
// 2.5D Wigner-function rendering on a 2D canvas (no WebGL).
|
| 382 |
-
// For Gaussian states: W(x) = (1 / (2π√det Σ)) exp(-½ (x-μ)ᵀΣ⁻¹(x-μ))
|
| 383 |
-
// ============================================================
|
| 384 |
-
var GRID = 26; // quads per side; 26² × 2 modes ~= 1352 quads/frame
|
| 385 |
var DPR = Math.min(window.devicePixelRatio || 1, 2);
|
| 386 |
|
| 387 |
function WignerView(canvas, coordEl){
|
|
@@ -390,8 +517,8 @@ const HTML = `<!doctype html>
|
|
| 390 |
this.ctx = canvas.getContext('2d');
|
| 391 |
this.mu = [0,0];
|
| 392 |
this.sigma = [[1,0],[0,1]];
|
| 393 |
-
this.yaw = 0.55;
|
| 394 |
-
this.pitch = 0.85;
|
| 395 |
this.userYaw = false;
|
| 396 |
this.dragging = false;
|
| 397 |
this.lastX = 0; this.lastY = 0;
|
|
@@ -440,29 +567,24 @@ const HTML = `<!doctype html>
|
|
| 440 |
WignerView.prototype.draw = function(){
|
| 441 |
var ctx = this.ctx, W = this.canvas.width, H = this.canvas.height;
|
| 442 |
ctx.clearRect(0,0,W,H);
|
| 443 |
-
|
| 444 |
var mu = this.mu, sg = this.sigma;
|
| 445 |
var det = sg[0][0]*sg[1][1] - sg[0][1]*sg[1][0];
|
| 446 |
if(det < 1e-12){ return; }
|
| 447 |
var iv00 = sg[1][1]/det, iv01 = -sg[0][1]/det, iv10 = -sg[1][0]/det, iv11 = sg[0][0]/det;
|
| 448 |
var norm = 1/(2*Math.PI*Math.sqrt(det));
|
| 449 |
-
|
| 450 |
-
// auto-scale to fit ±3σ ellipse around μ
|
| 451 |
var sQ = Math.sqrt(Math.abs(sg[0][0])), sP = Math.sqrt(Math.abs(sg[1][1]));
|
| 452 |
var extent = Math.max(2.6, Math.abs(mu[0]) + 3*sQ, Math.abs(mu[1]) + 3*sP);
|
| 453 |
var scaleXY = (Math.min(W, H) * 0.32) / extent;
|
| 454 |
-
var scaleZ = (Math.min(W, H) * 0.55) * Math.sqrt(det);
|
| 455 |
var ox = W*0.5, oy = H*0.62;
|
| 456 |
-
|
| 457 |
var cy = Math.cos(this.yaw), sy = Math.sin(this.yaw);
|
| 458 |
var cp = Math.cos(this.pitch), sp = Math.sin(this.pitch);
|
| 459 |
-
|
| 460 |
function project(q, p, w){
|
| 461 |
var xr = q*cy - p*sy;
|
| 462 |
var yr = q*sy + p*cy;
|
| 463 |
var sx = ox + xr * scaleXY;
|
| 464 |
var sy_ = oy - yr * scaleXY * cp - w * scaleZ * sp;
|
| 465 |
-
var depth = yr * sp - w * cp;
|
| 466 |
return [sx, sy_, depth];
|
| 467 |
}
|
| 468 |
function projectFlat(q, p){
|
|
@@ -470,8 +592,6 @@ const HTML = `<!doctype html>
|
|
| 470 |
var yr = q*sy + p*cy;
|
| 471 |
return [ox + xr*scaleXY, oy - yr*scaleXY*cp];
|
| 472 |
}
|
| 473 |
-
|
| 474 |
-
// floor grid on (q, p) plane
|
| 475 |
ctx.strokeStyle = 'rgba(28,39,66,0.7)';
|
| 476 |
ctx.lineWidth = 1;
|
| 477 |
var gN = 6;
|
|
@@ -484,7 +604,6 @@ const HTML = `<!doctype html>
|
|
| 484 |
var dFlat = projectFlat( extent, t);
|
| 485 |
ctx.beginPath(); ctx.moveTo(cFlat[0], cFlat[1]); ctx.lineTo(dFlat[0], dFlat[1]); ctx.stroke();
|
| 486 |
}
|
| 487 |
-
// axes: q (cyan), p (indigo) at origin
|
| 488 |
var oFlat = projectFlat(0,0);
|
| 489 |
var qAxis = projectFlat(extent, 0);
|
| 490 |
var pAxis = projectFlat(0, extent);
|
|
@@ -492,8 +611,6 @@ const HTML = `<!doctype html>
|
|
| 492 |
ctx.beginPath(); ctx.moveTo(oFlat[0], oFlat[1]); ctx.lineTo(qAxis[0], qAxis[1]); ctx.stroke();
|
| 493 |
ctx.strokeStyle = 'rgba(129,140,248,0.45)';
|
| 494 |
ctx.beginPath(); ctx.moveTo(oFlat[0], oFlat[1]); ctx.lineTo(pAxis[0], pAxis[1]); ctx.stroke();
|
| 495 |
-
|
| 496 |
-
// sample Wigner on grid + project
|
| 497 |
var N = GRID;
|
| 498 |
var step = (2*extent)/N;
|
| 499 |
var pts = new Array(N+1);
|
|
@@ -512,8 +629,6 @@ const HTML = `<!doctype html>
|
|
| 512 |
pts[i][j] = {w:w, sx:pr[0], sy:pr[1], depth:pr[2]};
|
| 513 |
}
|
| 514 |
}
|
| 515 |
-
|
| 516 |
-
// build quads + sort back-to-front
|
| 517 |
var quads = [];
|
| 518 |
for(var i2=0;i2<N;i2++){
|
| 519 |
for(var j2=0;j2<N;j2++){
|
|
@@ -523,13 +638,10 @@ const HTML = `<!doctype html>
|
|
| 523 |
quads.push({a:a, b:b, c:c, d:d, depth:depth, w:wAvg});
|
| 524 |
}
|
| 525 |
}
|
| 526 |
-
quads.sort(function(x, y){ return y.depth - x.depth; });
|
| 527 |
-
|
| 528 |
-
// draw
|
| 529 |
for(var qi2=0; qi2<quads.length; qi2++){
|
| 530 |
var qd = quads[qi2];
|
| 531 |
var t2 = wmax > 1e-12 ? Math.max(0, Math.min(1, qd.w / wmax)) : 0;
|
| 532 |
-
// indigo (low) -> cyan (high)
|
| 533 |
var rC = Math.round(0x81*(1-t2) + 0x22*t2);
|
| 534 |
var gC = Math.round(0x8c*(1-t2) + 0xd3*t2);
|
| 535 |
var bC = Math.round(0xf8*(1-t2) + 0xee*t2);
|
|
@@ -556,7 +668,6 @@ const HTML = `<!doctype html>
|
|
| 556 |
v1 = new WignerView(c1, document.getElementById('c1'));
|
| 557 |
var resize = function(){ v0.resize(); v1.resize(); };
|
| 558 |
window.addEventListener('resize', resize);
|
| 559 |
-
// vacuum initial
|
| 560 |
v0.setState([0,0],[[1,0],[0,1]]);
|
| 561 |
v1.setState([0,0],[[1,0],[0,1]]);
|
| 562 |
var lastT = performance.now();
|
|
@@ -577,23 +688,21 @@ const HTML = `<!doctype html>
|
|
| 577 |
if(text === lastQuery) return;
|
| 578 |
lastQuery = text;
|
| 579 |
var st = await encodeState(text);
|
| 580 |
-
if(text !== lastQuery) return;
|
| 581 |
currentState = st;
|
| 582 |
var m0 = modeMarginal(st, 0), m1 = modeMarginal(st, 1);
|
| 583 |
v0.setState(m0.mu, m0.sigma);
|
| 584 |
v1.setState(m1.mu, m1.sigma);
|
| 585 |
}
|
| 586 |
-
// init after DOM ready (we're at end of body so DOM exists)
|
| 587 |
initViz();
|
| 588 |
|
| 589 |
var backendSel = document.getElementById('backend');
|
| 590 |
var compareBox = document.getElementById('compare');
|
| 591 |
-
fetch('/health',{cache:'no-store'}).then(function(r){return r.json()}).then(function(j){
|
| 592 |
-
var ok = j && j.ok
|
| 593 |
healthPill.classList.add(ok?'ok':'err');
|
| 594 |
var backends = (j && j.backends_available) || [];
|
| 595 |
healthText.textContent = ok ? (j.default_backend || 'ok') : 'offline';
|
| 596 |
-
// hide options for backends that aren't actually live upstream
|
| 597 |
Array.prototype.forEach.call(backendSel.options, function(opt){
|
| 598 |
opt.disabled = backends.length>0 && backends.indexOf(opt.value) < 0;
|
| 599 |
if (opt.disabled) opt.text = opt.value + ' (n/a)';
|
|
@@ -674,12 +783,11 @@ const HTML = `<!doctype html>
|
|
| 674 |
}
|
| 675 |
|
| 676 |
async function fetchRank(text, topk, backend, sig){
|
| 677 |
-
var url='/rank?q='+encodeURIComponent(text)+'&top_k='+topk+'&backend='+encodeURIComponent(backend);
|
| 678 |
var r=await fetch(url,{signal:sig});
|
| 679 |
if(!r.ok) throw new Error('http '+r.status);
|
| 680 |
var j=await r.json();
|
| 681 |
-
return {backend:j.backend||backend, items:j.results||[],
|
| 682 |
-
cache:r.headers.get('x-photon-route-cache')||''};
|
| 683 |
}
|
| 684 |
|
| 685 |
async function run(){
|
|
@@ -704,7 +812,7 @@ const HTML = `<!doctype html>
|
|
| 704 |
var j = await fetchRank(text, topk, backendSel.value, abort.signal);
|
| 705 |
var ms=(performance.now()-t0).toFixed(0);
|
| 706 |
status.textContent = j.items.length+' result'+(j.items.length===1?'':'s')+
|
| 707 |
-
' · '+ms+' ms
|
| 708 |
render(j.items);
|
| 709 |
return;
|
| 710 |
}
|
|
@@ -745,114 +853,4 @@ const HTML = `<!doctype html>
|
|
| 745 |
})();
|
| 746 |
</script>
|
| 747 |
</body>
|
| 748 |
-
</html>
|
| 749 |
-
|
| 750 |
-
addEventListener('fetch', (e) => e.respondWith(handle(e.request)));
|
| 751 |
-
|
| 752 |
-
async function handle(req) {
|
| 753 |
-
const url = new URL(req.url);
|
| 754 |
-
const path = url.pathname;
|
| 755 |
-
|
| 756 |
-
if (req.method === 'OPTIONS') {
|
| 757 |
-
return cors(new Response(null, { status: 204 }));
|
| 758 |
-
}
|
| 759 |
-
|
| 760 |
-
if (path === '/' && (req.method === 'GET' || req.method === 'HEAD')) {
|
| 761 |
-
return new Response(req.method === 'HEAD' ? null : HTML, {
|
| 762 |
-
headers: {
|
| 763 |
-
'content-type': 'text/html; charset=utf-8',
|
| 764 |
-
'content-security-policy': CSP,
|
| 765 |
-
'referrer-policy': 'strict-origin-when-cross-origin',
|
| 766 |
-
'x-content-type-options': 'nosniff',
|
| 767 |
-
'cache-control': 'public, max-age=300',
|
| 768 |
-
},
|
| 769 |
-
});
|
| 770 |
-
}
|
| 771 |
-
|
| 772 |
-
if (path === '/api' || path === '/info') {
|
| 773 |
-
return jsonResp(BANNER);
|
| 774 |
-
}
|
| 775 |
-
|
| 776 |
-
if (path === '/health') {
|
| 777 |
-
// Merge upstream /health so the UI knows which backends are live.
|
| 778 |
-
let upstream = null;
|
| 779 |
-
try {
|
| 780 |
-
const r = await fetch(HF + '/health', { cf: { cacheTtl: 30 } });
|
| 781 |
-
if (r.ok) upstream = await r.json();
|
| 782 |
-
} catch (_) {}
|
| 783 |
-
return jsonResp({
|
| 784 |
-
ok: true,
|
| 785 |
-
proxy: 'cloudflare-worker',
|
| 786 |
-
upstream_ok: upstream ? !!upstream.ok : false,
|
| 787 |
-
backends_available: upstream && upstream.backends_available
|
| 788 |
-
? upstream.backends_available : ['stub'],
|
| 789 |
-
default_backend: upstream && upstream.default_backend
|
| 790 |
-
? upstream.default_backend : 'stub',
|
| 791 |
-
weights_loaded: upstream ? !!upstream.weights_loaded : false,
|
| 792 |
-
});
|
| 793 |
-
}
|
| 794 |
-
|
| 795 |
-
if (PROXY.has(path) && (req.method === 'GET' || req.method === 'HEAD')) {
|
| 796 |
-
return await proxied(req, url, path);
|
| 797 |
-
}
|
| 798 |
-
|
| 799 |
-
return jsonResp({ error: 'not found' }, 404);
|
| 800 |
-
}
|
| 801 |
-
|
| 802 |
-
function jsonResp(obj, status = 200) {
|
| 803 |
-
return new Response(JSON.stringify(obj, null, 2), {
|
| 804 |
-
status,
|
| 805 |
-
headers: {
|
| 806 |
-
'content-type': 'application/json; charset=utf-8',
|
| 807 |
-
'access-control-allow-origin': '*',
|
| 808 |
-
'cache-control': 'no-store',
|
| 809 |
-
},
|
| 810 |
-
});
|
| 811 |
-
}
|
| 812 |
-
|
| 813 |
-
function cors(resp) {
|
| 814 |
-
resp.headers.set('access-control-allow-origin', '*');
|
| 815 |
-
resp.headers.set('access-control-allow-methods', 'GET, HEAD, OPTIONS');
|
| 816 |
-
resp.headers.set('access-control-allow-headers', 'content-type');
|
| 817 |
-
return resp;
|
| 818 |
-
}
|
| 819 |
-
|
| 820 |
-
async function proxied(req, url, path) {
|
| 821 |
-
const upstream = new URL(HF);
|
| 822 |
-
upstream.pathname = path;
|
| 823 |
-
upstream.search = url.search;
|
| 824 |
-
|
| 825 |
-
const isRank = path === '/rank';
|
| 826 |
-
const cacheKey = new Request(upstream.toString() + '#' + CACHE_VERSION, req);
|
| 827 |
-
const cache = caches.default;
|
| 828 |
-
|
| 829 |
-
if (isRank) {
|
| 830 |
-
const hit = await cache.match(cacheKey);
|
| 831 |
-
if (hit) {
|
| 832 |
-
const r = new Response(hit.body, hit);
|
| 833 |
-
r.headers.set('x-photon-route-cache', 'hit');
|
| 834 |
-
return cors(r);
|
| 835 |
-
}
|
| 836 |
-
}
|
| 837 |
-
|
| 838 |
-
const fetched = await fetch(upstream.toString(), {
|
| 839 |
-
method: req.method,
|
| 840 |
-
headers: { accept: req.headers.get('accept') || '*/*' },
|
| 841 |
-
cf: { cacheTtl: 0 },
|
| 842 |
-
});
|
| 843 |
-
|
| 844 |
-
if (!fetched.ok || fetched.status >= 500) {
|
| 845 |
-
const r = new Response(fetched.body, fetched);
|
| 846 |
-
r.headers.set('x-photon-route-cache', 'bypass');
|
| 847 |
-
return cors(r);
|
| 848 |
-
}
|
| 849 |
-
|
| 850 |
-
const body = await fetched.arrayBuffer();
|
| 851 |
-
const headers = new Headers(fetched.headers);
|
| 852 |
-
headers.delete('set-cookie');
|
| 853 |
-
if (isRank) headers.set('cache-control', 'public, s-maxage=86400, max-age=300');
|
| 854 |
-
const ok = new Response(body, { status: fetched.status, headers });
|
| 855 |
-
ok.headers.set('x-photon-route-cache', 'miss');
|
| 856 |
-
if (isRank) await cache.put(cacheKey, ok.clone());
|
| 857 |
-
return cors(ok);
|
| 858 |
-
}
|
|
|
|
| 1 |
+
<!doctype html>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
<html lang="en">
|
| 3 |
<head>
|
| 4 |
<meta charset="utf-8">
|
|
|
|
| 108 |
border-bottom:1px dotted var(--line)}
|
| 109 |
footer a:hover{color:var(--fg);border-bottom-color:var(--cyan)}
|
| 110 |
.empty{color:var(--muted);text-align:center;padding:32px 12px;font-size:12px}
|
|
|
|
| 111 |
.viz{margin:14px 0 18px;display:grid;grid-template-columns:1fr 1fr;gap:10px}
|
| 112 |
.modepanel{position:relative;background:var(--panel);border:1px solid var(--line);
|
| 113 |
border-radius:var(--radius);overflow:hidden;aspect-ratio:5/4;min-height:200px}
|
|
|
|
| 144 |
</style>
|
| 145 |
</head>
|
| 146 |
<body>
|
| 147 |
+
<style>
|
| 148 |
+
.nav { position: fixed; inset: 0 0 auto 0;
|
| 149 |
+
display: flex; align-items: center; justify-content: space-between;
|
| 150 |
+
padding: 14px clamp(12px, 3vw, 24px);
|
| 151 |
+
z-index: 100; pointer-events: none;
|
| 152 |
+
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", system-ui, sans-serif; }
|
| 153 |
+
.nav > * { pointer-events: auto; }
|
| 154 |
+
/* Nav-scoped: the page body has its own <div class="brand"> gradient
|
| 155 |
+
for the page title — don't let nav .brand rules override it. */
|
| 156 |
+
.nav .brand { color: #e9eef7; font-weight: 600; letter-spacing: 0.04em;
|
| 157 |
+
text-decoration: none; font-size: 15px;
|
| 158 |
+
background: none; -webkit-text-fill-color: currentColor; }
|
| 159 |
+
.nav .brand .brand-glyph { color: #ffa276; margin-right: 6px; }
|
| 160 |
+
.burger { width: 40px; height: 40px; padding: 0;
|
| 161 |
+
background: rgba(167,139,250,0.05);
|
| 162 |
+
border: 1px solid rgba(167,139,250,0.18);
|
| 163 |
+
border-radius: 10px; cursor: pointer; position: relative;
|
| 164 |
+
transition: background 0.15s, border-color 0.15s, box-shadow 0.15s; }
|
| 165 |
+
.burger:hover { background: rgba(167,139,250,0.12); border-color: #a78bfa; box-shadow: 0 0 14px rgba(167,139,250,0.25); }
|
| 166 |
+
.burger span { position: absolute; left: 11px; right: 11px; height: 2px;
|
| 167 |
+
background: #c9d4ec; border-radius: 1px;
|
| 168 |
+
transition: top 0.18s, transform 0.18s, opacity 0.18s; }
|
| 169 |
+
.burger span:nth-child(1) { top: 13px; }
|
| 170 |
+
.burger span:nth-child(2) { top: 19px; }
|
| 171 |
+
.burger span:nth-child(3) { top: 25px; }
|
| 172 |
+
.burger.open { background: rgba(167,139,250,0.18); border-color: #a78bfa; }
|
| 173 |
+
.burger.open span:nth-child(1) { top: 19px; transform: rotate(45deg); }
|
| 174 |
+
.burger.open span:nth-child(2) { opacity: 0; }
|
| 175 |
+
.burger.open span:nth-child(3) { top: 19px; transform: rotate(-45deg); }
|
| 176 |
+
.nav-menu { position: fixed; top: 64px; right: clamp(12px, 3vw, 24px);
|
| 177 |
+
width: min(330px, calc(100vw - 24px));
|
| 178 |
+
display: flex; flex-direction: column; gap: 2px; padding: 14px;
|
| 179 |
+
background: rgba(10, 13, 20, 0.96); backdrop-filter: blur(16px);
|
| 180 |
+
border: 1px solid rgba(167, 139, 250, 0.3); border-radius: 14px;
|
| 181 |
+
box-shadow: 0 20px 50px rgba(0,0,0,0.5); transform-origin: top right;
|
| 182 |
+
transform: translateY(-12px) scale(0.97); opacity: 0; pointer-events: none;
|
| 183 |
+
transition: transform 0.22s cubic-bezier(0.16,1,0.3,1), opacity 0.2s ease;
|
| 184 |
+
z-index: 110; max-height: calc(100vh - 80px); overflow-y: auto; }
|
| 185 |
+
.nav-menu.open { transform: translateY(0) scale(1); opacity: 1; pointer-events: auto; }
|
| 186 |
+
.nav-menu a { display: block; padding: 10px 14px; border-radius: 8px;
|
| 187 |
+
font-size: 14px; color: #9bb6ea; text-decoration: none;
|
| 188 |
+
transition: background 0.15s, color 0.15s; }
|
| 189 |
+
.nav-menu a:hover, .nav-menu .current { background: rgba(167,139,250,0.10); color: #fff; }
|
| 190 |
+
.nav-group { border-top: 1px solid rgba(167, 139, 250, 0.10); margin-top: 6px; padding-top: 4px; }
|
| 191 |
+
.nav-group:first-of-type { border-top: 0; margin-top: 4px; padding-top: 0; }
|
| 192 |
+
.nav-group > summary { cursor: pointer; list-style: none; user-select: none; }
|
| 193 |
+
.nav-group > summary::-webkit-details-marker { display: none; }
|
| 194 |
+
.nav-section { font-size: 11px; letter-spacing: 0.12em; text-transform: uppercase;
|
| 195 |
+
color: rgba(167, 139, 250, 0.75); padding: 10px 14px 6px; font-weight: 600;
|
| 196 |
+
display: flex; align-items: center; justify-content: space-between; }
|
| 197 |
+
.nav-section::after { content: '+'; font-size: 14px;
|
| 198 |
+
color: rgba(167, 139, 250, 0.55);
|
| 199 |
+
transition: transform 0.18s, color 0.18s; }
|
| 200 |
+
.nav-group[open] > summary .nav-section::after { transform: rotate(45deg); color: #a78bfa; }
|
| 201 |
+
.nav-group > summary:hover .nav-section { color: #fff; }
|
| 202 |
+
.nav-app { background: rgba(167,139,250,0.04);
|
| 203 |
+
border: 1px solid rgba(167,139,250,0.10);
|
| 204 |
+
margin: 4px 0; line-height: 1.3;
|
| 205 |
+
padding: 10px 14px !important; border-radius: 10px !important; }
|
| 206 |
+
.nav-app:hover { background: rgba(167,139,250,0.12); border-color: rgba(167,139,250,0.32); }
|
| 207 |
+
.nav-app .nav-app-name { display: block; font-size: 14px; color: #e9eef7; font-weight: 500; }
|
| 208 |
+
.nav-app .nav-app-tag { display: block; font-size: 12px; color: #9bb6ea; margin-top: 2px; }
|
| 209 |
+
.nav-app .nav-app-emoji { margin-right: 6px; }
|
| 210 |
+
</style>
|
| 211 |
+
|
| 212 |
+
<nav class="nav" aria-label="Primary">
|
| 213 |
+
<a href="https://ask-meridian.uk/" class="brand" style="view-transition-name: brand">◎ Meridian</a>
|
| 214 |
+
|
| 215 |
+
<button id="burgerBtn" class="burger" type="button" aria-label="Toggle navigation menu" aria-expanded="false" aria-controls="navMenu">
|
| 216 |
+
<span aria-hidden="true"></span><span aria-hidden="true"></span><span aria-hidden="true"></span>
|
| 217 |
+
</button>
|
| 218 |
+
|
| 219 |
+
<div id="navMenu" class="nav-menu" role="menu">
|
| 220 |
+
<a href="https://ask-meridian.uk/">Home</a>
|
| 221 |
+
|
| 222 |
+
<details class="nav-group" open>
|
| 223 |
+
<summary class="nav-section">Showcase</summary>
|
| 224 |
+
<a href="https://meridian.ask-meridian.uk/helix/" class="nav-app" data-status="live">
|
| 225 |
+
<span class="nav-app-name"><span class="nav-app-emoji">🧬</span>helix · proteins</span>
|
| 226 |
+
<span class="nav-app-tag">Injury → top therapeutic protein candidates, each rendered as its own star system</span>
|
| 227 |
+
</a>
|
| 228 |
+
<a href="https://ask-meridian.uk/miniapp/" class="nav-app" data-status="live">
|
| 229 |
+
<span class="nav-app-name"><span class="nav-app-emoji">🛰️</span>Try it · Task orbit</span>
|
| 230 |
+
<span class="nav-app-tag">Browser miniapp · routes any task to candidates</span>
|
| 231 |
+
</a>
|
| 232 |
+
<a href="https://ask-meridian.uk/miniapp/vision-lab/" class="nav-app" data-status="live">
|
| 233 |
+
<span class="nav-app-name"><span class="nav-app-emoji">🔭</span>Vision Lab</span>
|
| 234 |
+
<span class="nav-app-tag">SmolVLM / Moondream in browser via WebGPU</span>
|
| 235 |
+
</a>
|
| 236 |
+
<a href="https://photon.ask-meridian.uk" class="nav-app" data-status="live">
|
| 237 |
+
<span class="nav-app-name"><span class="nav-app-emoji">⚛︎</span>Photon Router</span>
|
| 238 |
+
<span class="nav-app-tag">CV photonic retrieval · trained on HF Space</span>
|
| 239 |
+
</a>
|
| 240 |
+
<a href="https://lens.ask-meridian.uk" class="nav-app" data-status="live">
|
| 241 |
+
<span class="nav-app-name"><span class="nav-app-emoji">◎</span>Lens · WebXR</span>
|
| 242 |
+
<span class="nav-app-tag">Vision Lab in VR · controllers, raycaster, orbit</span>
|
| 243 |
+
</a>
|
| 244 |
+
</details>
|
| 245 |
+
|
| 246 |
+
<details class="nav-group">
|
| 247 |
+
<summary class="nav-section">Resources</summary>
|
| 248 |
+
<a href="https://ask-meridian.uk/blog/">Blog</a>
|
| 249 |
+
<a href="https://ask-meridian.uk/docs/">Docs</a>
|
| 250 |
+
<a href="https://ask-meridian.uk/#pricing">Pricing</a>
|
| 251 |
+
</details>
|
| 252 |
+
|
| 253 |
+
<details class="nav-group">
|
| 254 |
+
<summary class="nav-section">Source</summary>
|
| 255 |
+
<a href="https://github.com/LuuOW/meridian-mcp">GitHub · meridian-mcp</a>
|
| 256 |
+
<a href="https://github.com/LuuOW/photon-route">GitHub · photon-route</a>
|
| 257 |
+
</details>
|
| 258 |
+
</div>
|
| 259 |
+
|
| 260 |
+
<!-- Self-contained burger + current-link behaviour. Lives inside <nav>
|
| 261 |
+
so sync-nav.py treats it as part of the synced block — every page
|
| 262 |
+
gets the same wiring without needing a separate <script> import.
|
| 263 |
+
Re-runs on DOMContentLoaded too so it survives module-script
|
| 264 |
+
races and weird parse orders. -->
|
| 265 |
+
<script>
|
| 266 |
+
(function () {
|
| 267 |
+
function setup() {
|
| 268 |
+
var btn = document.getElementById('burgerBtn')
|
| 269 |
+
var menu = document.getElementById('navMenu')
|
| 270 |
+
if (!btn || !menu) return
|
| 271 |
+
if (btn.dataset.wired) return
|
| 272 |
+
btn.dataset.wired = '1'
|
| 273 |
+
|
| 274 |
+
function toggle(open) {
|
| 275 |
+
var isOpen = open === undefined ? !menu.classList.contains('open') : open
|
| 276 |
+
menu.classList.toggle('open', isOpen)
|
| 277 |
+
btn.classList.toggle('open', isOpen)
|
| 278 |
+
btn.setAttribute('aria-expanded', String(isOpen))
|
| 279 |
+
}
|
| 280 |
+
btn.addEventListener('click', function () { toggle() })
|
| 281 |
+
menu.querySelectorAll('a').forEach(function (a) {
|
| 282 |
+
a.addEventListener('click', function () { toggle(false) })
|
| 283 |
+
})
|
| 284 |
+
document.addEventListener('click', function (e) {
|
| 285 |
+
if (!menu.classList.contains('open')) return
|
| 286 |
+
if (!menu.contains(e.target) && !btn.contains(e.target)) toggle(false)
|
| 287 |
+
})
|
| 288 |
+
document.addEventListener('keydown', function (e) {
|
| 289 |
+
if (e.key === 'Escape') toggle(false)
|
| 290 |
+
})
|
| 291 |
+
|
| 292 |
+
// Cross-host aware: same protocol+host+path → current. Lets the
|
| 293 |
+
// highlight follow you whether you're on ask-meridian.uk,
|
| 294 |
+
// meridian.ask-meridian.uk, or photon.ask-meridian.uk.
|
| 295 |
+
var here = location.host + location.pathname.replace(/\/index\.html$/, '/')
|
| 296 |
+
menu.querySelectorAll('a').forEach(function (a) {
|
| 297 |
+
var href = a.getAttribute('href')
|
| 298 |
+
if (!href) return
|
| 299 |
+
try {
|
| 300 |
+
var u = new URL(href, location.href)
|
| 301 |
+
var target = u.host + u.pathname.replace(/\/index\.html$/, '/')
|
| 302 |
+
if (target === here) a.classList.add('current')
|
| 303 |
+
} catch (_) {}
|
| 304 |
+
})
|
| 305 |
+
}
|
| 306 |
+
if (document.readyState === 'loading') {
|
| 307 |
+
document.addEventListener('DOMContentLoaded', setup, { once: true })
|
| 308 |
+
} else {
|
| 309 |
+
setup()
|
| 310 |
+
}
|
| 311 |
+
})()
|
| 312 |
+
</script>
|
| 313 |
+
</nav>
|
| 314 |
+
|
| 315 |
+
<script>
|
| 316 |
+
(function () {
|
| 317 |
+
const btn = document.querySelector('.burger');
|
| 318 |
+
const menu = document.getElementById('navMenu');
|
| 319 |
+
if (!btn || !menu) return;
|
| 320 |
+
const set = (open) => {
|
| 321 |
+
menu.classList.toggle('open', open);
|
| 322 |
+
btn.classList.toggle('open', open);
|
| 323 |
+
btn.setAttribute('aria-expanded', String(open));
|
| 324 |
+
};
|
| 325 |
+
btn.addEventListener('click', e => { e.stopPropagation(); set(!menu.classList.contains('open')); });
|
| 326 |
+
document.addEventListener('click', e => {
|
| 327 |
+
if (!menu.classList.contains('open')) return;
|
| 328 |
+
if (!menu.contains(e.target) && !btn.contains(e.target)) set(false);
|
| 329 |
+
});
|
| 330 |
+
document.addEventListener('keydown', e => { if (e.key === 'Escape') set(false); });
|
| 331 |
+
menu.querySelectorAll('a').forEach(a => a.addEventListener('click', () => set(false)));
|
| 332 |
+
})();
|
| 333 |
+
</script>
|
| 334 |
<main>
|
| 335 |
<header>
|
| 336 |
<div>
|
|
|
|
| 382 |
<div class="body">
|
| 383 |
<p><strong>photon-route</strong> is a research artifact exploring whether semantic retrieval can run in the continuous-variable (CV) photonic regime — the regime that real photonic hardware (Xanadu Borealis, fiber-loop reservoirs, coherent Ising machines) actually operates in.</p>
|
| 384 |
<p>Each document is encoded as a <em>Gaussian state</em> over N bosonic modes via a <a href="https://strawberryfields.ai/" target="_blank" rel="noopener">Strawberry Fields</a> program: words contribute squeezing and displacement operations, then a beam-splitter network mixes the modes. Query and document fidelity is computed in closed form using the <a href="https://the-walrus.readthedocs.io/" target="_blank" rel="noopener">thewalrus</a> implementation of the Banchi-Braunstein-Pirandola formula.</p>
|
| 385 |
+
<p>Three swappable encoders share the same fidelity scoring. <strong>v1</strong> uses Strawberry Fields to run an N-mode Gaussian program with SHA-256-derived parameters. <strong>sha_init</strong> is a pure-numpy port of the same gates (no SF at deploy time). <strong>trained</strong> swaps the SHA-256 lookup for a learned table fit by InfoNCE + Bhattacharyya-coefficient surrogate fidelity on a small arXiv quant-ph eval set. Toggle <em>compare</em> to see the three side by side.</p>
|
| 386 |
<p>Source · <a href="https://github.com/LuuOW/photon-route" target="_blank" rel="noopener">github.com/LuuOW/photon-route</a></p>
|
| 387 |
</div>
|
| 388 |
</details>
|
| 389 |
|
| 390 |
<footer>
|
| 391 |
+
<span>CV photonic · gaussian backend · UI on GitHub Pages, retrieval on HF Space</span>
|
| 392 |
+
<span><a href="https://luuow-photon-route.hf.space/docs" rel="noopener">/docs</a> · <a href="https://luuow-photon-route.hf.space/health" rel="noopener">/health</a></span>
|
| 393 |
</footer>
|
| 394 |
</main>
|
| 395 |
|
| 396 |
<script>
|
| 397 |
(function(){
|
| 398 |
+
// Backend lives on Hugging Face Spaces. CORS is wide-open on the FastAPI app
|
| 399 |
+
// so the browser hits this directly. Override with a ?api=https://… query
|
| 400 |
+
// string for local Space testing.
|
| 401 |
+
var SPACE = (new URL(location.href)).searchParams.get('api') || 'https://luuow-photon-route.hf.space';
|
| 402 |
+
|
| 403 |
function $(id){return document.getElementById(id)}
|
| 404 |
var q=$('q'), k=$('k'), results=$('results'), status=$('status');
|
| 405 |
var healthPill=$('health'), healthText=$('health-text');
|
| 406 |
var abort=null, debounceT=0;
|
| 407 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 408 |
var N_MODES = 2, MAX_SQUEEZE = 0.5, MAX_DISPLACE = 1.0;
|
| 409 |
var _wcache = new Map();
|
| 410 |
|
|
|
|
| 416 |
var hash = await crypto.subtle.digest('SHA-256', buf);
|
| 417 |
bytes = new Uint8Array(hash);
|
| 418 |
} catch(e){
|
|
|
|
| 419 |
bytes = new Uint8Array(32);
|
| 420 |
var h = 2166136261;
|
| 421 |
for(var i=0;i<word.length;i++){ h ^= word.charCodeAt(i); h = (h*16777619)>>>0; }
|
|
|
|
| 423 |
}
|
| 424 |
var parts = [];
|
| 425 |
for(var i2=0;i2<4;i2++){
|
|
|
|
| 426 |
var big = 0n;
|
| 427 |
for(var j2=0;j2<8;j2++) big = (big << 8n) + BigInt(bytes[i2*8 + j2]);
|
| 428 |
parts.push(Number(big % 1000000000n) / 1e9);
|
|
|
|
| 454 |
function tr4(A){var T=mat4(); for(var i=0;i<4;i++) for(var j=0;j<4;j++) T[i][j]=A[j][i]; return T;}
|
| 455 |
|
| 456 |
function sgateMat(k, r, phi){
|
|
|
|
| 457 |
var c=Math.cos(phi/2), s=Math.sin(phi/2), em=Math.exp(-r), ep=Math.exp(r);
|
| 458 |
var a = c*c*em + s*s*ep;
|
| 459 |
var b = c*s*(em - ep);
|
|
|
|
| 508 |
};
|
| 509 |
}
|
| 510 |
|
| 511 |
+
var GRID = 26;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 512 |
var DPR = Math.min(window.devicePixelRatio || 1, 2);
|
| 513 |
|
| 514 |
function WignerView(canvas, coordEl){
|
|
|
|
| 517 |
this.ctx = canvas.getContext('2d');
|
| 518 |
this.mu = [0,0];
|
| 519 |
this.sigma = [[1,0],[0,1]];
|
| 520 |
+
this.yaw = 0.55;
|
| 521 |
+
this.pitch = 0.85;
|
| 522 |
this.userYaw = false;
|
| 523 |
this.dragging = false;
|
| 524 |
this.lastX = 0; this.lastY = 0;
|
|
|
|
| 567 |
WignerView.prototype.draw = function(){
|
| 568 |
var ctx = this.ctx, W = this.canvas.width, H = this.canvas.height;
|
| 569 |
ctx.clearRect(0,0,W,H);
|
|
|
|
| 570 |
var mu = this.mu, sg = this.sigma;
|
| 571 |
var det = sg[0][0]*sg[1][1] - sg[0][1]*sg[1][0];
|
| 572 |
if(det < 1e-12){ return; }
|
| 573 |
var iv00 = sg[1][1]/det, iv01 = -sg[0][1]/det, iv10 = -sg[1][0]/det, iv11 = sg[0][0]/det;
|
| 574 |
var norm = 1/(2*Math.PI*Math.sqrt(det));
|
|
|
|
|
|
|
| 575 |
var sQ = Math.sqrt(Math.abs(sg[0][0])), sP = Math.sqrt(Math.abs(sg[1][1]));
|
| 576 |
var extent = Math.max(2.6, Math.abs(mu[0]) + 3*sQ, Math.abs(mu[1]) + 3*sP);
|
| 577 |
var scaleXY = (Math.min(W, H) * 0.32) / extent;
|
| 578 |
+
var scaleZ = (Math.min(W, H) * 0.55) * Math.sqrt(det);
|
| 579 |
var ox = W*0.5, oy = H*0.62;
|
|
|
|
| 580 |
var cy = Math.cos(this.yaw), sy = Math.sin(this.yaw);
|
| 581 |
var cp = Math.cos(this.pitch), sp = Math.sin(this.pitch);
|
|
|
|
| 582 |
function project(q, p, w){
|
| 583 |
var xr = q*cy - p*sy;
|
| 584 |
var yr = q*sy + p*cy;
|
| 585 |
var sx = ox + xr * scaleXY;
|
| 586 |
var sy_ = oy - yr * scaleXY * cp - w * scaleZ * sp;
|
| 587 |
+
var depth = yr * sp - w * cp;
|
| 588 |
return [sx, sy_, depth];
|
| 589 |
}
|
| 590 |
function projectFlat(q, p){
|
|
|
|
| 592 |
var yr = q*sy + p*cy;
|
| 593 |
return [ox + xr*scaleXY, oy - yr*scaleXY*cp];
|
| 594 |
}
|
|
|
|
|
|
|
| 595 |
ctx.strokeStyle = 'rgba(28,39,66,0.7)';
|
| 596 |
ctx.lineWidth = 1;
|
| 597 |
var gN = 6;
|
|
|
|
| 604 |
var dFlat = projectFlat( extent, t);
|
| 605 |
ctx.beginPath(); ctx.moveTo(cFlat[0], cFlat[1]); ctx.lineTo(dFlat[0], dFlat[1]); ctx.stroke();
|
| 606 |
}
|
|
|
|
| 607 |
var oFlat = projectFlat(0,0);
|
| 608 |
var qAxis = projectFlat(extent, 0);
|
| 609 |
var pAxis = projectFlat(0, extent);
|
|
|
|
| 611 |
ctx.beginPath(); ctx.moveTo(oFlat[0], oFlat[1]); ctx.lineTo(qAxis[0], qAxis[1]); ctx.stroke();
|
| 612 |
ctx.strokeStyle = 'rgba(129,140,248,0.45)';
|
| 613 |
ctx.beginPath(); ctx.moveTo(oFlat[0], oFlat[1]); ctx.lineTo(pAxis[0], pAxis[1]); ctx.stroke();
|
|
|
|
|
|
|
| 614 |
var N = GRID;
|
| 615 |
var step = (2*extent)/N;
|
| 616 |
var pts = new Array(N+1);
|
|
|
|
| 629 |
pts[i][j] = {w:w, sx:pr[0], sy:pr[1], depth:pr[2]};
|
| 630 |
}
|
| 631 |
}
|
|
|
|
|
|
|
| 632 |
var quads = [];
|
| 633 |
for(var i2=0;i2<N;i2++){
|
| 634 |
for(var j2=0;j2<N;j2++){
|
|
|
|
| 638 |
quads.push({a:a, b:b, c:c, d:d, depth:depth, w:wAvg});
|
| 639 |
}
|
| 640 |
}
|
| 641 |
+
quads.sort(function(x, y){ return y.depth - x.depth; });
|
|
|
|
|
|
|
| 642 |
for(var qi2=0; qi2<quads.length; qi2++){
|
| 643 |
var qd = quads[qi2];
|
| 644 |
var t2 = wmax > 1e-12 ? Math.max(0, Math.min(1, qd.w / wmax)) : 0;
|
|
|
|
| 645 |
var rC = Math.round(0x81*(1-t2) + 0x22*t2);
|
| 646 |
var gC = Math.round(0x8c*(1-t2) + 0xd3*t2);
|
| 647 |
var bC = Math.round(0xf8*(1-t2) + 0xee*t2);
|
|
|
|
| 668 |
v1 = new WignerView(c1, document.getElementById('c1'));
|
| 669 |
var resize = function(){ v0.resize(); v1.resize(); };
|
| 670 |
window.addEventListener('resize', resize);
|
|
|
|
| 671 |
v0.setState([0,0],[[1,0],[0,1]]);
|
| 672 |
v1.setState([0,0],[[1,0],[0,1]]);
|
| 673 |
var lastT = performance.now();
|
|
|
|
| 688 |
if(text === lastQuery) return;
|
| 689 |
lastQuery = text;
|
| 690 |
var st = await encodeState(text);
|
| 691 |
+
if(text !== lastQuery) return;
|
| 692 |
currentState = st;
|
| 693 |
var m0 = modeMarginal(st, 0), m1 = modeMarginal(st, 1);
|
| 694 |
v0.setState(m0.mu, m0.sigma);
|
| 695 |
v1.setState(m1.mu, m1.sigma);
|
| 696 |
}
|
|
|
|
| 697 |
initViz();
|
| 698 |
|
| 699 |
var backendSel = document.getElementById('backend');
|
| 700 |
var compareBox = document.getElementById('compare');
|
| 701 |
+
fetch(SPACE+'/health',{cache:'no-store'}).then(function(r){return r.json()}).then(function(j){
|
| 702 |
+
var ok = j && j.ok;
|
| 703 |
healthPill.classList.add(ok?'ok':'err');
|
| 704 |
var backends = (j && j.backends_available) || [];
|
| 705 |
healthText.textContent = ok ? (j.default_backend || 'ok') : 'offline';
|
|
|
|
| 706 |
Array.prototype.forEach.call(backendSel.options, function(opt){
|
| 707 |
opt.disabled = backends.length>0 && backends.indexOf(opt.value) < 0;
|
| 708 |
if (opt.disabled) opt.text = opt.value + ' (n/a)';
|
|
|
|
| 783 |
}
|
| 784 |
|
| 785 |
async function fetchRank(text, topk, backend, sig){
|
| 786 |
+
var url=SPACE+'/rank?q='+encodeURIComponent(text)+'&top_k='+topk+'&backend='+encodeURIComponent(backend);
|
| 787 |
var r=await fetch(url,{signal:sig});
|
| 788 |
if(!r.ok) throw new Error('http '+r.status);
|
| 789 |
var j=await r.json();
|
| 790 |
+
return {backend:j.backend||backend, items:j.results||[], cache:''};
|
|
|
|
| 791 |
}
|
| 792 |
|
| 793 |
async function run(){
|
|
|
|
| 812 |
var j = await fetchRank(text, topk, backendSel.value, abort.signal);
|
| 813 |
var ms=(performance.now()-t0).toFixed(0);
|
| 814 |
status.textContent = j.items.length+' result'+(j.items.length===1?'':'s')+
|
| 815 |
+
' · '+ms+' ms · backend '+j.backend;
|
| 816 |
render(j.items);
|
| 817 |
return;
|
| 818 |
}
|
|
|
|
| 853 |
})();
|
| 854 |
</script>
|
| 855 |
</body>
|
| 856 |
+
</html>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
space/analyze_sweep.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Read sweep_results.csv, decide whether SBERT-photon's gain is real
|
| 2 |
+
and whether photon-number-distribution metric (A3-Simple) outperforms
|
| 3 |
+
Gaussian-state-overlap (BBP fidelity) on the same trained encoder.
|
| 4 |
+
|
| 5 |
+
Q1. Does SBERT-photon (full, gaussian metric) robustly beat raw SBERT?
|
| 6 |
+
Q2. Does the squeezing layer specifically pay (full vs no-squeeze)?
|
| 7 |
+
Q3. A3-Simple: does photon-prob metric > gaussian metric on same encoder?
|
| 8 |
+
Q4. Generalization tax (train − test nDCG@10).
|
| 9 |
+
"""
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import csv
|
| 14 |
+
import statistics
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
SBERT_ALONE_NDCG10 = 0.385
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def stat(rs, key):
|
| 22 |
+
vals = [r[key] for r in rs if r[key] == r[key]]
|
| 23 |
+
if len(vals) < 2:
|
| 24 |
+
return (vals[0] if vals else float("nan"), 0.0, len(vals))
|
| 25 |
+
return statistics.mean(vals), statistics.stdev(vals), len(vals)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def main():
|
| 29 |
+
ap = argparse.ArgumentParser()
|
| 30 |
+
ap.add_argument("--csv", type=Path, default=Path(__file__).resolve().parent.parent / "sweep_results.csv")
|
| 31 |
+
args = ap.parse_args()
|
| 32 |
+
|
| 33 |
+
rows = list(csv.DictReader(args.csv.open()))
|
| 34 |
+
for r in rows:
|
| 35 |
+
for k, v in list(r.items()):
|
| 36 |
+
try:
|
| 37 |
+
r[k] = float(v)
|
| 38 |
+
except (ValueError, TypeError):
|
| 39 |
+
pass
|
| 40 |
+
|
| 41 |
+
full = sorted([r for r in rows if int(r["no_squeeze"]) == 0], key=lambda r: r["seed"])
|
| 42 |
+
nosq = sorted([r for r in rows if int(r["no_squeeze"]) == 1], key=lambda r: r["seed"])
|
| 43 |
+
|
| 44 |
+
print(f"loaded {len(rows)} runs from {args.csv}")
|
| 45 |
+
print(f" full (squeezing on): n={len(full)}")
|
| 46 |
+
print(f" no-squeeze: n={len(nosq)}")
|
| 47 |
+
|
| 48 |
+
print("\n=== Q1. SBERT-photon (full, gaussian metric) vs raw SBERT (0.385) ===")
|
| 49 |
+
m, s, n = stat(full, "test_gaussian_n10")
|
| 50 |
+
delta = m - SBERT_ALONE_NDCG10
|
| 51 |
+
z = delta / s if s > 0 else float("inf")
|
| 52 |
+
verdict = "✓ YES" if delta > s and m > SBERT_ALONE_NDCG10 else "✗ noisy or no gain"
|
| 53 |
+
print(f" full mean test nDCG@10 (gaussian) = {m:.3f} ± {s:.3f} (n={n})")
|
| 54 |
+
print(f" Δ vs raw SBERT = {delta:+.3f} (Δ/σ ≈ {z:+.2f}) → {verdict}")
|
| 55 |
+
|
| 56 |
+
print("\n=== Q2. Squeezing pays? (paired full − no_squeeze, gaussian metric) ===")
|
| 57 |
+
paired = []
|
| 58 |
+
for r in full:
|
| 59 |
+
n_row = next((x for x in nosq if x["seed"] == r["seed"]), None)
|
| 60 |
+
if n_row:
|
| 61 |
+
paired.append((int(r["seed"]), r["test_gaussian_n10"], n_row["test_gaussian_n10"]))
|
| 62 |
+
diffs = [a - b for _, a, b in paired]
|
| 63 |
+
m_d = statistics.mean(diffs) if diffs else float("nan")
|
| 64 |
+
s_d = statistics.stdev(diffs) if len(diffs) > 1 else 0.0
|
| 65 |
+
for sid, a, b in paired:
|
| 66 |
+
print(f" seed {sid}: full={a:.3f} no_sq={b:.3f} Δ={a-b:+.3f}")
|
| 67 |
+
verdict = ("✓ YES" if m_d > s_d and m_d > 0.01 else
|
| 68 |
+
"✗ NO" if m_d <= 0 else "≈ within noise")
|
| 69 |
+
print(f" mean Δ = {m_d:+.3f} ± {s_d:.3f} → {verdict}")
|
| 70 |
+
|
| 71 |
+
print("\n=== Q3. A3-Simple: photon-prob > gaussian metric on same encoder? ===")
|
| 72 |
+
for label, rs in [("full", full), ("no-squeeze", nosq)]:
|
| 73 |
+
if not rs:
|
| 74 |
+
continue
|
| 75 |
+
# Paired per-seed: same encoder, two metrics on test set
|
| 76 |
+
diffs = [r["test_photon_prob_n10"] - r["test_gaussian_n10"] for r in rs]
|
| 77 |
+
m_d = statistics.mean(diffs)
|
| 78 |
+
s_d = statistics.stdev(diffs) if len(diffs) > 1 else 0.0
|
| 79 |
+
m_g, _, _ = stat(rs, "test_gaussian_n10")
|
| 80 |
+
m_p, _, _ = stat(rs, "test_photon_prob_n10")
|
| 81 |
+
verdict = ("✓ photon-prob wins" if m_d > s_d and m_d > 0.01 else
|
| 82 |
+
"✗ photon-prob loses" if m_d < -0.01 else
|
| 83 |
+
"≈ tie within noise")
|
| 84 |
+
print(f" {label:>10}: gaussian={m_g:.3f} photon_prob={m_p:.3f} Δ={m_d:+.3f} ± {s_d:.3f} → {verdict}")
|
| 85 |
+
|
| 86 |
+
print("\n=== Q4. Generalization tax (train − test, gaussian metric) ===")
|
| 87 |
+
for label, rs in [("full", full), ("no-squeeze", nosq)]:
|
| 88 |
+
if not rs:
|
| 89 |
+
continue
|
| 90 |
+
gaps = [r["train_gaussian_n10"] - r["test_gaussian_n10"] for r in rs]
|
| 91 |
+
m_g = statistics.mean(gaps)
|
| 92 |
+
s_g = statistics.stdev(gaps) if len(gaps) > 1 else 0.0
|
| 93 |
+
print(f" {label:>10}: gap = {m_g:.3f} ± {s_g:.3f}")
|
| 94 |
+
|
| 95 |
+
print("\nFull table:")
|
| 96 |
+
headers = ("seed", "mode", "train_g_n10", "test_g_n10", "train_p_n10", "test_p_n10")
|
| 97 |
+
print(" " + " ".join(f"{h:>11}" for h in headers))
|
| 98 |
+
for r in full + nosq:
|
| 99 |
+
mode = "no_squeeze" if int(r["no_squeeze"]) else "full"
|
| 100 |
+
cells = (
|
| 101 |
+
int(r["seed"]), mode,
|
| 102 |
+
r["train_gaussian_n10"], r["test_gaussian_n10"],
|
| 103 |
+
r["train_photon_prob_n10"], r["test_photon_prob_n10"],
|
| 104 |
+
)
|
| 105 |
+
print(" " + " ".join(
|
| 106 |
+
f"{c:>11}" if isinstance(c, str) else f"{c:>11.3f}" if isinstance(c, float) else f"{c:>11}"
|
| 107 |
+
for c in cells
|
| 108 |
+
))
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
if __name__ == "__main__":
|
| 112 |
+
main()
|
space/run_sweep.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""5-split × 2-mode sweep over the SBERT-backed photon-route trainer.
|
| 2 |
+
|
| 3 |
+
Designed to run on cloud CI (e.g. GitHub Actions ubuntu-latest, free tier),
|
| 4 |
+
NOT locally. Output is a CSV of (split_seed, no_squeeze, train_ndcg10, test_ndcg10, ...)
|
| 5 |
+
that gets uploaded as a workflow artifact for the user to read.
|
| 6 |
+
|
| 7 |
+
For each random split seed:
|
| 8 |
+
1. Pick 2 of the eval queries as held-out test, rest as train.
|
| 9 |
+
2. Train SBERTPhoton (full vs --no-squeeze).
|
| 10 |
+
3. Evaluate on train and test.
|
| 11 |
+
4. Append result row.
|
| 12 |
+
|
| 13 |
+
The point: we want error bars, not a point estimate. If the +30% nDCG@10
|
| 14 |
+
the SBERT-backed run got on one specific 4/2 split is real, it should hold
|
| 15 |
+
across multiple random splits. If it varies wildly, the headline was a
|
| 16 |
+
2-query coincidence.
|
| 17 |
+
"""
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
import argparse
|
| 21 |
+
import csv
|
| 22 |
+
import json
|
| 23 |
+
import os
|
| 24 |
+
import random
|
| 25 |
+
import sys
|
| 26 |
+
import tempfile
|
| 27 |
+
from pathlib import Path
|
| 28 |
+
|
| 29 |
+
ROOT = Path(__file__).resolve().parent.parent
|
| 30 |
+
SRC = ROOT / "src"
|
| 31 |
+
if str(SRC) not in sys.path:
|
| 32 |
+
sys.path.insert(0, str(SRC))
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def make_split(rel_payload: dict, n_test: int, seed: int) -> tuple[dict, dict]:
|
| 36 |
+
"""Returns (train_relevance, test_relevance) as separate JSON-able dicts."""
|
| 37 |
+
rng = random.Random(seed)
|
| 38 |
+
queries = list(rel_payload["queries"])
|
| 39 |
+
rng.shuffle(queries)
|
| 40 |
+
test = queries[:n_test]
|
| 41 |
+
train = queries[n_test:]
|
| 42 |
+
return (
|
| 43 |
+
{**rel_payload, "queries": train},
|
| 44 |
+
{**rel_payload, "queries": test},
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def run_one(seed: int, no_squeeze: bool, steps: int, relevance_path: Path,
|
| 49 |
+
n_test: int, log_dir: Path) -> dict:
|
| 50 |
+
"""Train + eval one configuration; return summary dict for the CSV row."""
|
| 51 |
+
import space.train_sbert as ts
|
| 52 |
+
|
| 53 |
+
rel_payload = json.loads(relevance_path.read_text("utf-8"))
|
| 54 |
+
train_rel, test_rel = make_split(rel_payload, n_test=n_test, seed=seed)
|
| 55 |
+
|
| 56 |
+
with tempfile.TemporaryDirectory() as tmp:
|
| 57 |
+
train_p = Path(tmp) / "rel_train.json"
|
| 58 |
+
test_p = Path(tmp) / "rel_test.json"
|
| 59 |
+
train_p.write_text(json.dumps(train_rel, indent=2))
|
| 60 |
+
test_p.write_text(json.dumps(test_rel, indent=2))
|
| 61 |
+
|
| 62 |
+
# Build a Namespace mimicking train_sbert's CLI args so we can call
|
| 63 |
+
# train(args) directly without subprocess. Faster + captures Python
|
| 64 |
+
# exceptions cleanly.
|
| 65 |
+
import argparse as _ap
|
| 66 |
+
args = _ap.Namespace(
|
| 67 |
+
steps=steps, lr=1e-2, weight_decay=1e-3, temperature=2.0,
|
| 68 |
+
negatives=8, clip=1.0, seed=seed, log_every=50,
|
| 69 |
+
relevance=str(train_p),
|
| 70 |
+
eval_train_rel=str(train_p), eval_test_rel=str(test_p),
|
| 71 |
+
no_squeeze=no_squeeze,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
# Capture stdout to recover SUMMARY_JSON line.
|
| 75 |
+
import io, contextlib
|
| 76 |
+
buf = io.StringIO()
|
| 77 |
+
with contextlib.redirect_stdout(buf):
|
| 78 |
+
ts.train(args)
|
| 79 |
+
out = buf.getvalue()
|
| 80 |
+
|
| 81 |
+
# Persist log
|
| 82 |
+
log_path = log_dir / f"seed{seed}_nosqz{int(no_squeeze)}.log"
|
| 83 |
+
log_path.write_text(out, encoding="utf-8")
|
| 84 |
+
|
| 85 |
+
summary_line = next(
|
| 86 |
+
(l for l in out.splitlines() if l.startswith("SUMMARY_JSON=")), ""
|
| 87 |
+
)
|
| 88 |
+
summary = json.loads(summary_line.split("=", 1)[1]) if summary_line else {}
|
| 89 |
+
|
| 90 |
+
def g(key, metric_key):
|
| 91 |
+
return summary.get(f"{key}/{metric_key}", {})
|
| 92 |
+
|
| 93 |
+
row = {"seed": seed, "no_squeeze": int(no_squeeze)}
|
| 94 |
+
for split in ("train", "test"):
|
| 95 |
+
for metric in ("gaussian", "photon_prob"):
|
| 96 |
+
agg = g(split, metric)
|
| 97 |
+
for m in ("ndcg@10", "recall@10", "recall@1"):
|
| 98 |
+
short = m.replace("@", "").replace("recall", "r").replace("ndcg", "n")
|
| 99 |
+
row[f"{split}_{metric}_{short}"] = agg.get(m, float("nan"))
|
| 100 |
+
return row
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def main():
|
| 104 |
+
ap = argparse.ArgumentParser()
|
| 105 |
+
ap.add_argument("--relevance", type=Path, default=ROOT / "eval" / "relevance_expanded.json",
|
| 106 |
+
help="Default to the title-expanded set so each split has more train signal.")
|
| 107 |
+
ap.add_argument("--seeds", type=int, nargs="+", default=[1, 2, 3, 4, 5])
|
| 108 |
+
ap.add_argument("--n-test", type=int, default=4,
|
| 109 |
+
help="Held-out test queries per split. With expanded relevance (26 q), 4 test is ~15%.")
|
| 110 |
+
ap.add_argument("--steps", type=int, default=200)
|
| 111 |
+
ap.add_argument("--out-csv", type=Path, default=ROOT / "sweep_results.csv")
|
| 112 |
+
ap.add_argument("--log-dir", type=Path, default=ROOT / "sweep_logs")
|
| 113 |
+
args = ap.parse_args()
|
| 114 |
+
|
| 115 |
+
args.log_dir.mkdir(parents=True, exist_ok=True)
|
| 116 |
+
|
| 117 |
+
results = []
|
| 118 |
+
for seed in args.seeds:
|
| 119 |
+
for no_squeeze in [False, True]:
|
| 120 |
+
print(f"\n{'='*72}\nseed={seed} no_squeeze={no_squeeze}\n{'='*72}")
|
| 121 |
+
row = run_one(
|
| 122 |
+
seed=seed, no_squeeze=no_squeeze, steps=args.steps,
|
| 123 |
+
relevance_path=args.relevance, n_test=args.n_test, log_dir=args.log_dir,
|
| 124 |
+
)
|
| 125 |
+
print(f" → gaussian: train n10={row['train_gaussian_n10']:.3f} test n10={row['test_gaussian_n10']:.3f}")
|
| 126 |
+
print(f" photon_prob: train n10={row['train_photon_prob_n10']:.3f} test n10={row['test_photon_prob_n10']:.3f}")
|
| 127 |
+
results.append(row)
|
| 128 |
+
|
| 129 |
+
# Write CSV
|
| 130 |
+
fieldnames = list(results[0].keys())
|
| 131 |
+
with args.out_csv.open("w", newline="") as f:
|
| 132 |
+
w = csv.DictWriter(f, fieldnames=fieldnames)
|
| 133 |
+
w.writeheader()
|
| 134 |
+
w.writerows(results)
|
| 135 |
+
print(f"\nwrote {len(results)} rows → {args.out_csv}")
|
| 136 |
+
|
| 137 |
+
# Aggregate stats
|
| 138 |
+
import statistics
|
| 139 |
+
def stat(rows, key):
|
| 140 |
+
vals = [r[key] for r in rows]
|
| 141 |
+
return statistics.mean(vals), statistics.stdev(vals) if len(vals) > 1 else 0.0
|
| 142 |
+
|
| 143 |
+
print("\nAggregates over seeds:")
|
| 144 |
+
for ns in [False, True]:
|
| 145 |
+
rows = [r for r in results if r["no_squeeze"] == int(ns)]
|
| 146 |
+
if not rows:
|
| 147 |
+
continue
|
| 148 |
+
label = "no-squeeze" if ns else "full"
|
| 149 |
+
for metric in ("gaussian", "photon_prob"):
|
| 150 |
+
m, s = stat(rows, f"test_{metric}_n10")
|
| 151 |
+
print(f" {label:>10}/{metric:>11}: test nDCG@10 = {m:.3f} ± {s:.3f} (n={len(rows)})")
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
if __name__ == "__main__":
|
| 155 |
+
main()
|
space/run_sweep_fock.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""5-split sweep for the A3-Real Fock-basis trainer.
|
| 2 |
+
|
| 3 |
+
Mirrors space/run_sweep.py but for train_sbert_fock.py (non-Gaussian
|
| 4 |
+
heralded encoder). Outputs one CSV row per (seed) — no squeeze ablation
|
| 5 |
+
since the Fock encoder structure already includes a learnable TMS gate
|
| 6 |
+
and learnable squeezing; the equivalent ablation is herald_n=0 (heralding
|
| 7 |
+
on vacuum keeps the state Gaussian).
|
| 8 |
+
"""
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import argparse
|
| 12 |
+
import csv
|
| 13 |
+
import io
|
| 14 |
+
import contextlib
|
| 15 |
+
import json
|
| 16 |
+
import random
|
| 17 |
+
import sys
|
| 18 |
+
import tempfile
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
ROOT = Path(__file__).resolve().parent.parent
|
| 22 |
+
SRC = ROOT / "src"
|
| 23 |
+
if str(SRC) not in sys.path:
|
| 24 |
+
sys.path.insert(0, str(SRC))
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def make_split(rel_payload, n_test, seed):
|
| 28 |
+
rng = random.Random(seed)
|
| 29 |
+
queries = list(rel_payload["queries"])
|
| 30 |
+
rng.shuffle(queries)
|
| 31 |
+
return (
|
| 32 |
+
{**rel_payload, "queries": queries[n_test:]},
|
| 33 |
+
{**rel_payload, "queries": queries[:n_test]},
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def run_one(seed, herald_n, steps, cutoff, relevance_path, n_test, log_dir):
|
| 38 |
+
import space.train_sbert_fock as ts
|
| 39 |
+
|
| 40 |
+
rel_payload = json.loads(relevance_path.read_text("utf-8"))
|
| 41 |
+
train_rel, test_rel = make_split(rel_payload, n_test=n_test, seed=seed)
|
| 42 |
+
with tempfile.TemporaryDirectory() as tmp:
|
| 43 |
+
train_p = Path(tmp) / "rel_train.json"
|
| 44 |
+
test_p = Path(tmp) / "rel_test.json"
|
| 45 |
+
train_p.write_text(json.dumps(train_rel, indent=2))
|
| 46 |
+
test_p.write_text(json.dumps(test_rel, indent=2))
|
| 47 |
+
import argparse as _ap
|
| 48 |
+
args = _ap.Namespace(
|
| 49 |
+
cutoff=cutoff, herald_n=herald_n,
|
| 50 |
+
steps=steps, lr=1e-2, weight_decay=1e-3, temperature=0.5,
|
| 51 |
+
negatives=8, clip=1.0, seed=seed, log_every=50,
|
| 52 |
+
relevance=str(train_p),
|
| 53 |
+
eval_train_rel=str(train_p), eval_test_rel=str(test_p),
|
| 54 |
+
)
|
| 55 |
+
buf = io.StringIO()
|
| 56 |
+
with contextlib.redirect_stdout(buf):
|
| 57 |
+
ts.train(args)
|
| 58 |
+
out = buf.getvalue()
|
| 59 |
+
log_path = log_dir / f"fock_seed{seed}_n{herald_n}.log"
|
| 60 |
+
log_path.write_text(out, encoding="utf-8")
|
| 61 |
+
summary_line = next(
|
| 62 |
+
(l for l in out.splitlines() if l.startswith("SUMMARY_JSON=")), ""
|
| 63 |
+
)
|
| 64 |
+
summary = json.loads(summary_line.split("=", 1)[1]) if summary_line else {}
|
| 65 |
+
train_agg = summary.get("train/fock", {})
|
| 66 |
+
test_agg = summary.get("test/fock", {})
|
| 67 |
+
return {
|
| 68 |
+
"seed": seed,
|
| 69 |
+
"herald_n": herald_n,
|
| 70 |
+
"cutoff": cutoff,
|
| 71 |
+
"train_ndcg10": train_agg.get("ndcg@10", float("nan")),
|
| 72 |
+
"test_ndcg10": test_agg.get("ndcg@10", float("nan")),
|
| 73 |
+
"train_recall10":train_agg.get("recall@10", float("nan")),
|
| 74 |
+
"test_recall10": test_agg.get("recall@10", float("nan")),
|
| 75 |
+
"train_recall1": train_agg.get("recall@1", float("nan")),
|
| 76 |
+
"test_recall1": test_agg.get("recall@1", float("nan")),
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def main():
|
| 81 |
+
ap = argparse.ArgumentParser()
|
| 82 |
+
ap.add_argument("--relevance", type=Path, default=ROOT / "eval" / "relevance_expanded.json")
|
| 83 |
+
ap.add_argument("--seeds", type=int, nargs="+", default=[1, 2, 3, 4, 5])
|
| 84 |
+
ap.add_argument("--n-test", type=int, default=4)
|
| 85 |
+
ap.add_argument("--steps", type=int, default=200)
|
| 86 |
+
ap.add_argument("--cutoff", type=int, default=6)
|
| 87 |
+
ap.add_argument("--herald-ns", type=int, nargs="+", default=[1, 0],
|
| 88 |
+
help="Ancilla outcomes to test. herald_n=0 keeps state Gaussian; herald_n=1 makes it non-Gaussian.")
|
| 89 |
+
ap.add_argument("--out-csv", type=Path, default=ROOT / "sweep_fock_results.csv")
|
| 90 |
+
ap.add_argument("--log-dir", type=Path, default=ROOT / "sweep_fock_logs")
|
| 91 |
+
args = ap.parse_args()
|
| 92 |
+
args.log_dir.mkdir(parents=True, exist_ok=True)
|
| 93 |
+
|
| 94 |
+
results = []
|
| 95 |
+
for seed in args.seeds:
|
| 96 |
+
for hn in args.herald_ns:
|
| 97 |
+
print(f"\n{'='*72}\nseed={seed} herald_n={hn}\n{'='*72}")
|
| 98 |
+
row = run_one(seed=seed, herald_n=hn, steps=args.steps,
|
| 99 |
+
cutoff=args.cutoff, relevance_path=args.relevance,
|
| 100 |
+
n_test=args.n_test, log_dir=args.log_dir)
|
| 101 |
+
print(f" → train n10={row['train_ndcg10']:.3f} test n10={row['test_ndcg10']:.3f}")
|
| 102 |
+
results.append(row)
|
| 103 |
+
|
| 104 |
+
fieldnames = list(results[0].keys())
|
| 105 |
+
with args.out_csv.open("w", newline="") as f:
|
| 106 |
+
w = csv.DictWriter(f, fieldnames=fieldnames)
|
| 107 |
+
w.writeheader()
|
| 108 |
+
w.writerows(results)
|
| 109 |
+
print(f"\nwrote {len(results)} rows → {args.out_csv}")
|
| 110 |
+
|
| 111 |
+
import statistics
|
| 112 |
+
def stat(rs, key):
|
| 113 |
+
vals = [r[key] for r in rs if r[key] == r[key]]
|
| 114 |
+
return (statistics.mean(vals), statistics.stdev(vals) if len(vals) > 1 else 0.0, len(vals))
|
| 115 |
+
|
| 116 |
+
print("\nAggregates:")
|
| 117 |
+
for hn in args.herald_ns:
|
| 118 |
+
rs = [r for r in results if r["herald_n"] == hn]
|
| 119 |
+
m, s, n = stat(rs, "test_ndcg10")
|
| 120 |
+
label = "non-Gaussian (herald=1)" if hn == 1 else "Gaussian (herald=0)" if hn == 0 else f"herald={hn}"
|
| 121 |
+
print(f" {label:>26}: test nDCG@10 = {m:.3f} ± {s:.3f} (n={n})")
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
if __name__ == "__main__":
|
| 125 |
+
main()
|
space/sim_b1_g1_coherence.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""B1 sim — does g^(1)-style coherence time τ_c discriminate candidates
|
| 2 |
+
better than meridian's current 3-bin Shannon entropy `cross_domain`?
|
| 3 |
+
|
| 4 |
+
Loudon eq 3.1.3: g^(1)(τ) = ⟨E*(t) E(t+τ)⟩ / ⟨|E|²⟩.
|
| 5 |
+
For a chaotic source, |g^(1)(τ)| decays exponentially with characteristic
|
| 6 |
+
τ_c = (∫|g^(1)(τ)|² dτ).
|
| 7 |
+
|
| 8 |
+
Treat each candidate's keyword stream as a chaotic light source where
|
| 9 |
+
each token at position t is a "wavetrain at frequency ω_token". The
|
| 10 |
+
autocorrelation of one-hot token vectors gives an effective τ_c that
|
| 11 |
+
scales with vocabulary diversity.
|
| 12 |
+
|
| 13 |
+
This is a self-contained synthetic-data sim. No external corpus / no
|
| 14 |
+
cloud compute required. Runs in <1 s.
|
| 15 |
+
"""
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import math
|
| 19 |
+
import numpy as np
|
| 20 |
+
from collections import Counter
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def cross_domain_proxy(tokens: list[str], systems: dict[str, set[str]]) -> float:
|
| 24 |
+
"""Mirror meridian's existing computation: Shannon entropy / log(3)
|
| 25 |
+
over hits in {forge, signal, mind} term lists, normalized to [0, 1]."""
|
| 26 |
+
affinity = {sys: 0 for sys in systems}
|
| 27 |
+
for t in tokens:
|
| 28 |
+
for sys, terms in systems.items():
|
| 29 |
+
if t in terms:
|
| 30 |
+
affinity[sys] += 1
|
| 31 |
+
total = sum(affinity.values()) or 1
|
| 32 |
+
probs = [n / total for n in affinity.values() if n > 0]
|
| 33 |
+
H = -sum(p * math.log(p) for p in probs)
|
| 34 |
+
return H / math.log(3) if H else 0.0
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def coherence_time(tokens: list[str], window: int = 8) -> float:
|
| 38 |
+
"""Empirical g^(1)-style coherence time of a token stream.
|
| 39 |
+
|
| 40 |
+
Treat the sequence as a discrete-time signal where each token is a
|
| 41 |
+
distinct mode. g^(1)(τ) = (# matched-token pairs at offset τ) /
|
| 42 |
+
(# matched at τ=0). τ_c = sum_{τ≥1} |g^(1)(τ)|² up to a window.
|
| 43 |
+
|
| 44 |
+
Pure-stdlib, normalised so τ_c ∈ [0, window].
|
| 45 |
+
"""
|
| 46 |
+
n = len(tokens)
|
| 47 |
+
if n < 2:
|
| 48 |
+
return 0.0
|
| 49 |
+
g0 = sum(1 for t in tokens) or 1 # τ=0 normalisation = total length
|
| 50 |
+
tau_c = 0.0
|
| 51 |
+
for tau in range(1, min(window, n)):
|
| 52 |
+
matches = sum(1 for i in range(n - tau) if tokens[i] == tokens[i + tau])
|
| 53 |
+
gtau = matches / g0
|
| 54 |
+
tau_c += gtau * gtau
|
| 55 |
+
return tau_c
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# ─── Synthetic candidates with realistic body lengths ──────────────────────
|
| 59 |
+
SYSTEMS = {
|
| 60 |
+
"forge": {"build", "compile", "deploy", "ci", "container", "image", "binary",
|
| 61 |
+
"docker", "kubernetes", "package", "release", "monorepo"},
|
| 62 |
+
"signal": {"data", "stream", "ingest", "pipeline", "etl", "kafka", "queue",
|
| 63 |
+
"throughput", "latency", "broker", "subscriber", "publish"},
|
| 64 |
+
"mind": {"llm", "embed", "embedding", "model", "transformer", "agent",
|
| 65 |
+
"reasoning", "prompt", "context", "rag", "fine", "tune"},
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def make_candidate(label: str, vocab_pool: list[str], length: int = 250,
|
| 70 |
+
alpha: float = 1.0, seed: int = 0) -> list[str]:
|
| 71 |
+
"""Generate length tokens drawn Zipfian from vocab_pool. alpha controls
|
| 72 |
+
head heaviness; alpha=1.0 ≈ thermal; alpha→∞ ≈ heavy concentrated."""
|
| 73 |
+
rng = np.random.default_rng(seed)
|
| 74 |
+
weights = 1.0 / (np.arange(1, len(vocab_pool) + 1) ** alpha)
|
| 75 |
+
weights /= weights.sum()
|
| 76 |
+
return list(rng.choice(vocab_pool, size=length, p=weights))
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def main():
|
| 80 |
+
# 9 archetypes spanning body lengths and topical vs scattered patterns.
|
| 81 |
+
forge_terms = sorted(SYSTEMS["forge"])
|
| 82 |
+
signal_terms = sorted(SYSTEMS["signal"])
|
| 83 |
+
mind_terms = sorted(SYSTEMS["mind"])
|
| 84 |
+
cross_terms = forge_terms + signal_terms + mind_terms
|
| 85 |
+
cases = [
|
| 86 |
+
("focused-forge", forge_terms, 300, 1.0, 1),
|
| 87 |
+
("focused-signal", signal_terms, 300, 1.0, 2),
|
| 88 |
+
("focused-mind", mind_terms, 300, 1.0, 3),
|
| 89 |
+
("cross-forge-signal", forge_terms + signal_terms, 300, 1.0, 4),
|
| 90 |
+
("cross-mind-signal", mind_terms + signal_terms, 300, 1.0, 5),
|
| 91 |
+
("cross-three-systems", cross_terms, 300, 1.0, 6),
|
| 92 |
+
("scattered-cross", cross_terms, 300, 0.5, 7), # less Zipfian, more uniform
|
| 93 |
+
("very-narrow", forge_terms[:3], 300, 2.0, 8), # 3 dominant words
|
| 94 |
+
("very-broad", cross_terms, 300, 0.3, 9), # near-uniform
|
| 95 |
+
]
|
| 96 |
+
|
| 97 |
+
print(f"{'archetype':>22} {'len':>5} {'cross_domain':>13} {'τ_c (g^(1))':>12}")
|
| 98 |
+
print("-" * 64)
|
| 99 |
+
rows = []
|
| 100 |
+
for label, pool, length, alpha, seed in cases:
|
| 101 |
+
toks = make_candidate(label, pool, length=length, alpha=alpha, seed=seed)
|
| 102 |
+
cd = cross_domain_proxy(toks, SYSTEMS)
|
| 103 |
+
tc = coherence_time(toks)
|
| 104 |
+
print(f"{label:>22} {length:>5} {cd:>13.3f} {tc:>12.3f}")
|
| 105 |
+
rows.append((label, cd, tc))
|
| 106 |
+
|
| 107 |
+
print("\nDiscrimination check (variance across archetypes, higher = better signal):")
|
| 108 |
+
cd_vals = [r[1] for r in rows]
|
| 109 |
+
tc_vals = [r[2] for r in rows]
|
| 110 |
+
print(f" cross_domain: std = {np.std(cd_vals):.3f}, range = [{min(cd_vals):.3f}, {max(cd_vals):.3f}]")
|
| 111 |
+
print(f" τ_c (g^(1)): std = {np.std(tc_vals):.3f}, range = [{min(tc_vals):.3f}, {max(tc_vals):.3f}]")
|
| 112 |
+
|
| 113 |
+
# CV (coefficient of variation) — higher = more discriminative on its own scale
|
| 114 |
+
cd_cv = np.std(cd_vals) / max(np.mean(cd_vals), 1e-9)
|
| 115 |
+
tc_cv = np.std(tc_vals) / max(np.mean(tc_vals), 1e-9)
|
| 116 |
+
print(f"\n CV (std/mean): cross_domain={cd_cv:.3f} τ_c={tc_cv:.3f}")
|
| 117 |
+
if tc_cv > cd_cv * 1.2:
|
| 118 |
+
print(" → τ_c is more discriminative than cross_domain — B1 stands.")
|
| 119 |
+
elif tc_cv < cd_cv * 0.8:
|
| 120 |
+
print(" → τ_c is LESS discriminative than cross_domain — B1 fails.")
|
| 121 |
+
else:
|
| 122 |
+
print(" → τ_c and cross_domain have similar discrimination — B1 is a wash.")
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
if __name__ == "__main__":
|
| 126 |
+
main()
|
space/sim_b2_g2_classifier.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""B2 sim — does g^(2)(0) cleanly classify candidates into planet / comet
|
| 2 |
+
/ asteroid on REAL meridian-shaped data (Llama-emitted bodies of typical
|
| 3 |
+
length 100–500 tokens), and where does it disagree with the existing
|
| 4 |
+
mass × scope × independence rule?
|
| 5 |
+
|
| 6 |
+
Sim 4b earlier (synthetic Zipfian) showed g^(2) > 1 only emerges at
|
| 7 |
+
N_distinct × token_total scales typical of real bodies, not the 8–12
|
| 8 |
+
token toy archetypes from Sim 4. This sim re-runs that check on
|
| 9 |
+
realistic body shapes and compares per-candidate the g^(2) class label
|
| 10 |
+
to the mass × scope × independence label that orbital.mjs assigns today.
|
| 11 |
+
"""
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import math
|
| 15 |
+
import numpy as np
|
| 16 |
+
from collections import Counter
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def g2_zero(tokens: list[str]) -> float:
|
| 20 |
+
"""g^(2)(0) = ⟨n(n-1)⟩ / ⟨n⟩² over per-word counts {n_i}.
|
| 21 |
+
Loudon Ch 6.4: coherent → 1, chaotic → 2, antibunched → <1."""
|
| 22 |
+
if not tokens:
|
| 23 |
+
return float("nan")
|
| 24 |
+
counts = np.asarray(list(Counter(tokens).values()), dtype=np.float64)
|
| 25 |
+
n_mean = counts.mean()
|
| 26 |
+
n_n_minus_1 = (counts * (counts - 1)).mean()
|
| 27 |
+
return n_n_minus_1 / (n_mean ** 2) if n_mean > 0 else float("nan")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def class_from_g2(g2: float) -> str:
|
| 31 |
+
"""Threshold rule from Loudon Ch 6.4."""
|
| 32 |
+
if g2 < 0.7:
|
| 33 |
+
return "asteroid" # antibunched / sparse / niche
|
| 34 |
+
elif g2 < 1.4:
|
| 35 |
+
return "planet" # ≈ 1 = coherent / focused
|
| 36 |
+
else:
|
| 37 |
+
return "comet" # > 1 = thermal / scattered
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def class_from_meridian(mass: float, scope: float, indep: float,
|
| 41 |
+
cross_domain: float, drag: float, fragmentation: float,
|
| 42 |
+
dep_ratio: float, has_parent: bool) -> str:
|
| 43 |
+
"""Mirror orbital.mjs:139-167 — compute the same scores and pick max."""
|
| 44 |
+
planet = min(mass, scope, indep) ** 1.5
|
| 45 |
+
moon = (max(0, 0.5 - indep) * 2 *
|
| 46 |
+
(1.0 if has_parent else 0.4) * (1 - 0.5 * mass))
|
| 47 |
+
trojan = dep_ratio * (1.0 if has_parent else 0.5) * (1 - fragmentation)
|
| 48 |
+
asteroid = max(0, 0.55 - mass) * 2.5 * scope * indep
|
| 49 |
+
comet = drag * cross_domain * (1 - dep_ratio)
|
| 50 |
+
irregular = cross_domain * fragmentation * 0.85
|
| 51 |
+
scores = {"planet": planet, "moon": moon, "trojan": trojan,
|
| 52 |
+
"asteroid": asteroid, "comet": comet, "irregular": irregular}
|
| 53 |
+
return max(scores, key=scores.get)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def physics_from_tokens(tokens: list[str]) -> dict:
|
| 57 |
+
"""Approximate the physics scalars meridian computes from text."""
|
| 58 |
+
body_len = sum(len(t) for t in tokens)
|
| 59 |
+
n_words = len(tokens)
|
| 60 |
+
mass = max(0, min(1, 0.6 * np.log10(max(50, body_len) / 200) /
|
| 61 |
+
np.log10(3000 / 200) + 0.4 * (n_words - 3) / 9))
|
| 62 |
+
distinct = len(set(tokens))
|
| 63 |
+
scope = min(0.7, distinct / 12) + 0.2 # rough proxy
|
| 64 |
+
scope = max(0, min(1, scope))
|
| 65 |
+
indep = 0.7 # synthetic candidates have no siblings; assume mid-high
|
| 66 |
+
drag = 0.3
|
| 67 |
+
fragmentation = 0.4
|
| 68 |
+
cross_domain = 0.5 # placeholder
|
| 69 |
+
dep_ratio = 0.2
|
| 70 |
+
return dict(mass=mass, scope=scope, indep=indep,
|
| 71 |
+
drag=drag, fragmentation=fragmentation,
|
| 72 |
+
cross_domain=cross_domain, dep_ratio=dep_ratio,
|
| 73 |
+
has_parent=False)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# ─── Realistic synthetic candidates ─────────────────────────────────────────
|
| 77 |
+
def zipfian_words(prefix: str, n_distinct: int, length: int, alpha: float, seed: int):
|
| 78 |
+
rng = np.random.default_rng(seed)
|
| 79 |
+
vocab = [f"{prefix}-{i:02d}" for i in range(n_distinct)]
|
| 80 |
+
weights = 1.0 / (np.arange(1, n_distinct + 1) ** alpha)
|
| 81 |
+
weights /= weights.sum()
|
| 82 |
+
return list(rng.choice(vocab, size=length, p=weights))
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def main():
|
| 86 |
+
# 9 archetypes covering the planet/comet/asteroid spectrum at realistic
|
| 87 |
+
# body length (200–400 tokens) — the regime Sim 4b proved relevant.
|
| 88 |
+
cases = [
|
| 89 |
+
# label, n_distinct, length, alpha (Zipf head), expected
|
| 90 |
+
("planet-tight-vocab", 20, 300, 1.0, "planet"), # coherent-shaped
|
| 91 |
+
("planet-medium", 15, 250, 0.8, "planet"),
|
| 92 |
+
("planet-broad-vocab", 50, 400, 1.2, "planet"),
|
| 93 |
+
("comet-thermal", 30, 300, 1.5, "comet"), # heavier head
|
| 94 |
+
("comet-very-heavy", 25, 300, 2.0, "comet"),
|
| 95 |
+
("comet-multimodal", 40, 350, 1.8, "comet"),
|
| 96 |
+
("asteroid-narrow", 5, 300, 1.0, "asteroid"), # too few distinct
|
| 97 |
+
("asteroid-fragments", 10, 100, 0.5, "asteroid"), # short body
|
| 98 |
+
("asteroid-uniform", 50, 300, 0.3, "asteroid"), # near-uniform
|
| 99 |
+
]
|
| 100 |
+
|
| 101 |
+
print(f"{'archetype':>22} {'len':>5} {'g^(2)':>7} {'g2_class':>10} "
|
| 102 |
+
f"{'mass×s×i_class':>16} {'expected':>10}")
|
| 103 |
+
print("-" * 90)
|
| 104 |
+
correct_g2 = 0
|
| 105 |
+
correct_meridian = 0
|
| 106 |
+
for label, n_distinct, length, alpha, expected in cases:
|
| 107 |
+
toks = zipfian_words(label, n_distinct, length, alpha, seed=hash(label) & 0xFFFF)
|
| 108 |
+
g2 = g2_zero(toks)
|
| 109 |
+
cls_g2 = class_from_g2(g2)
|
| 110 |
+
phys = physics_from_tokens(toks)
|
| 111 |
+
cls_m = class_from_meridian(**phys)
|
| 112 |
+
ok_g2 = cls_g2 == expected
|
| 113 |
+
ok_m = cls_m == expected
|
| 114 |
+
correct_g2 += int(ok_g2)
|
| 115 |
+
correct_meridian += int(ok_m)
|
| 116 |
+
marker_g2 = "✓" if ok_g2 else "✗"
|
| 117 |
+
marker_m = "✓" if ok_m else "✗"
|
| 118 |
+
print(f"{label:>22} {length:>5} {g2:>7.3f} "
|
| 119 |
+
f"{cls_g2:>9}{marker_g2} {cls_m:>15}{marker_m} {expected:>10}")
|
| 120 |
+
|
| 121 |
+
print(f"\n g^(2)-only classifier: {correct_g2}/{len(cases)} archetypes correct")
|
| 122 |
+
print(f" meridian's mass×scope×indep: {correct_meridian}/{len(cases)} archetypes correct")
|
| 123 |
+
if correct_g2 > correct_meridian:
|
| 124 |
+
print(" → B2 stands: g^(2) classifies more reliably on real-shape data")
|
| 125 |
+
elif correct_g2 < correct_meridian:
|
| 126 |
+
print(" → B2 fails: meridian's existing rule is better")
|
| 127 |
+
else:
|
| 128 |
+
print(" → B2 is a wash: both classifiers tied on archetype recovery")
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
if __name__ == "__main__":
|
| 132 |
+
main()
|
space/train.py
CHANGED
|
@@ -245,7 +245,7 @@ def train(args: argparse.Namespace) -> None:
|
|
| 245 |
torch.manual_seed(args.seed)
|
| 246 |
np.random.seed(args.seed)
|
| 247 |
|
| 248 |
-
rel_path = ROOT / "eval" / "relevance.json"
|
| 249 |
cids_path = ROOT / "eval" / "corpus_ids.json"
|
| 250 |
man_path = ROOT / "eval" / "manifest.json"
|
| 251 |
|
|
@@ -380,6 +380,8 @@ def train(args: argparse.Namespace) -> None:
|
|
| 380 |
def main() -> None:
|
| 381 |
ap = argparse.ArgumentParser()
|
| 382 |
ap.add_argument("--out", type=Path, default=ROOT / "weights.npz")
|
|
|
|
|
|
|
| 383 |
ap.add_argument("--steps", type=int, default=100)
|
| 384 |
ap.add_argument("--lr", type=float, default=5e-3)
|
| 385 |
# D-scale logits: with D in [0, 50], temp=0.1 made -D/temp logits up to
|
|
|
|
| 245 |
torch.manual_seed(args.seed)
|
| 246 |
np.random.seed(args.seed)
|
| 247 |
|
| 248 |
+
rel_path = Path(args.relevance) if args.relevance else ROOT / "eval" / "relevance.json"
|
| 249 |
cids_path = ROOT / "eval" / "corpus_ids.json"
|
| 250 |
man_path = ROOT / "eval" / "manifest.json"
|
| 251 |
|
|
|
|
| 380 |
def main() -> None:
|
| 381 |
ap = argparse.ArgumentParser()
|
| 382 |
ap.add_argument("--out", type=Path, default=ROOT / "weights.npz")
|
| 383 |
+
ap.add_argument("--relevance", type=str, default=None,
|
| 384 |
+
help="path to alternate relevance.json (e.g. for held-out splits)")
|
| 385 |
ap.add_argument("--steps", type=int, default=100)
|
| 386 |
ap.add_argument("--lr", type=float, default=5e-3)
|
| 387 |
# D-scale logits: with D in [0, 50], temp=0.1 made -D/temp logits up to
|
space/train_sbert.py
ADDED
|
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SBERT-backed photon-route encoder. The language model does language;
|
| 2 |
+
the photonic gates do the structured projection.
|
| 3 |
+
|
| 4 |
+
Architecture:
|
| 5 |
+
text → frozen SentenceTransformer (all-MiniLM-L6-v2, 384-d, no grad)
|
| 6 |
+
→ Linear(384 → 4N + 2N) [trainable]
|
| 7 |
+
→ 4N displacement outputs (αq, αp per mode) + 2N squeezing outputs (r, φ per mode)
|
| 8 |
+
→ photonic gates (Sgate + Dgate per mode)
|
| 9 |
+
→ 2N-d Gaussian state (μ, σ) at hbar=2
|
| 10 |
+
|
| 11 |
+
Trainable surface:
|
| 12 |
+
Linear(384 → 6N) ≈ 384·6N + 6N params (6 numbers per mode: αq, αp, r, φ_s, plus 2 future)
|
| 13 |
+
For N=2: 384·8 + 8 = 3,080 params total
|
| 14 |
+
vs word-level photon-route: |V|·4 = 5,772 (and grows with vocab)
|
| 15 |
+
|
| 16 |
+
Loss is the same InfoNCE-on-Bhattacharyya as space/train.py so the comparison
|
| 17 |
+
is apples-to-apples on the encoder, not the loss.
|
| 18 |
+
|
| 19 |
+
Holdout discipline: load --relevance from a file. The eval driver
|
| 20 |
+
(eval/run.py) does NOT support sbert weights yet; this module ships its
|
| 21 |
+
own evaluator alongside the trainer for now.
|
| 22 |
+
"""
|
| 23 |
+
from __future__ import annotations
|
| 24 |
+
|
| 25 |
+
import argparse
|
| 26 |
+
import json
|
| 27 |
+
import math
|
| 28 |
+
import sys
|
| 29 |
+
import time
|
| 30 |
+
from pathlib import Path
|
| 31 |
+
|
| 32 |
+
import numpy as np
|
| 33 |
+
import torch
|
| 34 |
+
import torch.nn as nn
|
| 35 |
+
import torch.nn.functional as F
|
| 36 |
+
from torch import Tensor
|
| 37 |
+
|
| 38 |
+
ROOT = Path(__file__).resolve().parent.parent
|
| 39 |
+
SRC = ROOT / "src"
|
| 40 |
+
if str(SRC) not in sys.path:
|
| 41 |
+
sys.path.insert(0, str(SRC))
|
| 42 |
+
|
| 43 |
+
from eval.fetch import fetch_all, verify_against_manifest # noqa: E402
|
| 44 |
+
|
| 45 |
+
N_MODES = 2
|
| 46 |
+
HBAR = 2.0
|
| 47 |
+
SBERT_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
| 48 |
+
SBERT_DIM = 384
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# ─── differentiable single-mode squeezing in qqpp ─────────────────────────────
|
| 52 |
+
def _eye2N(n: int, ref: Tensor) -> Tensor:
|
| 53 |
+
return torch.eye(2 * n, dtype=ref.dtype, device=ref.device)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def squeezing_qqpp(n: int, k: int, r: Tensor, phi: Tensor) -> Tensor:
|
| 57 |
+
S = _eye2N(n, r).clone()
|
| 58 |
+
cr, sr = torch.cosh(r), torch.sinh(r)
|
| 59 |
+
cp, sp = torch.cos(phi), torch.sin(phi)
|
| 60 |
+
S[k, k ] = cr - sr * cp
|
| 61 |
+
S[k, k + n] = -sr * sp
|
| 62 |
+
S[k + n, k ] = -sr * sp
|
| 63 |
+
S[k + n, k + n] = cr + sr * cp
|
| 64 |
+
return S
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class SBERTPhoton(nn.Module):
|
| 68 |
+
"""Frozen SBERT → Linear → photonic state.
|
| 69 |
+
|
| 70 |
+
The Linear emits 6 numbers per mode (αq, αp, r, φ_s, plus 2 reserved).
|
| 71 |
+
Currently 4 are used; spare dims are zero'd by their learnable weight
|
| 72 |
+
converging to small values, so unused capacity self-prunes.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
def __init__(self, n_modes: int = N_MODES, max_squeeze: float = 0.5,
|
| 76 |
+
max_displace: float = 1.0, no_squeeze: bool = False):
|
| 77 |
+
super().__init__()
|
| 78 |
+
from sentence_transformers import SentenceTransformer
|
| 79 |
+
self.n = n_modes
|
| 80 |
+
self.max_sq = max_squeeze
|
| 81 |
+
self.max_disp = max_displace
|
| 82 |
+
self.no_squeeze = no_squeeze
|
| 83 |
+
self.dgate_prefactor = math.sqrt(2.0 * HBAR)
|
| 84 |
+
self.sbert = SentenceTransformer(SBERT_MODEL_NAME)
|
| 85 |
+
for p in self.sbert.parameters():
|
| 86 |
+
p.requires_grad = False
|
| 87 |
+
# float32 throughout — MPS (Apple-Silicon GPU) doesn't support float64;
|
| 88 |
+
# cast to float64 at the eval-fidelity boundary (numpy + thewalrus).
|
| 89 |
+
# Squeezing magnitudes are bounded ≤ 0.5 so the covariance stays
|
| 90 |
+
# well-conditioned and float32 slogdet is numerically safe.
|
| 91 |
+
self.proj = nn.Linear(SBERT_DIM, 4 * n_modes, dtype=torch.float32)
|
| 92 |
+
# Small-random init (NOT zeros). Zero init puts every text at the
|
| 93 |
+
# same vacuum state, so all pairwise distances equal zero, gradients
|
| 94 |
+
# vanish, and loss stays at log(N+1) = saddle point forever.
|
| 95 |
+
nn.init.normal_(self.proj.weight, std=0.02)
|
| 96 |
+
nn.init.zeros_(self.proj.bias)
|
| 97 |
+
|
| 98 |
+
def encode_features(self, texts: list[str]) -> Tensor:
|
| 99 |
+
"""Run frozen SBERT, return (B, 384) float32 on CPU."""
|
| 100 |
+
with torch.no_grad():
|
| 101 |
+
emb = self.sbert.encode(
|
| 102 |
+
texts, normalize_embeddings=True, convert_to_numpy=False,
|
| 103 |
+
show_progress_bar=False,
|
| 104 |
+
)
|
| 105 |
+
emb = torch.stack([e for e in emb]) if isinstance(emb, list) else emb
|
| 106 |
+
return emb.to(torch.float32).cpu()
|
| 107 |
+
|
| 108 |
+
def state_from_features(self, feat: Tensor) -> tuple[Tensor, Tensor]:
|
| 109 |
+
"""Forward from a *precomputed* SBERT feature vector — used during
|
| 110 |
+
training when frozen-SBERT features are cached at start to avoid
|
| 111 |
+
re-running the transformer every step."""
|
| 112 |
+
out = self.proj(feat)
|
| 113 |
+
return self._gates_from_logits(out)
|
| 114 |
+
|
| 115 |
+
def state_from_text(self, text: str) -> tuple[Tensor, Tensor]:
|
| 116 |
+
feat = self.encode_features([text])[0] # (384,)
|
| 117 |
+
out = self.proj(feat) # (4N,)
|
| 118 |
+
return self._gates_from_logits(out)
|
| 119 |
+
|
| 120 |
+
def _gates_from_logits(self, out: Tensor) -> tuple[Tensor, Tensor]:
|
| 121 |
+
# Decompose: per-mode (αq, αp, raw_r, raw_phi).
|
| 122 |
+
# tanh-bound squeezing magnitude to [0, max_sq]; phi free.
|
| 123 |
+
per_mode = out.view(self.n, 4)
|
| 124 |
+
alpha_q = self.dgate_prefactor * torch.tanh(per_mode[:, 0])
|
| 125 |
+
alpha_p = self.dgate_prefactor * torch.tanh(per_mode[:, 1])
|
| 126 |
+
if self.no_squeeze:
|
| 127 |
+
r = torch.zeros(self.n, dtype=out.dtype)
|
| 128 |
+
phi_s = torch.zeros(self.n, dtype=out.dtype)
|
| 129 |
+
else:
|
| 130 |
+
r = self.max_sq * torch.sigmoid(per_mode[:, 2])
|
| 131 |
+
phi_s = (2 * math.pi) * torch.sigmoid(per_mode[:, 3])
|
| 132 |
+
|
| 133 |
+
mu = torch.zeros(2 * self.n, dtype=out.dtype)
|
| 134 |
+
sigma = _eye2N(self.n, out)
|
| 135 |
+
for k in range(self.n):
|
| 136 |
+
if not self.no_squeeze:
|
| 137 |
+
S = squeezing_qqpp(self.n, k, r[k], phi_s[k])
|
| 138 |
+
mu = S @ mu
|
| 139 |
+
sigma = S @ sigma @ S.T
|
| 140 |
+
shift = torch.zeros_like(mu)
|
| 141 |
+
shift[k] = alpha_q[k]
|
| 142 |
+
shift[k + self.n] = alpha_p[k]
|
| 143 |
+
mu = mu + shift
|
| 144 |
+
return mu, sigma
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def bhattacharyya_distance(mu_a, sg_a, mu_b, sg_b, ridge: float = 1e-3) -> Tensor:
|
| 148 |
+
d = sg_a.shape[0]
|
| 149 |
+
eye = torch.eye(d, dtype=sg_a.dtype, device=sg_a.device)
|
| 150 |
+
A = sg_a + ridge * eye
|
| 151 |
+
B = sg_b + ridge * eye
|
| 152 |
+
V = 0.5 * (A + B)
|
| 153 |
+
delta = mu_a - mu_b
|
| 154 |
+
quad = (delta * torch.linalg.solve(V, delta)).sum()
|
| 155 |
+
log_det_V = torch.linalg.slogdet(V)[1]
|
| 156 |
+
log_det_A = torch.linalg.slogdet(A)[1]
|
| 157 |
+
log_det_B = torch.linalg.slogdet(B)[1]
|
| 158 |
+
D = 0.125 * quad + 0.5 * (log_det_V - 0.5 * (log_det_A + log_det_B))
|
| 159 |
+
return torch.clamp(D, min=0.0, max=50.0)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def gaussian_fidelity_eval(mu_a: np.ndarray, sg_a: np.ndarray,
|
| 163 |
+
mu_b: np.ndarray, sg_b: np.ndarray) -> float:
|
| 164 |
+
"""thewalrus closed-form for eval-time scoring (not used in loss)."""
|
| 165 |
+
from thewalrus.quantum import fidelity as tw_fidelity
|
| 166 |
+
f = tw_fidelity(mu_a, sg_a, mu_b, sg_b, hbar=HBAR)
|
| 167 |
+
val = float(f.real if hasattr(f, "real") else f)
|
| 168 |
+
return max(0.0, min(1.0, val))
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
# ── photon-number-distribution Bhattacharyya coefficient (A3-Simple) ──────
|
| 172 |
+
# Loudon Ch 6.10 — direct detection projects a state onto Fock basis and
|
| 173 |
+
# measures the photon-number distribution P(n_0, n_1, ...). A retrieval
|
| 174 |
+
# metric grounded in what the detector actually sees, rather than the
|
| 175 |
+
# Gaussian-state inner product. Closed-form computable from (μ, σ) for
|
| 176 |
+
# Gaussian states via thewalrus.quantum.probabilities; non-differentiable
|
| 177 |
+
# (numpy under the hood), eval-only — A3-Real would do this differentiably.
|
| 178 |
+
_PHOTON_PROB_CACHE: dict[int, np.ndarray] = {}
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def photon_prob_eval(mu_a: np.ndarray, sg_a: np.ndarray,
|
| 182 |
+
mu_b: np.ndarray, sg_b: np.ndarray, cutoff: int = 4) -> float:
|
| 183 |
+
"""Bhattacharyya coefficient between two photon-number distributions:
|
| 184 |
+
BC(P, Q) = Σ √(p_i q_i). ∈ [0, 1]; 1 = identical distributions.
|
| 185 |
+
|
| 186 |
+
Reuses the cutoff-sized P arrays via caching keyed by (id(mu), id(sg))
|
| 187 |
+
isn't viable across calls (μ, σ get re-allocated). Caller's responsibility
|
| 188 |
+
to dedup per-state; here we just compute fresh.
|
| 189 |
+
"""
|
| 190 |
+
from thewalrus.quantum import probabilities
|
| 191 |
+
P_a = np.asarray(probabilities(mu_a, sg_a, cutoff=cutoff, hbar=HBAR), dtype=np.float64).real
|
| 192 |
+
P_b = np.asarray(probabilities(mu_b, sg_b, cutoff=cutoff, hbar=HBAR), dtype=np.float64).real
|
| 193 |
+
# Truncation can leave a tail — renormalize so distributions sum to 1.
|
| 194 |
+
P_a = np.clip(P_a, 0.0, None) / max(P_a.sum(), 1e-12)
|
| 195 |
+
P_b = np.clip(P_b, 0.0, None) / max(P_b.sum(), 1e-12)
|
| 196 |
+
bc = float(np.sum(np.sqrt(P_a) * np.sqrt(P_b)))
|
| 197 |
+
return max(0.0, min(1.0, bc))
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def recall_at_k(ranked_ids, relevant, k):
|
| 201 |
+
if not relevant:
|
| 202 |
+
return float("nan")
|
| 203 |
+
return len(set(ranked_ids[:k]) & relevant) / len(relevant)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def ndcg_at_k(ranked_ids, relevant, k):
|
| 207 |
+
if not relevant:
|
| 208 |
+
return float("nan")
|
| 209 |
+
dcg = sum(1.0 / math.log2(i + 1) for i, a in enumerate(ranked_ids[:k], start=1) if a in relevant)
|
| 210 |
+
ideal = sum(1.0 / math.log2(i + 1) for i in range(1, min(k, len(relevant)) + 1))
|
| 211 |
+
return dcg / ideal if ideal > 0 else float("nan")
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def evaluate(model: SBERTPhoton, abstracts, ids, queries, ks=(1, 3, 5, 10),
|
| 215 |
+
metrics=("gaussian", "photon_prob"), photon_cutoff: int = 4) -> dict:
|
| 216 |
+
"""Evaluate retrieval under multiple metrics on the same trained encoder.
|
| 217 |
+
|
| 218 |
+
Returns a dict with one report per metric:
|
| 219 |
+
{"gaussian": {"per_query":[...], "aggregate":{...}}, "photon_prob": {...}}
|
| 220 |
+
|
| 221 |
+
A3-Simple test: do "gaussian" (BBP fidelity) and "photon_prob" (Loudon
|
| 222 |
+
Ch 6.10 direct-detection-grounded Bhattacharyya coefficient on the
|
| 223 |
+
photon-number distribution) give different rankings on the same encoder?
|
| 224 |
+
"""
|
| 225 |
+
model.eval()
|
| 226 |
+
# Encode all docs + queries once; convert to numpy float64 for thewalrus.
|
| 227 |
+
doc_np: dict[str, tuple[np.ndarray, np.ndarray]] = {}
|
| 228 |
+
q_np: list[tuple[dict, np.ndarray, np.ndarray]] = []
|
| 229 |
+
with torch.no_grad():
|
| 230 |
+
for arxiv_id, doc_text in abstracts.items():
|
| 231 |
+
mu_d, sg_d = model.state_from_text(doc_text)
|
| 232 |
+
doc_np[arxiv_id] = (
|
| 233 |
+
mu_d.cpu().numpy().astype(np.float64),
|
| 234 |
+
sg_d.cpu().numpy().astype(np.float64),
|
| 235 |
+
)
|
| 236 |
+
for q in queries:
|
| 237 |
+
mu_q, sg_q = model.state_from_text(q["query"])
|
| 238 |
+
q_np.append((
|
| 239 |
+
q,
|
| 240 |
+
mu_q.cpu().numpy().astype(np.float64),
|
| 241 |
+
sg_q.cpu().numpy().astype(np.float64),
|
| 242 |
+
))
|
| 243 |
+
|
| 244 |
+
score_fn = {
|
| 245 |
+
"gaussian": lambda mq, sq, md, sd: gaussian_fidelity_eval(mq, sq, md, sd),
|
| 246 |
+
"photon_prob": lambda mq, sq, md, sd: photon_prob_eval(mq, sq, md, sd, cutoff=photon_cutoff),
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
metric_rows: dict[str, list] = {m: [] for m in metrics}
|
| 250 |
+
for q, mu_q, sg_q in q_np:
|
| 251 |
+
for metric in metrics:
|
| 252 |
+
scored = []
|
| 253 |
+
for a in ids:
|
| 254 |
+
mu_d, sg_d = doc_np[a]
|
| 255 |
+
f = score_fn[metric](mu_q, sg_q, mu_d, sg_d)
|
| 256 |
+
scored.append((f, a))
|
| 257 |
+
scored.sort(key=lambda x: -x[0])
|
| 258 |
+
ranked_ids = [a for _, a in scored]
|
| 259 |
+
rel = set(q["relevant_ids"])
|
| 260 |
+
row = {"query": q["query"], "ranked": ranked_ids[: max(ks)]}
|
| 261 |
+
for k in ks:
|
| 262 |
+
row[f"recall@{k}"] = recall_at_k(ranked_ids, rel, k)
|
| 263 |
+
row[f"ndcg@{k}"] = ndcg_at_k(ranked_ids, rel, k)
|
| 264 |
+
metric_rows[metric].append(row)
|
| 265 |
+
|
| 266 |
+
out = {}
|
| 267 |
+
for metric in metrics:
|
| 268 |
+
rows = metric_rows[metric]
|
| 269 |
+
agg = {f"recall@{k}": float(np.mean([r[f"recall@{k}"] for r in rows])) for k in ks}
|
| 270 |
+
agg.update({f"ndcg@{k}": float(np.mean([r[f"ndcg@{k}"] for r in rows])) for k in ks})
|
| 271 |
+
out[metric] = {"per_query": rows, "aggregate": agg}
|
| 272 |
+
return out
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def train(args):
|
| 276 |
+
torch.manual_seed(args.seed)
|
| 277 |
+
np.random.seed(args.seed)
|
| 278 |
+
|
| 279 |
+
rel_path = Path(args.relevance) if args.relevance else ROOT / "eval" / "relevance.json"
|
| 280 |
+
cids_path = ROOT / "eval" / "corpus_ids.json"
|
| 281 |
+
man_path = ROOT / "eval" / "manifest.json"
|
| 282 |
+
|
| 283 |
+
train_relevance = json.loads(rel_path.read_text("utf-8"))["queries"]
|
| 284 |
+
ids = json.loads(cids_path.read_text("utf-8"))["ids"]
|
| 285 |
+
print(f"[sbert] fetching {len(ids)} abstracts...", flush=True)
|
| 286 |
+
abstracts = fetch_all(ids)
|
| 287 |
+
bad = verify_against_manifest(abstracts, man_path)
|
| 288 |
+
if bad:
|
| 289 |
+
sys.exit(f"manifest mismatch: {list(bad)[:3]}")
|
| 290 |
+
|
| 291 |
+
model = SBERTPhoton(n_modes=N_MODES, no_squeeze=args.no_squeeze)
|
| 292 |
+
n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 293 |
+
print(f"[sbert] trainable params = {n_trainable} no_squeeze={args.no_squeeze}", flush=True)
|
| 294 |
+
|
| 295 |
+
# ── Cache frozen-SBERT features for every doc and every query ONCE.
|
| 296 |
+
# Without this we re-run the transformer for every (doc, step) pair —
|
| 297 |
+
# 20 docs × 200 steps = 4000 forward passes per training run, ~5 min
|
| 298 |
+
# of pure-inference waste on cloud CPU. Frozen features don't change.
|
| 299 |
+
print(f"[sbert] caching SBERT features for {len(abstracts)} docs + "
|
| 300 |
+
f"{len(train_relevance)} queries...", flush=True)
|
| 301 |
+
doc_feats = {a: model.encode_features([t])[0] for a, t in abstracts.items()}
|
| 302 |
+
query_feats = {q["query"]: model.encode_features([q["query"]])[0]
|
| 303 |
+
for q in train_relevance}
|
| 304 |
+
|
| 305 |
+
optim = torch.optim.AdamW(
|
| 306 |
+
[p for p in model.parameters() if p.requires_grad],
|
| 307 |
+
lr=args.lr, weight_decay=args.weight_decay,
|
| 308 |
+
)
|
| 309 |
+
rng = np.random.default_rng(args.seed)
|
| 310 |
+
queries = [(q["query"], set(q["relevant_ids"])) for q in train_relevance]
|
| 311 |
+
|
| 312 |
+
t0 = time.time()
|
| 313 |
+
for step in range(1, args.steps + 1):
|
| 314 |
+
optim.zero_grad()
|
| 315 |
+
# Re-run the trainable projection each step over CACHED features
|
| 316 |
+
# (instead of re-running SBERT). projection ∈ R^{384×4N}, cheap.
|
| 317 |
+
doc_states = {a: model.state_from_features(doc_feats[a]) for a in abstracts}
|
| 318 |
+
|
| 319 |
+
loss_sum = torch.zeros((), dtype=torch.float32)
|
| 320 |
+
for query_text, rel_set in queries:
|
| 321 |
+
mu_q, sg_q = model.state_from_features(query_feats[query_text])
|
| 322 |
+
pos_id = rng.choice(sorted(rel_set))
|
| 323 |
+
mu_p, sg_p = doc_states[pos_id]
|
| 324 |
+
negs = rng.choice(
|
| 325 |
+
[i for i in ids if i not in rel_set],
|
| 326 |
+
size=min(args.negatives, len(ids) - len(rel_set)), replace=False,
|
| 327 |
+
)
|
| 328 |
+
d_pos = bhattacharyya_distance(mu_q, sg_q, mu_p, sg_p)
|
| 329 |
+
d_negs = torch.stack([bhattacharyya_distance(mu_q, sg_q, *doc_states[n]) for n in negs])
|
| 330 |
+
logits = -torch.cat([d_pos.unsqueeze(0), d_negs]) / args.temperature
|
| 331 |
+
ce = F.cross_entropy(logits.unsqueeze(0), torch.zeros((), dtype=torch.long).unsqueeze(0))
|
| 332 |
+
loss_sum = loss_sum + ce
|
| 333 |
+
loss_sum = loss_sum / len(queries)
|
| 334 |
+
loss_sum.backward()
|
| 335 |
+
torch.nn.utils.clip_grad_norm_(
|
| 336 |
+
[p for p in model.parameters() if p.requires_grad], max_norm=args.clip,
|
| 337 |
+
)
|
| 338 |
+
optim.step()
|
| 339 |
+
if step == 1 or step % args.log_every == 0 or step == args.steps:
|
| 340 |
+
print(f"[sbert] step {step}/{args.steps} loss={loss_sum.item():.4f} "
|
| 341 |
+
f"elapsed={time.time()-t0:.1f}s", flush=True)
|
| 342 |
+
|
| 343 |
+
# final eval against whichever relevance file the user asks for
|
| 344 |
+
eval_paths = []
|
| 345 |
+
if args.eval_train_rel:
|
| 346 |
+
eval_paths.append(("train", Path(args.eval_train_rel)))
|
| 347 |
+
if args.eval_test_rel:
|
| 348 |
+
eval_paths.append(("test", Path(args.eval_test_rel)))
|
| 349 |
+
if not eval_paths:
|
| 350 |
+
eval_paths.append(("all", rel_path))
|
| 351 |
+
summary = {}
|
| 352 |
+
for label, p in eval_paths:
|
| 353 |
+
rels = json.loads(p.read_text("utf-8"))["queries"]
|
| 354 |
+
multi = evaluate(model, abstracts, ids, rels,
|
| 355 |
+
metrics=("gaussian", "photon_prob"))
|
| 356 |
+
for metric, report in multi.items():
|
| 357 |
+
print(f"\n=== {label.upper()} EVAL — metric={metric} ({len(rels)} queries) ===")
|
| 358 |
+
for r in report["per_query"]:
|
| 359 |
+
cells = " ".join(f"{m}={r[m]:.3f}" for m in r if m.startswith(("recall", "ndcg")))
|
| 360 |
+
print(f" {r['query'][:48]:<48s} {cells}")
|
| 361 |
+
print("aggregate: " + " ".join(
|
| 362 |
+
f"{m}={report['aggregate'][m]:.3f}" for m in report["aggregate"]
|
| 363 |
+
))
|
| 364 |
+
summary[f"{label}/{metric}"] = report["aggregate"]
|
| 365 |
+
# Sentinel line for downstream parsers (run_sweep.py).
|
| 366 |
+
print(f"\nSUMMARY_JSON={json.dumps(summary)}")
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
def main():
|
| 370 |
+
ap = argparse.ArgumentParser()
|
| 371 |
+
ap.add_argument("--steps", type=int, default=200)
|
| 372 |
+
ap.add_argument("--lr", type=float, default=1e-2)
|
| 373 |
+
ap.add_argument("--weight-decay", type=float, default=1e-3)
|
| 374 |
+
ap.add_argument("--temperature", type=float, default=2.0)
|
| 375 |
+
ap.add_argument("--negatives", type=int, default=8)
|
| 376 |
+
ap.add_argument("--clip", type=float, default=1.0)
|
| 377 |
+
ap.add_argument("--seed", type=int, default=42)
|
| 378 |
+
ap.add_argument("--log-every", type=int, default=20)
|
| 379 |
+
ap.add_argument("--relevance", type=str, default=None,
|
| 380 |
+
help="training relevance.json (e.g. /tmp/rel_train.json)")
|
| 381 |
+
ap.add_argument("--eval-train-rel", type=str, default=None,
|
| 382 |
+
help="optional separate train-eval set for in-sample numbers")
|
| 383 |
+
ap.add_argument("--eval-test-rel", type=str, default=None,
|
| 384 |
+
help="optional held-out eval set for generalization numbers")
|
| 385 |
+
ap.add_argument("--no-squeeze", action="store_true",
|
| 386 |
+
help="ablation: force r=0 (displacement-only). Tests whether the squeezing layer specifically pays.")
|
| 387 |
+
args = ap.parse_args()
|
| 388 |
+
train(args)
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
if __name__ == "__main__":
|
| 392 |
+
main()
|
space/train_sbert_fock.py
ADDED
|
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""A3-Real — non-Gaussian SBERT-photon trainer.
|
| 2 |
+
|
| 3 |
+
Architecture:
|
| 4 |
+
text → frozen SBERT 384d → Linear(384, 6) → photonic params
|
| 5 |
+
→ 2-mode (signal + ancilla) cutoff-D Fock-basis encoder
|
| 6 |
+
→ unitary U = D_signal(α) S_signal(r,φ) S_2(τ,θ) applied to |0,0⟩
|
| 7 |
+
→ project ancilla onto |1⟩ (single-photon herald)
|
| 8 |
+
→ normalised pure state |ψ_sig⟩ ∈ ℂ^D, NON-GAUSSIAN
|
| 9 |
+
score(q, d) = |⟨ψ_q,sig | ψ_d,sig⟩|²
|
| 10 |
+
|
| 11 |
+
Why this is genuinely new vs space/train_sbert.py: with α small and r mild,
|
| 12 |
+
the heralded-on-|1⟩ signal mode contains a single-photon contribution. A
|
| 13 |
+
single-photon Fock state has Wigner-negative regions — non-Gaussian. The
|
| 14 |
+
similarity |⟨ψ_q|ψ_d⟩|² is *not* representable as a Gaussian-RBF kernel
|
| 15 |
+
on any finite-d projection of the inputs (Sim 1's negative result for the
|
| 16 |
+
Gaussian path).
|
| 17 |
+
|
| 18 |
+
Loss: InfoNCE on -log(score) (i.e. score is the affinity logit). Cached
|
| 19 |
+
SBERT features. Same eval/relevance as the Gaussian trainer for direct
|
| 20 |
+
head-to-head numbers.
|
| 21 |
+
"""
|
| 22 |
+
from __future__ import annotations
|
| 23 |
+
|
| 24 |
+
import argparse
|
| 25 |
+
import json
|
| 26 |
+
import math
|
| 27 |
+
import sys
|
| 28 |
+
import time
|
| 29 |
+
from pathlib import Path
|
| 30 |
+
|
| 31 |
+
import numpy as np
|
| 32 |
+
import torch
|
| 33 |
+
import torch.nn as nn
|
| 34 |
+
import torch.nn.functional as F
|
| 35 |
+
from torch import Tensor
|
| 36 |
+
|
| 37 |
+
ROOT = Path(__file__).resolve().parent.parent
|
| 38 |
+
SRC = ROOT / "src"
|
| 39 |
+
if str(SRC) not in sys.path:
|
| 40 |
+
sys.path.insert(0, str(SRC))
|
| 41 |
+
|
| 42 |
+
from eval.fetch import fetch_all, verify_against_manifest # noqa: E402
|
| 43 |
+
|
| 44 |
+
SBERT_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
| 45 |
+
SBERT_DIM = 384
|
| 46 |
+
HERALD_N = 1 # ancilla outcome we herald on; single-photon → genuinely non-Gaussian
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# ─── Truncated bosonic operators in Fock basis ────────────────────────────
|
| 50 |
+
def annihilation_op(D: int) -> Tensor:
|
| 51 |
+
"""a |n⟩ = √n |n-1⟩ (D-dim truncation; lossy at the top)."""
|
| 52 |
+
a = torch.zeros(D, D, dtype=torch.complex128)
|
| 53 |
+
for n in range(1, D):
|
| 54 |
+
a[n - 1, n] = math.sqrt(n)
|
| 55 |
+
return a
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def kron(A: Tensor, B: Tensor) -> Tensor:
|
| 59 |
+
return torch.kron(A, B)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# ─── Generators of the unitaries ────────────────────────────────────────────
|
| 63 |
+
def displace_generator(a: Tensor, alpha: Tensor) -> Tensor:
|
| 64 |
+
"""G_D = α a† − α* a; applied as exp(G_D) gives D(α). α is complex scalar."""
|
| 65 |
+
return alpha * a.conj().T - torch.conj(alpha) * a
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def squeeze_generator(a: Tensor, zeta: Tensor) -> Tensor:
|
| 69 |
+
"""G_S = (1/2)(ζ* a² − ζ a†²); applied as exp(G_S) gives S(ζ). ζ = r e^{iφ}."""
|
| 70 |
+
a2 = a @ a
|
| 71 |
+
return 0.5 * (torch.conj(zeta) * a2 - zeta * a2.conj().T)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def two_mode_squeeze_generator(a: Tensor, b: Tensor, xi: Tensor) -> Tensor:
|
| 75 |
+
"""G_{TMS} = ξ* a b − ξ a† b†. Acts on joint signal⊗ancilla space.
|
| 76 |
+
Inputs are full-dim joint operators a, b (e.g. a = a_signal ⊗ I_anc)."""
|
| 77 |
+
return torch.conj(xi) * (a @ b) - xi * (a.conj().T @ b.conj().T)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# ─── Encoder ────────────────────────────────────────────────────────────────
|
| 81 |
+
class SBERTPhotonFock(nn.Module):
|
| 82 |
+
"""SBERT → Linear(384, 6) → 2-mode Fock encoder → herald → 1-mode pure state.
|
| 83 |
+
|
| 84 |
+
The 6 outputs decompose to (αq, αp, r, φ_s, τ, θ):
|
| 85 |
+
α = αq + i·αp (signal displacement)
|
| 86 |
+
ζ = r e^{iφ_s}, r ∈ [0, 0.5] (signal squeezing)
|
| 87 |
+
ξ = τ e^{iθ}, τ ∈ [0, 0.5] (two-mode squeezing toward ancilla)
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
def __init__(self, cutoff: int = 6, max_squeeze: float = 0.5,
|
| 91 |
+
max_displace: float = 1.5, herald_n: int = HERALD_N):
|
| 92 |
+
super().__init__()
|
| 93 |
+
from sentence_transformers import SentenceTransformer
|
| 94 |
+
self.D = cutoff
|
| 95 |
+
self.max_sq = max_squeeze
|
| 96 |
+
self.max_disp = max_displace
|
| 97 |
+
self.herald_n = herald_n
|
| 98 |
+
if herald_n >= cutoff:
|
| 99 |
+
raise ValueError(f"herald_n={herald_n} must be < cutoff={cutoff}")
|
| 100 |
+
self.sbert = SentenceTransformer(SBERT_MODEL_NAME)
|
| 101 |
+
for p in self.sbert.parameters():
|
| 102 |
+
p.requires_grad = False
|
| 103 |
+
# Trainable surface
|
| 104 |
+
self.proj = nn.Linear(SBERT_DIM, 6, dtype=torch.float32)
|
| 105 |
+
nn.init.normal_(self.proj.weight, std=0.02)
|
| 106 |
+
nn.init.zeros_(self.proj.bias)
|
| 107 |
+
# Pre-compute truncated bosonic operators (constants — no grad needed)
|
| 108 |
+
a = annihilation_op(cutoff)
|
| 109 |
+
I = torch.eye(cutoff, dtype=torch.complex128)
|
| 110 |
+
# Joint-space (signal ⊗ ancilla): a_s = a ⊗ I, b_a = I ⊗ a
|
| 111 |
+
self.register_buffer("a_signal_full", kron(a, I))
|
| 112 |
+
self.register_buffer("b_anc_full", kron(I, a))
|
| 113 |
+
self.register_buffer("a_signal_local", a) # for solo signal-side gates if ever needed
|
| 114 |
+
# Initial vacuum |0,0⟩ in joint Fock basis (D² vector)
|
| 115 |
+
psi0 = torch.zeros(cutoff * cutoff, dtype=torch.complex128)
|
| 116 |
+
psi0[0] = 1.0 # index (0, 0) → flat 0
|
| 117 |
+
self.register_buffer("vacuum", psi0)
|
| 118 |
+
|
| 119 |
+
def encode_features(self, texts: list[str]) -> Tensor:
|
| 120 |
+
with torch.no_grad():
|
| 121 |
+
emb = self.sbert.encode(
|
| 122 |
+
texts, normalize_embeddings=True, convert_to_numpy=False,
|
| 123 |
+
show_progress_bar=False,
|
| 124 |
+
)
|
| 125 |
+
emb = torch.stack([e for e in emb]) if isinstance(emb, list) else emb
|
| 126 |
+
return emb.to(torch.float32).cpu()
|
| 127 |
+
|
| 128 |
+
def state_from_features(self, feat: Tensor) -> Tensor:
|
| 129 |
+
"""Returns the heralded signal-mode state |ψ_sig⟩ ∈ ℂ^D, normalized.
|
| 130 |
+
Shape: (D,) complex128.
|
| 131 |
+
"""
|
| 132 |
+
out = self.proj(feat) # (6,) float32
|
| 133 |
+
out = out.to(torch.float64)
|
| 134 |
+
# decompose with bounded reparametrizations
|
| 135 |
+
alpha_q = self.max_disp * torch.tanh(out[0])
|
| 136 |
+
alpha_p = self.max_disp * torch.tanh(out[1])
|
| 137 |
+
r = self.max_sq * torch.sigmoid(out[2])
|
| 138 |
+
phi_s = (2 * math.pi) * torch.sigmoid(out[3])
|
| 139 |
+
tau = self.max_sq * torch.sigmoid(out[4])
|
| 140 |
+
theta = (2 * math.pi) * torch.sigmoid(out[5])
|
| 141 |
+
# Build complex parameters
|
| 142 |
+
alpha = torch.complex(alpha_q, alpha_p)
|
| 143 |
+
zeta = torch.complex(r * torch.cos(phi_s), r * torch.sin(phi_s))
|
| 144 |
+
xi = torch.complex(tau * torch.cos(theta), tau * torch.sin(theta))
|
| 145 |
+
# Generators in joint space
|
| 146 |
+
G_TMS = two_mode_squeeze_generator(
|
| 147 |
+
self.a_signal_full, self.b_anc_full, xi,
|
| 148 |
+
)
|
| 149 |
+
G_S = squeeze_generator(self.a_signal_full, zeta)
|
| 150 |
+
G_D = displace_generator(self.a_signal_full, alpha)
|
| 151 |
+
# Apply unitaries: |ψ⟩ = D · S · S_2 · |0,0⟩
|
| 152 |
+
U_TMS = torch.linalg.matrix_exp(G_TMS)
|
| 153 |
+
U_S = torch.linalg.matrix_exp(G_S)
|
| 154 |
+
U_D = torch.linalg.matrix_exp(G_D)
|
| 155 |
+
psi = U_TMS @ self.vacuum
|
| 156 |
+
psi = U_S @ psi
|
| 157 |
+
psi = U_D @ psi
|
| 158 |
+
# Project ancilla onto |herald_n⟩. Joint flat index = signal*D + ancilla.
|
| 159 |
+
# Pick rows where ancilla == herald_n: rows = [signal*D + herald_n for signal in 0..D-1]
|
| 160 |
+
D = self.D
|
| 161 |
+
idx = torch.arange(D, device=psi.device, dtype=torch.long) * D + self.herald_n
|
| 162 |
+
psi_sig = psi[idx]
|
| 163 |
+
# Normalize (heralding probability is the squared norm; we drop it)
|
| 164 |
+
norm = torch.linalg.vector_norm(psi_sig)
|
| 165 |
+
psi_sig = psi_sig / torch.clamp(norm, min=1e-12)
|
| 166 |
+
return psi_sig
|
| 167 |
+
|
| 168 |
+
def state_from_text(self, text: str) -> Tensor:
|
| 169 |
+
feat = self.encode_features([text])[0]
|
| 170 |
+
return self.state_from_features(feat)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def overlap_squared(psi_a: Tensor, psi_b: Tensor) -> Tensor:
|
| 174 |
+
"""|⟨ψ_a|ψ_b⟩|². Pure-state fidelity since both are heralded pure."""
|
| 175 |
+
inner = torch.vdot(psi_a, psi_b)
|
| 176 |
+
return (inner.real ** 2 + inner.imag ** 2)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def recall_at_k(ranked, relevant, k):
|
| 180 |
+
if not relevant:
|
| 181 |
+
return float("nan")
|
| 182 |
+
return len(set(ranked[:k]) & relevant) / len(relevant)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def ndcg_at_k(ranked, relevant, k):
|
| 186 |
+
if not relevant:
|
| 187 |
+
return float("nan")
|
| 188 |
+
dcg = sum(1.0 / math.log2(i + 1) for i, a in enumerate(ranked[:k], start=1) if a in relevant)
|
| 189 |
+
ideal = sum(1.0 / math.log2(i + 1) for i in range(1, min(k, len(relevant)) + 1))
|
| 190 |
+
return dcg / ideal if ideal > 0 else float("nan")
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def evaluate(model, abstracts, ids, queries, ks=(1, 3, 5, 10)) -> dict:
|
| 194 |
+
model.eval()
|
| 195 |
+
with torch.no_grad():
|
| 196 |
+
doc_states = {a: model.state_from_text(t) for a, t in abstracts.items()}
|
| 197 |
+
rows = []
|
| 198 |
+
for q in queries:
|
| 199 |
+
with torch.no_grad():
|
| 200 |
+
psi_q = model.state_from_text(q["query"])
|
| 201 |
+
scored = []
|
| 202 |
+
for a in ids:
|
| 203 |
+
psi_d = doc_states[a]
|
| 204 |
+
scored.append((float(overlap_squared(psi_q, psi_d).item()), a))
|
| 205 |
+
scored.sort(key=lambda x: -x[0])
|
| 206 |
+
ranked = [a for _, a in scored]
|
| 207 |
+
rel = set(q["relevant_ids"])
|
| 208 |
+
row = {"query": q["query"], "ranked": ranked[: max(ks)]}
|
| 209 |
+
for k in ks:
|
| 210 |
+
row[f"recall@{k}"] = recall_at_k(ranked, rel, k)
|
| 211 |
+
row[f"ndcg@{k}"] = ndcg_at_k(ranked, rel, k)
|
| 212 |
+
rows.append(row)
|
| 213 |
+
agg = {f"recall@{k}": float(np.mean([r[f"recall@{k}"] for r in rows])) for k in ks}
|
| 214 |
+
agg.update({f"ndcg@{k}": float(np.mean([r[f"ndcg@{k}"] for r in rows])) for k in ks})
|
| 215 |
+
return {"per_query": rows, "aggregate": agg}
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def train(args):
|
| 219 |
+
torch.manual_seed(args.seed)
|
| 220 |
+
np.random.seed(args.seed)
|
| 221 |
+
|
| 222 |
+
rel_path = Path(args.relevance) if args.relevance else ROOT / "eval" / "relevance.json"
|
| 223 |
+
cids_path = ROOT / "eval" / "corpus_ids.json"
|
| 224 |
+
man_path = ROOT / "eval" / "manifest.json"
|
| 225 |
+
train_relevance = json.loads(rel_path.read_text("utf-8"))["queries"]
|
| 226 |
+
ids = json.loads(cids_path.read_text("utf-8"))["ids"]
|
| 227 |
+
print(f"[fock] fetching {len(ids)} abstracts...", flush=True)
|
| 228 |
+
abstracts = fetch_all(ids)
|
| 229 |
+
bad = verify_against_manifest(abstracts, man_path)
|
| 230 |
+
if bad:
|
| 231 |
+
sys.exit(f"manifest mismatch: {list(bad)[:3]}")
|
| 232 |
+
|
| 233 |
+
model = SBERTPhotonFock(cutoff=args.cutoff, herald_n=args.herald_n)
|
| 234 |
+
n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 235 |
+
print(f"[fock] cutoff={args.cutoff} herald_n={args.herald_n} trainable={n_trainable}", flush=True)
|
| 236 |
+
|
| 237 |
+
# Cache features
|
| 238 |
+
print(f"[fock] caching SBERT features ({len(abstracts)} docs + "
|
| 239 |
+
f"{len(train_relevance)} queries)...", flush=True)
|
| 240 |
+
doc_feats = {a: model.encode_features([t])[0] for a, t in abstracts.items()}
|
| 241 |
+
query_feats = {q["query"]: model.encode_features([q["query"]])[0]
|
| 242 |
+
for q in train_relevance}
|
| 243 |
+
|
| 244 |
+
optim = torch.optim.AdamW(
|
| 245 |
+
[p for p in model.parameters() if p.requires_grad],
|
| 246 |
+
lr=args.lr, weight_decay=args.weight_decay,
|
| 247 |
+
)
|
| 248 |
+
rng = np.random.default_rng(args.seed)
|
| 249 |
+
queries = [(q["query"], set(q["relevant_ids"])) for q in train_relevance]
|
| 250 |
+
|
| 251 |
+
t0 = time.time()
|
| 252 |
+
for step in range(1, args.steps + 1):
|
| 253 |
+
optim.zero_grad()
|
| 254 |
+
# Recompute every state (proj weights change every step). Cached
|
| 255 |
+
# SBERT features feed straight into state_from_features.
|
| 256 |
+
doc_psi = {a: model.state_from_features(doc_feats[a]) for a in abstracts}
|
| 257 |
+
|
| 258 |
+
loss_sum = torch.zeros((), dtype=torch.float64)
|
| 259 |
+
for query_text, rel_set in queries:
|
| 260 |
+
psi_q = model.state_from_features(query_feats[query_text])
|
| 261 |
+
pos_id = rng.choice(sorted(rel_set))
|
| 262 |
+
psi_pos = doc_psi[pos_id]
|
| 263 |
+
negs = rng.choice(
|
| 264 |
+
[i for i in ids if i not in rel_set],
|
| 265 |
+
size=min(args.negatives, len(ids) - len(rel_set)), replace=False,
|
| 266 |
+
)
|
| 267 |
+
f_pos = overlap_squared(psi_q, psi_pos)
|
| 268 |
+
f_negs = torch.stack([overlap_squared(psi_q, doc_psi[n]) for n in negs])
|
| 269 |
+
# InfoNCE: log P(pos) = log(f_pos / Σ f). Fidelities are in [0,1] so use as logits directly.
|
| 270 |
+
logits = torch.cat([f_pos.unsqueeze(0), f_negs]) / args.temperature
|
| 271 |
+
ce = F.cross_entropy(logits.unsqueeze(0), torch.zeros((), dtype=torch.long).unsqueeze(0))
|
| 272 |
+
loss_sum = loss_sum + ce
|
| 273 |
+
loss_sum = loss_sum / len(queries)
|
| 274 |
+
loss_sum.backward()
|
| 275 |
+
torch.nn.utils.clip_grad_norm_(
|
| 276 |
+
[p for p in model.parameters() if p.requires_grad], max_norm=args.clip,
|
| 277 |
+
)
|
| 278 |
+
optim.step()
|
| 279 |
+
if step == 1 or step % args.log_every == 0 or step == args.steps:
|
| 280 |
+
print(f"[fock] step {step}/{args.steps} loss={loss_sum.item():.4f} "
|
| 281 |
+
f"elapsed={time.time()-t0:.1f}s", flush=True)
|
| 282 |
+
|
| 283 |
+
eval_paths = []
|
| 284 |
+
if args.eval_train_rel:
|
| 285 |
+
eval_paths.append(("train", Path(args.eval_train_rel)))
|
| 286 |
+
if args.eval_test_rel:
|
| 287 |
+
eval_paths.append(("test", Path(args.eval_test_rel)))
|
| 288 |
+
if not eval_paths:
|
| 289 |
+
eval_paths.append(("all", rel_path))
|
| 290 |
+
summary = {}
|
| 291 |
+
for label, p in eval_paths:
|
| 292 |
+
rels = json.loads(p.read_text("utf-8"))["queries"]
|
| 293 |
+
report = evaluate(model, abstracts, ids, rels)
|
| 294 |
+
print(f"\n=== {label.upper()} EVAL ({len(rels)} queries) ===")
|
| 295 |
+
for r in report["per_query"]:
|
| 296 |
+
cells = " ".join(f"{m}={r[m]:.3f}" for m in r if m.startswith(("recall", "ndcg")))
|
| 297 |
+
print(f" {r['query'][:48]:<48s} {cells}")
|
| 298 |
+
print("aggregate: " + " ".join(f"{m}={report['aggregate'][m]:.3f}" for m in report["aggregate"]))
|
| 299 |
+
summary[f"{label}/fock"] = report["aggregate"]
|
| 300 |
+
print(f"\nSUMMARY_JSON={json.dumps(summary)}")
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def main():
|
| 304 |
+
ap = argparse.ArgumentParser()
|
| 305 |
+
ap.add_argument("--cutoff", type=int, default=6,
|
| 306 |
+
help="Fock-basis truncation per mode. Joint dim = cutoff².")
|
| 307 |
+
ap.add_argument("--herald-n", type=int, default=HERALD_N,
|
| 308 |
+
help="Ancilla photon-number outcome to project onto.")
|
| 309 |
+
ap.add_argument("--steps", type=int, default=200)
|
| 310 |
+
ap.add_argument("--lr", type=float, default=1e-2)
|
| 311 |
+
ap.add_argument("--weight-decay", type=float, default=1e-3)
|
| 312 |
+
ap.add_argument("--temperature", type=float, default=0.5)
|
| 313 |
+
ap.add_argument("--negatives", type=int, default=8)
|
| 314 |
+
ap.add_argument("--clip", type=float, default=1.0)
|
| 315 |
+
ap.add_argument("--seed", type=int, default=42)
|
| 316 |
+
ap.add_argument("--log-every", type=int, default=20)
|
| 317 |
+
ap.add_argument("--relevance", type=str, default=None)
|
| 318 |
+
ap.add_argument("--eval-train-rel", type=str, default=None)
|
| 319 |
+
ap.add_argument("--eval-test-rel", type=str, default=None)
|
| 320 |
+
args = ap.parse_args()
|
| 321 |
+
train(args)
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
if __name__ == "__main__":
|
| 325 |
+
main()
|