iljung1106 commited on
Commit
89b3ad1
·
1 Parent(s): 059ebcb

Removed editing feature and make additional prototype temporary.

Browse files
Files changed (2) hide show
  1. app/proto_db.py +0 -29
  2. webui_gradio.py +8 -143
app/proto_db.py CHANGED
@@ -43,35 +43,6 @@ class PrototypeDB:
43
  self.labels = torch.cat([self.labels, torch.tensor([lid], dtype=torch.long)], dim=0)
44
  return lid
45
 
46
- def remove_prototype(self, proto_idx: int) -> bool:
47
- """Remove a prototype by its index. Returns True if removed."""
48
- if proto_idx < 0 or proto_idx >= self.centers.shape[0]:
49
- return False
50
- mask = torch.ones(self.centers.shape[0], dtype=torch.bool)
51
- mask[proto_idx] = False
52
- self.centers = self.centers[mask]
53
- self.labels = self.labels[mask]
54
- return True
55
-
56
- def list_prototypes(self) -> List[Tuple[int, str, int]]:
57
- """Returns list of (proto_idx, label_name, label_id) for all prototypes."""
58
- result = []
59
- for i in range(self.centers.shape[0]):
60
- lid = int(self.labels[i].item())
61
- name = self.id_to_name(lid)
62
- result.append((i, name, lid))
63
- return result
64
-
65
- def get_label_summary(self) -> Dict[str, int]:
66
- """Returns {label_name: count} for prototype counts per label."""
67
- from collections import Counter
68
- counts: Dict[str, int] = {}
69
- for i in range(self.labels.shape[0]):
70
- lid = int(self.labels[i].item())
71
- name = self.id_to_name(lid)
72
- counts[name] = counts.get(name, 0) + 1
73
- return counts
74
-
75
  def save(self, path: Optional[str | Path] = None) -> Path:
76
  out = Path(path) if path is not None else self.source_path
77
  if out is None:
 
43
  self.labels = torch.cat([self.labels, torch.tensor([lid], dtype=torch.long)], dim=0)
44
  return lid
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  def save(self, path: Optional[str | Path] = None) -> Path:
47
  out = Path(path) if path is not None else self.source_path
48
  if out is None:
webui_gradio.py CHANGED
@@ -18,20 +18,6 @@ except Exception: # noqa: BLE001
18
  # Detect if running on HF Spaces (ZeroGPU requires special handling)
19
  _ON_SPACES = bool(os.getenv("SPACE_ID") or os.getenv("HF_SPACE"))
20
 
21
- # ─────────────────────────────────────────────────────────────────────────────
22
- # Authentication for prototype management
23
- # Set ADMIN_PASSWORD env var, or it defaults to "admin" (change in production!)
24
- # ─────────────────────────────────────────────────────────────────────────────
25
- ADMIN_PASSWORD = os.getenv("ADMIN_PASSWORD", "admin")
26
- APPROVED_USERS: dict[str, str] = {
27
- "admin": ADMIN_PASSWORD,
28
- }
29
- # Add more users via APPROVED_USERS_CSV env var: "user1:pass1,user2:pass2"
30
- for pair in os.getenv("APPROVED_USERS_CSV", "").split(","):
31
- if ":" in pair:
32
- u, p = pair.split(":", 1)
33
- APPROVED_USERS[u.strip()] = p.strip()
34
-
35
  def _patch_fastapi_starlette_middleware_unpack() -> None:
36
  """
37
  Work around FastAPI/Starlette version mismatches where Starlette's Middleware
@@ -333,8 +319,8 @@ def classify(
333
  def add_prototype(
334
  label_name: str,
335
  images: List,
336
- save_back: bool,
337
  ) -> str:
 
338
  if APP_STATE.lm is None or APP_STATE.db is None:
339
  return "❌ Click **Load** first."
340
  lm = APP_STATE.lm
@@ -375,57 +361,7 @@ def add_prototype(
375
  center = torch.stack(zs, dim=0).mean(dim=0)
376
  lid = db.add_center(label_name, center)
377
 
378
- msg = f"✅ Added prototype for `{label_name}` (label_id={lid}). DB now N={db.centers.shape[0]}."
379
-
380
- if save_back:
381
- out_path = db.save(APP_STATE.proto_path)
382
- msg += f" Saved to `{out_path}`."
383
- return msg
384
-
385
-
386
- def save_db_as(path_text: str) -> str:
387
- if APP_STATE.db is None:
388
- return "❌ Nothing loaded."
389
- out = (path_text or "").strip()
390
- if not out:
391
- return "❌ Provide an output path."
392
- out_path = Path(out)
393
- if not out_path.is_absolute():
394
- out_path = (CKPT_DIR / out_path).resolve()
395
- APP_STATE.db.save(out_path)
396
- APP_STATE.proto_path = str(out_path)
397
- return f"✅ Saved prototype DB to `{out_path}`"
398
-
399
-
400
- def list_prototypes() -> Tuple[str, List[List]]:
401
- """Returns (status, table_rows) with all prototypes."""
402
- if APP_STATE.db is None:
403
- return "❌ Click **Load** first.", []
404
- db = APP_STATE.db
405
- protos = db.list_prototypes()
406
- # Group by label for summary
407
- summary = db.get_label_summary()
408
- summary_text = ", ".join(f"{k}: {v}" for k, v in sorted(summary.items()))
409
- rows = [[i, name, lid] for (i, name, lid) in protos]
410
- return f"✅ {len(protos)} prototypes. Summary: {summary_text}", rows
411
-
412
-
413
- def remove_prototype(proto_idx: int, save_back: bool) -> Tuple[str, List[List]]:
414
- """Remove a prototype by index. Returns (status, updated_table)."""
415
- if APP_STATE.db is None:
416
- return "❌ Click **Load** first.", []
417
- db = APP_STATE.db
418
- idx = int(proto_idx)
419
- if not db.remove_prototype(idx):
420
- return f"❌ Invalid index: {idx}", [[i, name, lid] for (i, name, lid) in db.list_prototypes()]
421
-
422
- msg = f"✅ Removed prototype at index {idx}. DB now N={db.centers.shape[0]}."
423
- if save_back and APP_STATE.proto_path:
424
- db.save(APP_STATE.proto_path)
425
- msg += f" Saved to `{APP_STATE.proto_path}`."
426
-
427
- rows = [[i, name, lid] for (i, name, lid) in db.list_prototypes()]
428
- return msg, rows
429
 
430
 
431
  def build_ui() -> gr.Blocks:
@@ -457,63 +393,17 @@ def build_ui() -> gr.Blocks:
457
  table = gr.Dataframe(headers=["label", "cosine_sim"], datatype=["str", "number"], interactive=False)
458
  run_btn.click(classify, inputs=[whole, topk], outputs=[out_status, table, face_prev, eyes_prev])
459
 
460
- with gr.Tab("Manage Prototypes"):
461
  gr.Markdown(
462
- "### Prototype Management\n"
463
- "**Authentication required** to add or remove prototypes. "
464
- "Enter your credentials below."
465
- )
466
-
467
- with gr.Row():
468
- auth_user = gr.Textbox(label="Username", placeholder="admin")
469
- auth_pass = gr.Textbox(label="Password", type="password", placeholder="password")
470
-
471
- def _check_auth(user: str, pwd: str) -> bool:
472
- u = (user or "").strip()
473
- p = (pwd or "").strip()
474
- return APPROVED_USERS.get(u) == p and p != ""
475
-
476
- # ─────────────── List Prototypes ───────────────
477
- gr.Markdown("---\n#### List Prototypes")
478
- list_btn = gr.Button("Refresh List", variant="secondary")
479
- list_status = gr.Markdown("")
480
- proto_table = gr.Dataframe(
481
- headers=["index", "label", "label_id"],
482
- datatype=["number", "str", "number"],
483
- interactive=False,
484
- )
485
- list_btn.click(list_prototypes, inputs=[], outputs=[list_status, proto_table])
486
-
487
- # ─────────────── Remove Prototype ───────────────
488
- gr.Markdown("---\n#### Remove Prototype")
489
- with gr.Row():
490
- remove_idx = gr.Number(label="Prototype index to remove", precision=0)
491
- remove_save = gr.Checkbox(value=True, label="Save after removal")
492
- remove_btn = gr.Button("Remove Prototype", variant="stop")
493
- remove_status = gr.Markdown("")
494
-
495
- def _remove_with_auth(user, pwd, idx, save_back):
496
- if not _check_auth(user, pwd):
497
- return "❌ Authentication failed. Check username/password.", []
498
- return remove_prototype(idx, save_back)
499
-
500
- remove_btn.click(
501
- _remove_with_auth,
502
- inputs=[auth_user, auth_pass, remove_idx, remove_save],
503
- outputs=[remove_status, proto_table],
504
- )
505
-
506
- # ─────────────── Add Prototype ───────────────
507
- gr.Markdown("---\n#### Add New Prototype")
508
- gr.Markdown(
509
- "Add a new prototype by averaging embeddings of uploaded whole images. "
510
  "Multiple prototypes per label are allowed."
511
  )
512
  label = gr.Textbox(label="Label name (artist)", placeholder="e.g. new_artist")
513
  imgs = gr.Gallery(label="Whole images (1+)", columns=4, rows=2, height=240, allow_preview=True)
514
  uploader = gr.Files(label="Upload image files (whole)", file_types=["image"], file_count="multiple")
515
- add_save_back = gr.Checkbox(value=True, label="Save back to prototype DB after adding")
516
- add_btn = gr.Button("Add Prototype", variant="primary")
517
  add_status = gr.Markdown("")
518
 
519
  def _files_to_gallery(files):
@@ -528,33 +418,8 @@ def build_ui() -> gr.Blocks:
528
  continue
529
  return out
530
 
531
- def _add_with_auth(user, pwd, lbl, images, save_back):
532
- if not _check_auth(user, pwd):
533
- return "❌ Authentication failed. Check username/password."
534
- return add_prototype(lbl, images, save_back)
535
-
536
  uploader.change(_files_to_gallery, inputs=[uploader], outputs=[imgs])
537
- add_btn.click(
538
- _add_with_auth,
539
- inputs=[auth_user, auth_pass, label, imgs, add_save_back],
540
- outputs=[add_status],
541
- )
542
-
543
- # ─────────────── Save DB As ───────────────
544
- gr.Markdown("---\n#### Save DB As (optional)")
545
- with gr.Row():
546
- save_path = gr.Textbox(
547
- label="Output path (relative paths go under ./checkpoints_style/)",
548
- placeholder="prototypes_custom.pt",
549
- )
550
- save_btn = gr.Button("Save As")
551
-
552
- def _save_with_auth(user, pwd, path):
553
- if not _check_auth(user, pwd):
554
- return "❌ Authentication failed. Check username/password."
555
- return save_db_as(path)
556
-
557
- save_btn.click(_save_with_auth, inputs=[auth_user, auth_pass, save_path], outputs=[add_status])
558
 
559
  return demo
560
 
 
18
  # Detect if running on HF Spaces (ZeroGPU requires special handling)
19
  _ON_SPACES = bool(os.getenv("SPACE_ID") or os.getenv("HF_SPACE"))
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  def _patch_fastapi_starlette_middleware_unpack() -> None:
22
  """
23
  Work around FastAPI/Starlette version mismatches where Starlette's Middleware
 
319
  def add_prototype(
320
  label_name: str,
321
  images: List,
 
322
  ) -> str:
323
+ """Add a temporary prototype (in-memory only, not persisted to disk)."""
324
  if APP_STATE.lm is None or APP_STATE.db is None:
325
  return "❌ Click **Load** first."
326
  lm = APP_STATE.lm
 
361
  center = torch.stack(zs, dim=0).mean(dim=0)
362
  lid = db.add_center(label_name, center)
363
 
364
+ return f"✅ Added temporary prototype for `{label_name}` (label_id={lid}). DB now N={db.centers.shape[0]}. ⚠️ This is session-only and will be lost when the Space restarts."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
365
 
366
 
367
  def build_ui() -> gr.Blocks:
 
393
  table = gr.Dataframe(headers=["label", "cosine_sim"], datatype=["str", "number"], interactive=False)
394
  run_btn.click(classify, inputs=[whole, topk], outputs=[out_status, table, face_prev, eyes_prev])
395
 
396
+ with gr.Tab("Add prototype (temporary)"):
397
  gr.Markdown(
398
+ "### ⚠️ Temporary Prototypes Only\n"
399
+ "Add a new prototype by averaging embeddings of uploaded whole images.\n"
400
+ "**These prototypes are session-only** — they will be lost when the Space restarts or goes idle.\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
401
  "Multiple prototypes per label are allowed."
402
  )
403
  label = gr.Textbox(label="Label name (artist)", placeholder="e.g. new_artist")
404
  imgs = gr.Gallery(label="Whole images (1+)", columns=4, rows=2, height=240, allow_preview=True)
405
  uploader = gr.Files(label="Upload image files (whole)", file_types=["image"], file_count="multiple")
406
+ add_btn = gr.Button("Add temporary prototype", variant="primary")
 
407
  add_status = gr.Markdown("")
408
 
409
  def _files_to_gallery(files):
 
418
  continue
419
  return out
420
 
 
 
 
 
 
421
  uploader.change(_files_to_gallery, inputs=[uploader], outputs=[imgs])
422
+ add_btn.click(add_prototype, inputs=[label, imgs], outputs=[add_status])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
 
424
  return demo
425