Spaces:
Running
space/train: fix NaN training + 4x speedup
Browse filesThree bugs caught from the first HF Space training run (loss NaN by
step 20, 5s/step):
1. Cross-entropy was being given F = exp(-D) values divided by
temperature, which is wrong: F is in (0, 1] post-exp and small
numerical drift in slogdet pushed F outside that range, causing
the softmax to NaN. Replaced with the standard contrastive form
`logits = -D / temp` (D = Bhattacharyya distance, lower is more
similar) and clamped D to [0, 50] for stability. `bhattacharyya`
renamed to `bhattacharyya_distance` to reflect the new return.
2. nn.Embedding defaulted to float32. Sequential Sgate composition
(one squeeze per word, ≥100 words/abstract) made the resulting
covariance ill-conditioned in float32, and slogdet emitted -inf/NaN
that propagated to the loss. Forced float64 throughout the embedding
and downstream tensors.
3. Each step encoded each abstract once per (query × neg/pos), i.e.
54 encodings/step for 6 queries × (1 pos + 8 negs). Cache the per-
abstract (mu, sigma) once per step (weights change every step so
the cache can't span steps). Drops to 26 encodings/step.
Also lowered default --steps 300 → 100 (with the speedup the
training step now takes ~2 min instead of ~25 min, and 100 steps is
plenty for a 20-doc, 6-query corpus).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- Dockerfile +1 -1
- space/train.py +42 -21
|
@@ -25,7 +25,7 @@ RUN pip install --upgrade pip \
|
|
| 25 |
# runs the InfoNCE + Bhattacharyya trainer, dumps weights.npz consumable
|
| 26 |
# by the v2 numpy encoder. Falls back to SHA-init at serve time if the
|
| 27 |
# train step fails so the container always boots.
|
| 28 |
-
RUN python -m space.train --out /app/weights.npz --steps
|
| 29 |
|| (echo "[build] training failed; container will serve sha_init only" && rm -f /app/weights.npz)
|
| 30 |
|
| 31 |
EXPOSE 7860
|
|
|
|
| 25 |
# runs the InfoNCE + Bhattacharyya trainer, dumps weights.npz consumable
|
| 26 |
# by the v2 numpy encoder. Falls back to SHA-init at serve time if the
|
| 27 |
# train step fails so the container always boots.
|
| 28 |
+
RUN python -m space.train --out /app/weights.npz --steps 100 \
|
| 29 |
|| (echo "[build] training failed; container will serve sha_init only" && rm -f /app/weights.npz)
|
| 30 |
|
| 31 |
EXPOSE 7860
|
|
@@ -208,15 +208,25 @@ def encode_torch(
|
|
| 208 |
# Bhattacharyya surrogate fidelity (Gaussian-Gaussian)
|
| 209 |
# ---------------------------------------------------------------------------
|
| 210 |
|
| 211 |
-
def
|
| 212 |
-
mu_a: Tensor, sg_a: Tensor, mu_b: Tensor, sg_b: Tensor, ridge: float = 1e-
|
| 213 |
) -> Tensor:
|
| 214 |
-
"""
|
| 215 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
d = sg_a.shape[0]
|
| 217 |
-
|
| 218 |
-
A = sg_a + ridge *
|
| 219 |
-
B = sg_b + ridge *
|
|
|
|
| 220 |
delta = mu_a - mu_b
|
| 221 |
sol = torch.linalg.solve(V, delta)
|
| 222 |
quad = (delta * sol).sum()
|
|
@@ -224,7 +234,7 @@ def bhattacharyya(
|
|
| 224 |
log_det_A = torch.linalg.slogdet(A)[1]
|
| 225 |
log_det_B = torch.linalg.slogdet(B)[1]
|
| 226 |
D = 0.125 * quad + 0.5 * (log_det_V - 0.5 * (log_det_A + log_det_B))
|
| 227 |
-
return torch.
|
| 228 |
|
| 229 |
|
| 230 |
# ---------------------------------------------------------------------------
|
|
@@ -257,10 +267,13 @@ def train(args: argparse.Namespace) -> None:
|
|
| 257 |
vocab = {w: i for i, w in enumerate(sorted(words))}
|
| 258 |
print(f"[train] vocab |V| = {len(vocab)}", flush=True)
|
| 259 |
|
| 260 |
-
|
|
|
|
|
|
|
|
|
|
| 261 |
with torch.no_grad():
|
| 262 |
for w, i in vocab.items():
|
| 263 |
-
embedding.weight[i] = torch.from_numpy(sha_init_raw(w))
|
| 264 |
|
| 265 |
optim = torch.optim.AdamW(embedding.parameters(), lr=args.lr, weight_decay=1e-4)
|
| 266 |
|
|
@@ -275,26 +288,34 @@ def train(args: argparse.Namespace) -> None:
|
|
| 275 |
loss_sum = torch.zeros((), dtype=torch.float64)
|
| 276 |
loss_components = {"info_nce": 0.0, "photon": 0.0}
|
| 277 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
for query, rel_set in queries:
|
| 279 |
mu_q, sg_q = encode_torch(query, vocab, embedding)
|
| 280 |
|
| 281 |
# one positive (random pick from relevant set)
|
| 282 |
pos_id = rng.choice(sorted(rel_set))
|
| 283 |
-
mu_p, sg_p =
|
| 284 |
|
| 285 |
# negatives: K random non-relevant ids
|
| 286 |
negs = rng.choice(
|
| 287 |
-
[i for i in all_ids if i not in rel_set],
|
|
|
|
|
|
|
| 288 |
)
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
for n in negs
|
| 293 |
])
|
| 294 |
-
#
|
| 295 |
-
|
|
|
|
|
|
|
| 296 |
target = torch.zeros((), dtype=torch.long)
|
| 297 |
-
ce = F.cross_entropy(
|
| 298 |
loss_sum = loss_sum + ce
|
| 299 |
|
| 300 |
loss_components["info_nce"] += ce.item()
|
|
@@ -359,14 +380,14 @@ def train(args: argparse.Namespace) -> None:
|
|
| 359 |
def main() -> None:
|
| 360 |
ap = argparse.ArgumentParser()
|
| 361 |
ap.add_argument("--out", type=Path, default=ROOT / "weights.npz")
|
| 362 |
-
ap.add_argument("--steps", type=int, default=
|
| 363 |
ap.add_argument("--lr", type=float, default=3e-2)
|
| 364 |
ap.add_argument("--temperature", type=float, default=0.1)
|
| 365 |
ap.add_argument("--photon-lambda", type=float, default=1e-2)
|
| 366 |
ap.add_argument("--negatives", type=int, default=8)
|
| 367 |
ap.add_argument("--clip", type=float, default=1.0)
|
| 368 |
ap.add_argument("--seed", type=int, default=42)
|
| 369 |
-
ap.add_argument("--log-every", type=int, default=
|
| 370 |
args = ap.parse_args()
|
| 371 |
train(args)
|
| 372 |
|
|
|
|
| 208 |
# Bhattacharyya surrogate fidelity (Gaussian-Gaussian)
|
| 209 |
# ---------------------------------------------------------------------------
|
| 210 |
|
| 211 |
+
def bhattacharyya_distance(
|
| 212 |
+
mu_a: Tensor, sg_a: Tensor, mu_b: Tensor, sg_b: Tensor, ridge: float = 1e-3,
|
| 213 |
) -> Tensor:
|
| 214 |
+
"""Bhattacharyya distance D_B between two Gaussians (means + covs).
|
| 215 |
+
|
| 216 |
+
D_B = (1/8) Δμᵀ V⁻¹ Δμ + (1/2) log(det V / sqrt(det A · det B)),
|
| 217 |
+
with V = (A + B)/2, A = Σ_a + ridge·I, B = Σ_b + ridge·I.
|
| 218 |
+
Lower = more similar; ≥ 0 for proper SPD inputs.
|
| 219 |
+
Returned clamped to [0, 50] for downstream softmax/exp stability.
|
| 220 |
+
|
| 221 |
+
Used as a contrastive *logit* (-D / temperature) — cheaper and far
|
| 222 |
+
more numerically stable than F_B = exp(-D), which underflows for
|
| 223 |
+
well-separated Gaussians and amplifies slogdet noise.
|
| 224 |
+
"""
|
| 225 |
d = sg_a.shape[0]
|
| 226 |
+
eye = torch.eye(d, dtype=sg_a.dtype, device=sg_a.device)
|
| 227 |
+
A = sg_a + ridge * eye
|
| 228 |
+
B = sg_b + ridge * eye
|
| 229 |
+
V = 0.5 * (A + B)
|
| 230 |
delta = mu_a - mu_b
|
| 231 |
sol = torch.linalg.solve(V, delta)
|
| 232 |
quad = (delta * sol).sum()
|
|
|
|
| 234 |
log_det_A = torch.linalg.slogdet(A)[1]
|
| 235 |
log_det_B = torch.linalg.slogdet(B)[1]
|
| 236 |
D = 0.125 * quad + 0.5 * (log_det_V - 0.5 * (log_det_A + log_det_B))
|
| 237 |
+
return torch.clamp(D, min=0.0, max=50.0)
|
| 238 |
|
| 239 |
|
| 240 |
# ---------------------------------------------------------------------------
|
|
|
|
| 267 |
vocab = {w: i for i, w in enumerate(sorted(words))}
|
| 268 |
print(f"[train] vocab |V| = {len(vocab)}", flush=True)
|
| 269 |
|
| 270 |
+
# float64 throughout — slogdet of a near-singular squeezed-state covariance
|
| 271 |
+
# in float32 emits NaN that propagates through cross_entropy. Float64 absorbs
|
| 272 |
+
# the conditioning loss from many sequential Sgate compositions.
|
| 273 |
+
embedding = nn.Embedding(len(vocab), 4, dtype=torch.float64)
|
| 274 |
with torch.no_grad():
|
| 275 |
for w, i in vocab.items():
|
| 276 |
+
embedding.weight[i] = torch.from_numpy(sha_init_raw(w))
|
| 277 |
|
| 278 |
optim = torch.optim.AdamW(embedding.parameters(), lr=args.lr, weight_decay=1e-4)
|
| 279 |
|
|
|
|
| 288 |
loss_sum = torch.zeros((), dtype=torch.float64)
|
| 289 |
loss_components = {"info_nce": 0.0, "photon": 0.0}
|
| 290 |
|
| 291 |
+
# Encode each abstract once per step (was 9× per query before — 54
|
| 292 |
+
# encodings/step → 26). Weights change every step so the cache is
|
| 293 |
+
# per-step only, not amortized across steps.
|
| 294 |
+
doc_states = {a: encode_torch(t, vocab, embedding) for a, t in abstracts.items()}
|
| 295 |
+
|
| 296 |
for query, rel_set in queries:
|
| 297 |
mu_q, sg_q = encode_torch(query, vocab, embedding)
|
| 298 |
|
| 299 |
# one positive (random pick from relevant set)
|
| 300 |
pos_id = rng.choice(sorted(rel_set))
|
| 301 |
+
mu_p, sg_p = doc_states[pos_id]
|
| 302 |
|
| 303 |
# negatives: K random non-relevant ids
|
| 304 |
negs = rng.choice(
|
| 305 |
+
[i for i in all_ids if i not in rel_set],
|
| 306 |
+
size=min(args.negatives, len(all_ids) - len(rel_set)),
|
| 307 |
+
replace=False,
|
| 308 |
)
|
| 309 |
+
d_pos = bhattacharyya_distance(mu_q, sg_q, mu_p, sg_p)
|
| 310 |
+
d_negs = torch.stack([
|
| 311 |
+
bhattacharyya_distance(mu_q, sg_q, *doc_states[n]) for n in negs
|
|
|
|
| 312 |
])
|
| 313 |
+
# Use distance directly as a (negative) logit. Smaller D → larger
|
| 314 |
+
# logit → higher probability for that class. Standard contrastive
|
| 315 |
+
# form: cross_entropy(-D / temp, target=positive).
|
| 316 |
+
logits = -torch.cat([d_pos.unsqueeze(0), d_negs]) / args.temperature
|
| 317 |
target = torch.zeros((), dtype=torch.long)
|
| 318 |
+
ce = F.cross_entropy(logits.unsqueeze(0), target.unsqueeze(0))
|
| 319 |
loss_sum = loss_sum + ce
|
| 320 |
|
| 321 |
loss_components["info_nce"] += ce.item()
|
|
|
|
| 380 |
def main() -> None:
|
| 381 |
ap = argparse.ArgumentParser()
|
| 382 |
ap.add_argument("--out", type=Path, default=ROOT / "weights.npz")
|
| 383 |
+
ap.add_argument("--steps", type=int, default=100)
|
| 384 |
ap.add_argument("--lr", type=float, default=3e-2)
|
| 385 |
ap.add_argument("--temperature", type=float, default=0.1)
|
| 386 |
ap.add_argument("--photon-lambda", type=float, default=1e-2)
|
| 387 |
ap.add_argument("--negatives", type=int, default=8)
|
| 388 |
ap.add_argument("--clip", type=float, default=1.0)
|
| 389 |
ap.add_argument("--seed", type=int, default=42)
|
| 390 |
+
ap.add_argument("--log-every", type=int, default=10)
|
| 391 |
args = ap.parse_args()
|
| 392 |
train(args)
|
| 393 |
|