Namhyun Kim commited on
Commit
5b413ef
·
1 Parent(s): 1a85ed1

Fix t-SNE blank plots (infer embedding dims)

Browse files
Files changed (1) hide show
  1. app.py +51 -7
app.py CHANGED
@@ -361,17 +361,47 @@ def apply_filters(
361
 
362
 
363
  def _select_tech_embedding(flat_embedding: np.ndarray | None, tech: str, embed_dim: Optional[int]) -> Optional[np.ndarray]:
364
- if flat_embedding is None or embed_dim is None:
 
 
 
 
 
 
365
  return None
 
 
366
  total = flat_embedding.size
367
  blocks = len(TECH_EXPERT_ORDER)
368
- if total % blocks != 0:
 
 
 
 
 
 
 
 
 
 
 
369
  return None
 
 
 
 
 
 
 
 
 
 
370
  try:
371
- arr = flat_embedding.reshape(blocks, embed_dim)
372
  except ValueError:
373
  return None
374
- tech_idx = TECH_TO_EXPERT_IDX.get(tech)
 
375
  if tech_idx is None or tech_idx >= arr.shape[0]:
376
  return arr.mean(axis=0)
377
  return arr[tech_idx]
@@ -416,7 +446,13 @@ def plot_tsne(
416
  filtered_df = apply_filters(df, tech_filter, snr_filter, mod_filter, mob_filter)
417
  sampled_df = _sample_balanced_by_snr(filtered_df, samples_per_snr, sampling_seed)
418
  if len(sampled_df) < 5:
419
- return None
 
 
 
 
 
 
420
 
421
  sampled_df = sampled_df.copy()
422
  color_column = COLOR_OPTIONS.get(color_label, "snr")
@@ -424,14 +460,22 @@ def plot_tsne(
424
  if representation == "LWM Embedding":
425
  embed_mask = sampled_df["tech_embedding"].apply(lambda x: x is not None)
426
  if embed_mask.sum() < 5:
427
- return None
 
 
 
 
 
 
428
  sampled_df = sampled_df.loc[embed_mask].reset_index(drop=True)
429
  features = np.stack(sampled_df["tech_embedding"].values)
430
  else:
431
  features = build_tsne_raw_vectors(sampled_df["spectrogram"])
432
 
433
  if features.size == 0:
434
- return None
 
 
435
 
436
  features = _standardize_for_tsne(features)
437
 
 
361
 
362
 
363
  def _select_tech_embedding(flat_embedding: np.ndarray | None, tech: str, embed_dim: Optional[int]) -> Optional[np.ndarray]:
364
+ """Extract the technology-specific expert embedding.
365
+
366
+ Some artifacts don't include an explicit embedding dimension hint. In that case,
367
+ infer `embed_dim = total_dim / num_experts` when divisible.
368
+ """
369
+
370
+ if flat_embedding is None:
371
  return None
372
+
373
+ flat_embedding = np.asarray(flat_embedding).reshape(-1)
374
  total = flat_embedding.size
375
  blocks = len(TECH_EXPERT_ORDER)
376
+ if blocks <= 0:
377
+ return None
378
+
379
+ inferred_dim = embed_dim
380
+ if inferred_dim is None:
381
+ if total % blocks != 0:
382
+ return None
383
+ inferred_dim = total // blocks
384
+
385
+ try:
386
+ inferred_dim = int(inferred_dim)
387
+ except (TypeError, ValueError):
388
  return None
389
+ if inferred_dim <= 0:
390
+ return None
391
+
392
+ expected = blocks * inferred_dim
393
+ if expected != total:
394
+ # If metadata is wrong, don't crash; fall back to an even split only if possible.
395
+ if total % blocks != 0:
396
+ return None
397
+ inferred_dim = total // blocks
398
+
399
  try:
400
+ arr = flat_embedding.reshape(blocks, inferred_dim)
401
  except ValueError:
402
  return None
403
+
404
+ tech_idx = TECH_TO_EXPERT_IDX.get(str(tech))
405
  if tech_idx is None or tech_idx >= arr.shape[0]:
406
  return arr.mean(axis=0)
407
  return arr[tech_idx]
 
446
  filtered_df = apply_filters(df, tech_filter, snr_filter, mod_filter, mob_filter)
447
  sampled_df = _sample_balanced_by_snr(filtered_df, samples_per_snr, sampling_seed)
448
  if len(sampled_df) < 5:
449
+ fig = go.Figure()
450
+ fig.update_layout(
451
+ title=f"Not enough samples to plot (n={len(sampled_df)}). Widen filters or increase samples.",
452
+ xaxis=dict(visible=False),
453
+ yaxis=dict(visible=False),
454
+ )
455
+ return fig
456
 
457
  sampled_df = sampled_df.copy()
458
  color_column = COLOR_OPTIONS.get(color_label, "snr")
 
460
  if representation == "LWM Embedding":
461
  embed_mask = sampled_df["tech_embedding"].apply(lambda x: x is not None)
462
  if embed_mask.sum() < 5:
463
+ fig = go.Figure()
464
+ fig.update_layout(
465
+ title="No per-technology embeddings found for the selected filters.",
466
+ xaxis=dict(visible=False),
467
+ yaxis=dict(visible=False),
468
+ )
469
+ return fig
470
  sampled_df = sampled_df.loc[embed_mask].reset_index(drop=True)
471
  features = np.stack(sampled_df["tech_embedding"].values)
472
  else:
473
  features = build_tsne_raw_vectors(sampled_df["spectrogram"])
474
 
475
  if features.size == 0:
476
+ fig = go.Figure()
477
+ fig.update_layout(title="No features available for t-SNE.", xaxis=dict(visible=False), yaxis=dict(visible=False))
478
+ return fig
479
 
480
  features = _standardize_for_tsne(features)
481