Spaces:
Running
http_server: /weights.npz download + train: temp 0.1->5.0, lr 3e-2->5e-3
Browse filesThe first successful build (commit 6038606) produced trained weights but
the InfoNCE loss oscillated wildly (174 -> 84 -> 140 -> 0.6 -> 68 -> 31
across 100 steps), worse than random init. Root cause: temperature=0.1
with the D-distance contrastive form — D values are 0-50, so logits
-D/0.1 reach ±500, the softmax saturates, gradients explode through
AdamW (lr=3e-2). New scale: temp=5.0 keeps logits in [-10, 0],
lr=5e-3 takes smaller steps from the SHA-init starting point.
Also adds GET /weights.npz to FastAPI so the eval harness on the VM
can pull the baked weights from the live Space and run recall/nDCG
locally (the Space's /rank ranks the 5-doc fixture corpus, not the
20-doc eval corpus the model was trained on).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- space/train.py +5 -2
- src/photon_route/http_server.py +15 -0
|
@@ -381,8 +381,11 @@ def main() -> None:
|
|
| 381 |
ap = argparse.ArgumentParser()
|
| 382 |
ap.add_argument("--out", type=Path, default=ROOT / "weights.npz")
|
| 383 |
ap.add_argument("--steps", type=int, default=100)
|
| 384 |
-
ap.add_argument("--lr", type=float, default=
|
| 385 |
-
|
|
|
|
|
|
|
|
|
|
| 386 |
ap.add_argument("--photon-lambda", type=float, default=1e-2)
|
| 387 |
ap.add_argument("--negatives", type=int, default=8)
|
| 388 |
ap.add_argument("--clip", type=float, default=1.0)
|
|
|
|
| 381 |
ap = argparse.ArgumentParser()
|
| 382 |
ap.add_argument("--out", type=Path, default=ROOT / "weights.npz")
|
| 383 |
ap.add_argument("--steps", type=int, default=100)
|
| 384 |
+
ap.add_argument("--lr", type=float, default=5e-3)
|
| 385 |
+
# D-scale logits: with D in [0, 50], temp=0.1 made -D/temp logits up to
|
| 386 |
+
# -500, exploding cross_entropy + gradients. temp=5 keeps logit magnitudes
|
| 387 |
+
# in a sensible range (~0-10) so AdamW can converge instead of oscillate.
|
| 388 |
+
ap.add_argument("--temperature", type=float, default=5.0)
|
| 389 |
ap.add_argument("--photon-lambda", type=float, default=1e-2)
|
| 390 |
ap.add_argument("--negatives", type=int, default=8)
|
| 391 |
ap.add_argument("--clip", type=float, default=1.0)
|
|
@@ -25,6 +25,7 @@ from typing import Any
|
|
| 25 |
import numpy as np
|
| 26 |
from fastapi import FastAPI, HTTPException, Query
|
| 27 |
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
| 28 |
|
| 29 |
from photon_route import __version__
|
| 30 |
from photon_route.corpus import Document, load_fixture
|
|
@@ -177,6 +178,20 @@ def version() -> dict[str, str]:
|
|
| 177 |
return {"version": __version__, "default_backend": DEFAULT_BACKEND}
|
| 178 |
|
| 179 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
@app.get("/rank")
|
| 181 |
def rank(
|
| 182 |
q: str = Query(..., min_length=1, max_length=500, description="query text"),
|
|
|
|
| 25 |
import numpy as np
|
| 26 |
from fastapi import FastAPI, HTTPException, Query
|
| 27 |
from fastapi.middleware.cors import CORSMiddleware
|
| 28 |
+
from fastapi.responses import FileResponse
|
| 29 |
|
| 30 |
from photon_route import __version__
|
| 31 |
from photon_route.corpus import Document, load_fixture
|
|
|
|
| 178 |
return {"version": __version__, "default_backend": DEFAULT_BACKEND}
|
| 179 |
|
| 180 |
|
| 181 |
+
@app.get("/weights.npz")
|
| 182 |
+
def weights_download():
|
| 183 |
+
"""Serve the trained weights.npz so the eval harness on the VM can
|
| 184 |
+
score `--backend v2 --weights weights.npz` without rebuilding the
|
| 185 |
+
whole training pipeline locally. Read-only, baked at build time."""
|
| 186 |
+
if not WEIGHTS_PATH.exists():
|
| 187 |
+
raise HTTPException(status_code=404, detail="no trained weights on this build")
|
| 188 |
+
return FileResponse(
|
| 189 |
+
WEIGHTS_PATH,
|
| 190 |
+
media_type="application/octet-stream",
|
| 191 |
+
filename="weights.npz",
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
@app.get("/rank")
|
| 196 |
def rank(
|
| 197 |
q: str = Query(..., min_length=1, max_length=500, description="query text"),
|