luuow Claude Opus 4.7 (1M context) commited on
Commit
6038606
·
1 Parent(s): dda3229

space/train: fix NaN training + 4x speedup

Browse files

Three bugs caught from the first HF Space training run (loss NaN by
step 20, 5s/step):

1. Cross-entropy was being given F = exp(-D) values divided by
temperature, which is wrong: F is in (0, 1] post-exp and small
numerical drift in slogdet pushed F outside that range, causing
the softmax to NaN. Replaced with the standard contrastive form
`logits = -D / temp` (D = Bhattacharyya distance, lower is more
similar) and clamped D to [0, 50] for stability. `bhattacharyya`
renamed to `bhattacharyya_distance` to reflect the new return.

2. nn.Embedding defaulted to float32. Sequential Sgate composition
(one squeeze per word, ≥100 words/abstract) made the resulting
covariance ill-conditioned in float32, and slogdet emitted -inf/NaN
that propagated to the loss. Forced float64 throughout the embedding
and downstream tensors.

3. Each step encoded each abstract once per (query × neg/pos), i.e.
54 encodings/step for 6 queries × (1 pos + 8 negs). Cache the per-
abstract (mu, sigma) once per step (weights change every step so
the cache can't span steps). Drops to 26 encodings/step.

Also lowered default --steps 300 → 100 (with the speedup the
training step now takes ~2 min instead of ~25 min, and 100 steps is
plenty for a 20-doc, 6-query corpus).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Files changed (2) hide show
  1. Dockerfile +1 -1
  2. space/train.py +42 -21
Dockerfile CHANGED
@@ -25,7 +25,7 @@ RUN pip install --upgrade pip \
25
  # runs the InfoNCE + Bhattacharyya trainer, dumps weights.npz consumable
26
  # by the v2 numpy encoder. Falls back to SHA-init at serve time if the
27
  # train step fails so the container always boots.
28
- RUN python -m space.train --out /app/weights.npz --steps 300 \
29
  || (echo "[build] training failed; container will serve sha_init only" && rm -f /app/weights.npz)
30
 
31
  EXPOSE 7860
 
25
  # runs the InfoNCE + Bhattacharyya trainer, dumps weights.npz consumable
26
  # by the v2 numpy encoder. Falls back to SHA-init at serve time if the
27
  # train step fails so the container always boots.
28
+ RUN python -m space.train --out /app/weights.npz --steps 100 \
29
  || (echo "[build] training failed; container will serve sha_init only" && rm -f /app/weights.npz)
30
 
31
  EXPOSE 7860
space/train.py CHANGED
@@ -208,15 +208,25 @@ def encode_torch(
208
  # Bhattacharyya surrogate fidelity (Gaussian-Gaussian)
209
  # ---------------------------------------------------------------------------
210
 
211
- def bhattacharyya(
212
- mu_a: Tensor, sg_a: Tensor, mu_b: Tensor, sg_b: Tensor, ridge: float = 1e-4,
213
  ) -> Tensor:
214
- """F_B(rho_A, rho_B) = exp(-D_B), D_B = (1/8)Δμᵀ V⁻¹ Δμ + 0.5 log(det V / sqrt(det A * det B))
215
- with V = (A + B)/2. Returns scalar in (0, 1]. Always differentiable."""
 
 
 
 
 
 
 
 
 
216
  d = sg_a.shape[0]
217
- V = 0.5 * (sg_a + sg_b) + ridge * torch.eye(d, dtype=sg_a.dtype, device=sg_a.device)
218
- A = sg_a + ridge * torch.eye(d, dtype=sg_a.dtype, device=sg_a.device)
219
- B = sg_b + ridge * torch.eye(d, dtype=sg_a.dtype, device=sg_a.device)
 
220
  delta = mu_a - mu_b
221
  sol = torch.linalg.solve(V, delta)
222
  quad = (delta * sol).sum()
@@ -224,7 +234,7 @@ def bhattacharyya(
224
  log_det_A = torch.linalg.slogdet(A)[1]
225
  log_det_B = torch.linalg.slogdet(B)[1]
226
  D = 0.125 * quad + 0.5 * (log_det_V - 0.5 * (log_det_A + log_det_B))
227
- return torch.exp(-D)
228
 
229
 
230
  # ---------------------------------------------------------------------------
@@ -257,10 +267,13 @@ def train(args: argparse.Namespace) -> None:
257
  vocab = {w: i for i, w in enumerate(sorted(words))}
258
  print(f"[train] vocab |V| = {len(vocab)}", flush=True)
259
 
260
- embedding = nn.Embedding(len(vocab), 4)
 
 
 
261
  with torch.no_grad():
262
  for w, i in vocab.items():
263
- embedding.weight[i] = torch.from_numpy(sha_init_raw(w)).to(embedding.weight.dtype)
264
 
265
  optim = torch.optim.AdamW(embedding.parameters(), lr=args.lr, weight_decay=1e-4)
266
 
@@ -275,26 +288,34 @@ def train(args: argparse.Namespace) -> None:
275
  loss_sum = torch.zeros((), dtype=torch.float64)
276
  loss_components = {"info_nce": 0.0, "photon": 0.0}
277
 
 
 
 
 
 
278
  for query, rel_set in queries:
279
  mu_q, sg_q = encode_torch(query, vocab, embedding)
280
 
281
  # one positive (random pick from relevant set)
282
  pos_id = rng.choice(sorted(rel_set))
283
- mu_p, sg_p = encode_torch(abstracts[pos_id], vocab, embedding)
284
 
285
  # negatives: K random non-relevant ids
286
  negs = rng.choice(
287
- [i for i in all_ids if i not in rel_set], size=min(args.negatives, len(all_ids) - len(rel_set)), replace=False,
 
 
288
  )
289
- f_pos = bhattacharyya(mu_q, sg_q, mu_p, sg_p)
290
- f_negs = torch.stack([
291
- bhattacharyya(mu_q, sg_q, *encode_torch(abstracts[n], vocab, embedding))
292
- for n in negs
293
  ])
294
- # InfoNCE: maximize f_pos / (f_pos + sum f_negs) -> minimize -log(...)
295
- sims = torch.cat([f_pos.unsqueeze(0), f_negs]) / args.temperature
 
 
296
  target = torch.zeros((), dtype=torch.long)
297
- ce = F.cross_entropy(sims.unsqueeze(0), target.unsqueeze(0))
298
  loss_sum = loss_sum + ce
299
 
300
  loss_components["info_nce"] += ce.item()
@@ -359,14 +380,14 @@ def train(args: argparse.Namespace) -> None:
359
  def main() -> None:
360
  ap = argparse.ArgumentParser()
361
  ap.add_argument("--out", type=Path, default=ROOT / "weights.npz")
362
- ap.add_argument("--steps", type=int, default=300)
363
  ap.add_argument("--lr", type=float, default=3e-2)
364
  ap.add_argument("--temperature", type=float, default=0.1)
365
  ap.add_argument("--photon-lambda", type=float, default=1e-2)
366
  ap.add_argument("--negatives", type=int, default=8)
367
  ap.add_argument("--clip", type=float, default=1.0)
368
  ap.add_argument("--seed", type=int, default=42)
369
- ap.add_argument("--log-every", type=int, default=20)
370
  args = ap.parse_args()
371
  train(args)
372
 
 
208
  # Bhattacharyya surrogate fidelity (Gaussian-Gaussian)
209
  # ---------------------------------------------------------------------------
210
 
211
+ def bhattacharyya_distance(
212
+ mu_a: Tensor, sg_a: Tensor, mu_b: Tensor, sg_b: Tensor, ridge: float = 1e-3,
213
  ) -> Tensor:
214
+ """Bhattacharyya distance D_B between two Gaussians (means + covs).
215
+
216
+ D_B = (1/8) Δμᵀ V⁻¹ Δμ + (1/2) log(det V / sqrt(det A · det B)),
217
+ with V = (A + B)/2, A = Σ_a + ridge·I, B = Σ_b + ridge·I.
218
+ Lower = more similar; ≥ 0 for proper SPD inputs.
219
+ Returned clamped to [0, 50] for downstream softmax/exp stability.
220
+
221
+ Used as a contrastive *logit* (-D / temperature) — cheaper and far
222
+ more numerically stable than F_B = exp(-D), which underflows for
223
+ well-separated Gaussians and amplifies slogdet noise.
224
+ """
225
  d = sg_a.shape[0]
226
+ eye = torch.eye(d, dtype=sg_a.dtype, device=sg_a.device)
227
+ A = sg_a + ridge * eye
228
+ B = sg_b + ridge * eye
229
+ V = 0.5 * (A + B)
230
  delta = mu_a - mu_b
231
  sol = torch.linalg.solve(V, delta)
232
  quad = (delta * sol).sum()
 
234
  log_det_A = torch.linalg.slogdet(A)[1]
235
  log_det_B = torch.linalg.slogdet(B)[1]
236
  D = 0.125 * quad + 0.5 * (log_det_V - 0.5 * (log_det_A + log_det_B))
237
+ return torch.clamp(D, min=0.0, max=50.0)
238
 
239
 
240
  # ---------------------------------------------------------------------------
 
267
  vocab = {w: i for i, w in enumerate(sorted(words))}
268
  print(f"[train] vocab |V| = {len(vocab)}", flush=True)
269
 
270
+ # float64 throughout — slogdet of a near-singular squeezed-state covariance
271
+ # in float32 emits NaN that propagates through cross_entropy. Float64 absorbs
272
+ # the conditioning loss from many sequential Sgate compositions.
273
+ embedding = nn.Embedding(len(vocab), 4, dtype=torch.float64)
274
  with torch.no_grad():
275
  for w, i in vocab.items():
276
+ embedding.weight[i] = torch.from_numpy(sha_init_raw(w))
277
 
278
  optim = torch.optim.AdamW(embedding.parameters(), lr=args.lr, weight_decay=1e-4)
279
 
 
288
  loss_sum = torch.zeros((), dtype=torch.float64)
289
  loss_components = {"info_nce": 0.0, "photon": 0.0}
290
 
291
+ # Encode each abstract once per step (was 9× per query before — 54
292
+ # encodings/step → 26). Weights change every step so the cache is
293
+ # per-step only, not amortized across steps.
294
+ doc_states = {a: encode_torch(t, vocab, embedding) for a, t in abstracts.items()}
295
+
296
  for query, rel_set in queries:
297
  mu_q, sg_q = encode_torch(query, vocab, embedding)
298
 
299
  # one positive (random pick from relevant set)
300
  pos_id = rng.choice(sorted(rel_set))
301
+ mu_p, sg_p = doc_states[pos_id]
302
 
303
  # negatives: K random non-relevant ids
304
  negs = rng.choice(
305
+ [i for i in all_ids if i not in rel_set],
306
+ size=min(args.negatives, len(all_ids) - len(rel_set)),
307
+ replace=False,
308
  )
309
+ d_pos = bhattacharyya_distance(mu_q, sg_q, mu_p, sg_p)
310
+ d_negs = torch.stack([
311
+ bhattacharyya_distance(mu_q, sg_q, *doc_states[n]) for n in negs
 
312
  ])
313
+ # Use distance directly as a (negative) logit. Smaller D larger
314
+ # logit higher probability for that class. Standard contrastive
315
+ # form: cross_entropy(-D / temp, target=positive).
316
+ logits = -torch.cat([d_pos.unsqueeze(0), d_negs]) / args.temperature
317
  target = torch.zeros((), dtype=torch.long)
318
+ ce = F.cross_entropy(logits.unsqueeze(0), target.unsqueeze(0))
319
  loss_sum = loss_sum + ce
320
 
321
  loss_components["info_nce"] += ce.item()
 
380
  def main() -> None:
381
  ap = argparse.ArgumentParser()
382
  ap.add_argument("--out", type=Path, default=ROOT / "weights.npz")
383
+ ap.add_argument("--steps", type=int, default=100)
384
  ap.add_argument("--lr", type=float, default=3e-2)
385
  ap.add_argument("--temperature", type=float, default=0.1)
386
  ap.add_argument("--photon-lambda", type=float, default=1e-2)
387
  ap.add_argument("--negatives", type=int, default=8)
388
  ap.add_argument("--clip", type=float, default=1.0)
389
  ap.add_argument("--seed", type=int, default=42)
390
+ ap.add_argument("--log-every", type=int, default=10)
391
  args = ap.parse_args()
392
  train(args)
393