iljung1106 commited on
Commit
178daad
Β·
1 Parent(s): 0ff521b

Make temporary prototype do mix and k-means embedding.

Browse files
Files changed (1) hide show
  1. webui_gradio.py +120 -21
webui_gradio.py CHANGED
@@ -343,11 +343,59 @@ def _gallery_item_to_pil(item) -> Optional[Image.Image]:
343
  return None
344
 
345
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
  def add_prototype(
347
  label_name: str,
348
  images: List,
 
 
349
  ) -> str:
350
- """Add a temporary prototype (in-memory only, not persisted to disk)."""
 
 
 
 
 
351
  if APP_STATE.lm is None or APP_STATE.db is None:
352
  return "❌ Click **Load** first."
353
  lm = APP_STATE.lm
@@ -360,8 +408,15 @@ def add_prototype(
360
  if not images:
361
  return "❌ Upload at least 1 image."
362
 
363
- zs: List[torch.Tensor] = []
 
 
 
 
 
 
364
  errors: List[str] = []
 
365
  for i, x in enumerate(images):
366
  try:
367
  im = _gallery_item_to_pil(x)
@@ -369,33 +424,71 @@ def add_prototype(
369
  errors.append(f"Image {i}: could not parse format {type(x)}")
370
  continue
371
 
372
- face_pil = None
373
- eyes_pil = None
 
374
  if ex is not None:
375
  rgb = np.array(im.convert("RGB"))
376
  face_rgb, eyes_rgb = ex.extract(rgb)
377
  if face_rgb is not None:
378
- face_pil = Image.fromarray(face_rgb)
379
  if eyes_rgb is not None:
380
- eyes_pil = Image.fromarray(eyes_rgb)
 
 
 
381
 
382
- wt = _pil_to_tensor(im, lm.T_w)
383
- ft = _pil_to_tensor(face_pil, lm.T_f) if face_pil is not None else None
384
- et = _pil_to_tensor(eyes_pil, lm.T_e) if eyes_pil is not None else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
  z = embed_triview(lm, whole=wt, face=ft, eyes=et)
386
  zs.append(z)
387
- except Exception as e:
388
- errors.append(f"Image {i}: {e}")
389
  continue
390
 
391
  if not zs:
392
- err_detail = "; ".join(errors[:3]) if errors else "unknown error"
393
- return f"❌ Could not embed any uploaded images. Details: {err_detail}"
 
 
394
 
395
- center = torch.stack(zs, dim=0).mean(dim=0)
396
- lid = db.add_center(label_name, center)
 
 
 
397
 
398
- return f"βœ… Added temporary prototype for `{label_name}` (label_id={lid}). DB now N={db.centers.shape[0]}. ⚠️ This is session-only and will be lost when the Space restarts."
 
 
 
 
 
 
 
 
 
 
 
 
 
399
 
400
 
401
  def build_ui() -> gr.Blocks:
@@ -430,14 +523,20 @@ def build_ui() -> gr.Blocks:
430
  with gr.Tab("Add prototype (temporary)"):
431
  gr.Markdown(
432
  "### ⚠️ Temporary Prototypes Only\n"
433
- "Add a new prototype by averaging embeddings of uploaded whole images.\n"
434
- "**These prototypes are session-only** β€” they will be lost when the Space restarts or goes idle.\n"
435
- "Multiple prototypes per label are allowed."
 
 
 
436
  )
437
  label = gr.Textbox(label="Label name (artist)", placeholder="e.g. new_artist")
438
  imgs = gr.Gallery(label="Whole images (1+)", columns=4, rows=2, height=240, allow_preview=True)
439
  uploader = gr.Files(label="Upload image files (whole)", file_types=["image"], file_count="multiple")
440
- add_btn = gr.Button("Add temporary prototype", variant="primary")
 
 
 
441
  add_status = gr.Markdown("")
442
 
443
  def _files_to_gallery(files):
@@ -453,7 +552,7 @@ def build_ui() -> gr.Blocks:
453
  return out
454
 
455
  uploader.change(_files_to_gallery, inputs=[uploader], outputs=[imgs])
456
- add_btn.click(add_prototype, inputs=[label, imgs], outputs=[add_status])
457
 
458
  return demo
459
 
 
343
  return None
344
 
345
 
346
+ def _kmeans_cosine(Z: torch.Tensor, K: int, iters: int = 20, seed: int = 42) -> torch.Tensor:
347
+ """
348
+ K-means clustering in cosine space (CPU only).
349
+ Returns K cluster centers (normalized).
350
+ """
351
+ Z = torch.nn.functional.normalize(Z, dim=1)
352
+ N, D = Z.shape
353
+ if N <= K:
354
+ return Z.clone()
355
+
356
+ # Initialize centers randomly
357
+ import random
358
+ random.seed(seed)
359
+ init_idx = random.sample(range(N), K)
360
+ C = Z[init_idx].clone()
361
+
362
+ for _ in range(iters):
363
+ # Assign each point to nearest center
364
+ sim = Z @ C.t()
365
+ assign = sim.argmax(dim=1)
366
+
367
+ # Recompute centers
368
+ new_C = torch.zeros(K, D, dtype=Z.dtype)
369
+ counts = torch.zeros(K, dtype=torch.long)
370
+ for i, c in enumerate(assign.tolist()):
371
+ new_C[c] += Z[i]
372
+ counts[c] += 1
373
+
374
+ # Handle empty clusters
375
+ for k in range(K):
376
+ if counts[k] == 0:
377
+ # Reinitialize from a random point
378
+ new_C[k] = Z[random.randint(0, N - 1)]
379
+ counts[k] = 1
380
+
381
+ C = new_C / counts.unsqueeze(1).clamp_min(1).float()
382
+ C = torch.nn.functional.normalize(C, dim=1)
383
+
384
+ return C
385
+
386
+
387
  def add_prototype(
388
  label_name: str,
389
  images: List,
390
+ k_prototypes: int,
391
+ n_triplets: int,
392
  ) -> str:
393
+ """
394
+ Add temporary prototypes using random triplet combinations and K-means clustering.
395
+ Similar to the eval process: extract views, create random triplets, embed, cluster.
396
+ """
397
+ import random
398
+
399
  if APP_STATE.lm is None or APP_STATE.db is None:
400
  return "❌ Click **Load** first."
401
  lm = APP_STATE.lm
 
408
  if not images:
409
  return "❌ Upload at least 1 image."
410
 
411
+ k_prototypes = max(1, int(k_prototypes))
412
+ n_triplets = max(1, int(n_triplets))
413
+
414
+ # Step 1: Extract whole/face/eyes from all uploaded images
415
+ wholes: List[Image.Image] = []
416
+ faces: List[Image.Image] = []
417
+ eyes_list: List[Image.Image] = []
418
  errors: List[str] = []
419
+
420
  for i, x in enumerate(images):
421
  try:
422
  im = _gallery_item_to_pil(x)
 
424
  errors.append(f"Image {i}: could not parse format {type(x)}")
425
  continue
426
 
427
+ wholes.append(im)
428
+
429
+ # Extract face and eyes
430
  if ex is not None:
431
  rgb = np.array(im.convert("RGB"))
432
  face_rgb, eyes_rgb = ex.extract(rgb)
433
  if face_rgb is not None:
434
+ faces.append(Image.fromarray(face_rgb))
435
  if eyes_rgb is not None:
436
+ eyes_list.append(Image.fromarray(eyes_rgb))
437
+ except Exception as e:
438
+ errors.append(f"Image {i}: {e}")
439
+ continue
440
 
441
+ if not wholes:
442
+ err_detail = "; ".join(errors[:3]) if errors else "unknown error"
443
+ return f"❌ Could not process any images. Details: {err_detail}"
444
+
445
+ # Step 2: Create random triplet combinations
446
+ # If we have fewer faces/eyes than wholes, we still try to make triplets
447
+ triplets: List[Tuple[Image.Image, Optional[Image.Image], Optional[Image.Image]]] = []
448
+ for _ in range(n_triplets):
449
+ w = random.choice(wholes)
450
+ f = random.choice(faces) if faces else None
451
+ e = random.choice(eyes_list) if eyes_list else None
452
+ triplets.append((w, f, e))
453
+
454
+ # Step 3: Embed all triplets
455
+ zs: List[torch.Tensor] = []
456
+ for w, f, e in triplets:
457
+ try:
458
+ wt = _pil_to_tensor(w, lm.T_w)
459
+ ft = _pil_to_tensor(f, lm.T_f) if f is not None else None
460
+ et = _pil_to_tensor(e, lm.T_e) if e is not None else None
461
  z = embed_triview(lm, whole=wt, face=ft, eyes=et)
462
  zs.append(z)
463
+ except Exception:
 
464
  continue
465
 
466
  if not zs:
467
+ return "❌ Could not embed any triplets."
468
+
469
+ Z = torch.stack(zs, dim=0)
470
+ Z = torch.nn.functional.normalize(Z, dim=1)
471
 
472
+ # Step 4: Run K-means to get K prototype centers
473
+ actual_k = min(k_prototypes, len(zs))
474
+ if actual_k < k_prototypes:
475
+ # Not enough embeddings for requested K
476
+ pass
477
 
478
+ centers = _kmeans_cosine(Z, actual_k, iters=20, seed=42)
479
+
480
+ # Step 5: Add all K prototypes to the DB
481
+ added_ids = []
482
+ for center in centers:
483
+ lid = db.add_center(label_name, center)
484
+ added_ids.append(lid)
485
+
486
+ return (
487
+ f"βœ… Added {len(added_ids)} temporary prototype(s) for `{label_name}` "
488
+ f"(from {len(wholes)} images, {len(triplets)} triplets, K-means K={actual_k}). "
489
+ f"DB now N={db.centers.shape[0]}. "
490
+ f"⚠️ Session-only β€” lost on Space restart."
491
+ )
492
 
493
 
494
  def build_ui() -> gr.Blocks:
 
523
  with gr.Tab("Add prototype (temporary)"):
524
  gr.Markdown(
525
  "### ⚠️ Temporary Prototypes Only\n"
526
+ "Add prototypes using random triplet combinations and K-means clustering (same as eval process).\n"
527
+ "1. Upload multiple whole images\n"
528
+ "2. Face/eyes are auto-extracted from each\n"
529
+ "3. Random triplets (whole + face + eyes) are created\n"
530
+ "4. K-means clustering creates K prototype centers\n\n"
531
+ "**These prototypes are session-only** β€” lost when the Space restarts."
532
  )
533
  label = gr.Textbox(label="Label name (artist)", placeholder="e.g. new_artist")
534
  imgs = gr.Gallery(label="Whole images (1+)", columns=4, rows=2, height=240, allow_preview=True)
535
  uploader = gr.Files(label="Upload image files (whole)", file_types=["image"], file_count="multiple")
536
+ with gr.Row():
537
+ k_proto = gr.Slider(1, 8, value=4, step=1, label="K (prototypes to create)")
538
+ n_trips = gr.Slider(4, 64, value=16, step=4, label="N (random triplets to sample)")
539
+ add_btn = gr.Button("Add temporary prototypes", variant="primary")
540
  add_status = gr.Markdown("")
541
 
542
  def _files_to_gallery(files):
 
552
  return out
553
 
554
  uploader.change(_files_to_gallery, inputs=[uploader], outputs=[imgs])
555
+ add_btn.click(add_prototype, inputs=[label, imgs, k_proto, n_trips], outputs=[add_status])
556
 
557
  return demo
558