File size: 8,228 Bytes
b77332f
 
102bb6a
 
 
 
 
 
 
 
 
 
 
 
 
 
b77332f
 
 
 
 
102bb6a
b77332f
 
102bb6a
b77332f
 
e0480d0
b77332f
 
 
 
102bb6a
3755f0a
102bb6a
 
 
 
 
 
b77332f
 
 
 
102bb6a
b77332f
102bb6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b77332f
3755f0a
 
102bb6a
 
 
 
 
 
 
 
 
 
b77332f
 
 
 
 
 
102bb6a
 
 
b77332f
 
 
 
 
 
 
 
 
 
 
102bb6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b77332f
 
 
 
 
 
 
102bb6a
 
b77332f
102bb6a
b77332f
 
 
 
 
 
 
102bb6a
 
 
 
 
 
3755f0a
 
 
b77332f
 
 
 
102bb6a
b77332f
 
e0480d0
 
 
 
 
 
 
 
 
 
 
 
 
 
b77332f
 
 
 
102bb6a
 
 
 
b77332f
102bb6a
b77332f
 
 
102bb6a
b77332f
 
 
 
 
 
102bb6a
 
 
b77332f
102bb6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b77332f
 
 
102bb6a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
"""FastAPI server for photon-route.

Three retrieval backends, all sharing the same 5-doc fixture corpus:

  v1        — original SHA-256 → SF gaussian backend (StrawberryFields)
  sha_init  — pure-numpy v2 encoder, identical SHA-256 init (Step 0
              equivalence: produces (mu, sigma) byte-identical to v1
              modulo dtype). Useful as a no-torch baseline.
  trained   — pure-numpy v2 encoder loaded from /app/weights.npz, the
              artifact produced by `space.train`. Only available when
              the Docker image was built with the training stage.

Importing the CV stack (strawberryfields + thewalrus) is required for
v1/sha_init/trained to function — fidelity scoring is closed-form
Gaussian-state fidelity (Banchi-Braunstein-Pirandola). If the import
fails, /rank degrades to the stub mode (deployment plumbing only).
"""

from __future__ import annotations

import os
from pathlib import Path
from typing import Any

import numpy as np
from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse

from photon_route import __version__
from photon_route.corpus import Document, load_fixture

BACKEND_AVAILABLE: list[str] = []
_IMPORT_ERROR: str | None = None
_v1_corpus: list[Any] | None = None
_v2_sha_corpus: list[tuple[np.ndarray, np.ndarray, str, dict]] | None = None
_v2_trained_corpus: list[tuple[np.ndarray, np.ndarray, str, dict]] | None = None
_v2_sha_encoder = None
_v2_trained_encoder = None
WEIGHTS_PATH = Path(os.environ.get("PHOTON_ROUTE_WEIGHTS", "/app/weights.npz"))

try:
    import strawberryfields  # noqa: F401
    import thewalrus  # noqa: F401
    from thewalrus.quantum import fidelity as _tw_fidelity

    from photon_route.encode import encode_corpus as _v1_encode_corpus
    from photon_route.retrieve import rank_against as _v1_rank_against
    from photon_route.v2 import Encoder, sha_params_v1_compat
    from photon_route.v2.encoder import dict_params_fn

    BACKEND_AVAILABLE.append("v1")
    BACKEND_AVAILABLE.append("sha_init")

    _v2_sha_encoder = Encoder(params_fn=sha_params_v1_compat)

    if WEIGHTS_PATH.exists():
        _data = np.load(WEIGHTS_PATH, allow_pickle=True)
        _table = {
            str(w): np.asarray(p, dtype=np.float64)
            for w, p in zip(_data["words"], _data["params"])
        }
        _unk = np.asarray(_data["unk"], dtype=np.float64) if "unk" in _data.files else None
        _v2_trained_encoder = Encoder(params_fn=dict_params_fn(_table, unk=_unk))
        BACKEND_AVAILABLE.append("trained")
        print(f"[photon-route] loaded trained weights: |V|={len(_table)}", flush=True)
    else:
        print(f"[photon-route] no trained weights at {WEIGHTS_PATH}; trained backend disabled", flush=True)

except Exception as _e:
    _IMPORT_ERROR = f"{type(_e).__name__}: {_e}"
    print(f"[photon-route] CV stack failed → backend=stub: {_IMPORT_ERROR}", flush=True)


def _has_cv() -> bool:
    return "v1" in BACKEND_AVAILABLE


DEFAULT_BACKEND = "trained" if "trained" in BACKEND_AVAILABLE else (
    "sha_init" if "sha_init" in BACKEND_AVAILABLE else "stub"
)


app = FastAPI(
    title="photon-route",
    description=(
        "Continuous-variable photonic retrieval. Each document is encoded "
        "as a Gaussian state; ranking is closed-form Gaussian-state "
        "fidelity (Banchi-Braunstein-Pirandola). Three swappable encoders: "
        "v1 (SF + SHA), sha_init (numpy + SHA), trained (numpy + learned)."
    ),
    version=__version__,
)
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["GET"],
    allow_headers=["*"],
)


def _ensure_v1_corpus() -> list[Any]:
    global _v1_corpus
    if _v1_corpus is None:
        _v1_corpus = _v1_encode_corpus(load_fixture())
    return _v1_corpus


def _encode_corpus_numpy(encoder) -> list[tuple[np.ndarray, np.ndarray, str, dict]]:
    return [
        (*encoder.encode(d.text), d.text, d.meta)
        for d in load_fixture()
    ]


def _ensure_v2_sha_corpus():
    global _v2_sha_corpus
    if _v2_sha_corpus is None:
        _v2_sha_corpus = _encode_corpus_numpy(_v2_sha_encoder)
    return _v2_sha_corpus


def _ensure_v2_trained_corpus():
    global _v2_trained_corpus
    if _v2_trained_corpus is None and _v2_trained_encoder is not None:
        _v2_trained_corpus = _encode_corpus_numpy(_v2_trained_encoder)
    return _v2_trained_corpus


def _safe_fidelity(mu_a, sg_a, mu_b, sg_b) -> float:
    try:
        f = _tw_fidelity(mu_a, sg_a, mu_b, sg_b)
        val = float(f.real if hasattr(f, "real") else f)
        return max(0.0, min(1.0, val))
    except (ValueError, RuntimeError, np.linalg.LinAlgError):
        return 0.0


def _rank_v2(encoder, encoded_corpus, query: str, top_k: int):
    mu_q, sg_q = encoder.encode(query)
    scored = [
        (_safe_fidelity(mu_q, sg_q, mu_d, sg_d), text, meta)
        for (mu_d, sg_d, text, meta) in encoded_corpus
    ]
    scored.sort(key=lambda x: x[0], reverse=True)
    return scored[:top_k]


@app.get("/")
def root() -> dict[str, Any]:
    return {
        "name": "photon-route",
        "version": __version__,
        "backends_available": BACKEND_AVAILABLE or ["stub"],
        "default_backend": DEFAULT_BACKEND,
        "n_modes": int(os.environ.get("PHOTON_ROUTE_N_MODES", "2")),
        "weights_path": str(WEIGHTS_PATH) if WEIGHTS_PATH.exists() else None,
        "endpoints": ["/", "/health", "/version", "/rank", "/docs"],
        "docs": "/docs",
    }


@app.get("/health")
def health() -> dict[str, Any]:
    out: dict[str, Any] = {
        "ok": True,
        "backends_available": BACKEND_AVAILABLE or ["stub"],
        "default_backend": DEFAULT_BACKEND,
        "weights_loaded": "trained" in BACKEND_AVAILABLE,
    }
    if _IMPORT_ERROR:
        out["import_error"] = _IMPORT_ERROR
    return out


@app.get("/version")
def version() -> dict[str, str]:
    return {"version": __version__, "default_backend": DEFAULT_BACKEND}


@app.get("/weights.npz")
def weights_download():
    """Serve the trained weights.npz so the eval harness on the VM can
    score `--backend v2 --weights weights.npz` without rebuilding the
    whole training pipeline locally. Read-only, baked at build time."""
    if not WEIGHTS_PATH.exists():
        raise HTTPException(status_code=404, detail="no trained weights on this build")
    return FileResponse(
        WEIGHTS_PATH,
        media_type="application/octet-stream",
        filename="weights.npz",
    )


@app.get("/rank")
def rank(
    q: str = Query(..., min_length=1, max_length=500, description="query text"),
    top_k: int = Query(5, ge=1, le=20),
    backend: str = Query(
        DEFAULT_BACKEND,
        description="encoder: v1 | sha_init | trained (falls back to default if unavailable)",
    ),
) -> dict[str, Any]:
    if not _has_cv():
        docs: list[Document] = load_fixture()[:top_k]
        return {
            "query": q,
            "backend": "stub",
            "results": [
                {"rank": i + 1, "score": 0.0, "text": d.text, "meta": d.meta}
                for i, d in enumerate(docs)
            ],
        }

    if backend not in BACKEND_AVAILABLE:
        backend = DEFAULT_BACKEND

    try:
        if backend == "v1":
            scored = _v1_rank_against(_ensure_v1_corpus(), q, top_k=top_k)
            results = [
                {"rank": i + 1, "score": round(r.score, 6),
                 "text": r.doc.doc.text, "meta": r.doc.doc.meta}
                for i, r in enumerate(scored)
            ]
        else:
            enc = _v2_trained_encoder if backend == "trained" else _v2_sha_encoder
            corpus = (
                _ensure_v2_trained_corpus() if backend == "trained"
                else _ensure_v2_sha_corpus()
            )
            scored = _rank_v2(enc, corpus, q, top_k)
            results = [
                {"rank": i + 1, "score": round(s, 6), "text": text, "meta": meta}
                for i, (s, text, meta) in enumerate(scored)
            ]
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"rank failed: {e!r}") from e

    return {"query": q, "backend": backend, "results": results}