iljung1106 commited on
Commit
059ebcb
Β·
1 Parent(s): b04d768

add prototype editing feature

Browse files
Files changed (3) hide show
  1. app/model_io.py +21 -4
  2. app/proto_db.py +29 -0
  3. webui_gradio.py +121 -9
app/model_io.py CHANGED
@@ -1,11 +1,15 @@
1
  from __future__ import annotations
2
 
 
3
  from dataclasses import dataclass
4
  from pathlib import Path
5
  from typing import Optional, Tuple
6
 
7
  import torch
8
 
 
 
 
9
 
10
  @dataclass(frozen=True)
11
  class LoadedModel:
@@ -21,6 +25,9 @@ class LoadedModel:
21
  def _pick_device(device: str) -> torch.device:
22
  if device.strip().lower() == "cpu":
23
  return torch.device("cpu")
 
 
 
24
  if torch.cuda.is_available():
25
  return torch.device("cuda")
26
  return torch.device("cpu")
@@ -41,9 +48,13 @@ def load_style_model(
41
  if not ckpt_path.exists():
42
  raise FileNotFoundError(str(ckpt_path))
43
 
44
- dev = _pick_device("cpu" if device == "auto" else device)
45
- if device == "auto":
 
 
46
  dev = _pick_device("cuda" if torch.cuda.is_available() else "cpu")
 
 
47
 
48
  ck = torch.load(str(ckpt_path), map_location="cpu")
49
  meta = ck.get("meta", {}) if isinstance(ck, dict) else {}
@@ -103,8 +114,14 @@ def embed_triview(
103
  # Use lazy dtype detection to avoid CUDA init at import time (ZeroGPU compatibility)
104
  import train_style_ddp as _ts
105
  _dtype = _ts._get_amp_dtype() if hasattr(_ts, "_get_amp_dtype") else torch.float16
106
- with torch.no_grad(), torch.amp.autocast("cuda", dtype=_dtype, enabled=(lm.device.type == "cuda")):
107
- z, _, _ = lm.model(views, masks)
 
 
 
 
 
 
108
  z = torch.nn.functional.normalize(z.float(), dim=1)
109
  return z.squeeze(0).detach().cpu()
110
 
 
1
  from __future__ import annotations
2
 
3
+ import os
4
  from dataclasses import dataclass
5
  from pathlib import Path
6
  from typing import Optional, Tuple
7
 
8
  import torch
9
 
10
+ # ZeroGPU on HF Spaces: CUDA must not be initialized in the main process
11
+ _ON_SPACES = bool(os.getenv("SPACE_ID") or os.getenv("HF_SPACE"))
12
+
13
 
14
  @dataclass(frozen=True)
15
  class LoadedModel:
 
25
  def _pick_device(device: str) -> torch.device:
26
  if device.strip().lower() == "cpu":
27
  return torch.device("cpu")
28
+ if _ON_SPACES:
29
+ # ZeroGPU: can't init CUDA in main process
30
+ return torch.device("cpu")
31
  if torch.cuda.is_available():
32
  return torch.device("cuda")
33
  return torch.device("cpu")
 
48
  if not ckpt_path.exists():
49
  raise FileNotFoundError(str(ckpt_path))
50
 
51
+ # On Spaces, always use CPU (ZeroGPU forbids CUDA in main process)
52
+ if _ON_SPACES:
53
+ dev = torch.device("cpu")
54
+ elif device == "auto":
55
  dev = _pick_device("cuda" if torch.cuda.is_available() else "cpu")
56
+ else:
57
+ dev = _pick_device(device)
58
 
59
  ck = torch.load(str(ckpt_path), map_location="cpu")
60
  meta = ck.get("meta", {}) if isinstance(ck, dict) else {}
 
114
  # Use lazy dtype detection to avoid CUDA init at import time (ZeroGPU compatibility)
115
  import train_style_ddp as _ts
116
  _dtype = _ts._get_amp_dtype() if hasattr(_ts, "_get_amp_dtype") else torch.float16
117
+ # On CPU or Spaces, skip autocast entirely to avoid touching CUDA
118
+ use_amp = (lm.device.type == "cuda") and not _ON_SPACES
119
+ if use_amp:
120
+ with torch.no_grad(), torch.amp.autocast("cuda", dtype=_dtype, enabled=True):
121
+ z, _, _ = lm.model(views, masks)
122
+ else:
123
+ with torch.no_grad():
124
+ z, _, _ = lm.model(views, masks)
125
  z = torch.nn.functional.normalize(z.float(), dim=1)
126
  return z.squeeze(0).detach().cpu()
127
 
app/proto_db.py CHANGED
@@ -43,6 +43,35 @@ class PrototypeDB:
43
  self.labels = torch.cat([self.labels, torch.tensor([lid], dtype=torch.long)], dim=0)
44
  return lid
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  def save(self, path: Optional[str | Path] = None) -> Path:
47
  out = Path(path) if path is not None else self.source_path
48
  if out is None:
 
43
  self.labels = torch.cat([self.labels, torch.tensor([lid], dtype=torch.long)], dim=0)
44
  return lid
45
 
46
+ def remove_prototype(self, proto_idx: int) -> bool:
47
+ """Remove a prototype by its index. Returns True if removed."""
48
+ if proto_idx < 0 or proto_idx >= self.centers.shape[0]:
49
+ return False
50
+ mask = torch.ones(self.centers.shape[0], dtype=torch.bool)
51
+ mask[proto_idx] = False
52
+ self.centers = self.centers[mask]
53
+ self.labels = self.labels[mask]
54
+ return True
55
+
56
+ def list_prototypes(self) -> List[Tuple[int, str, int]]:
57
+ """Returns list of (proto_idx, label_name, label_id) for all prototypes."""
58
+ result = []
59
+ for i in range(self.centers.shape[0]):
60
+ lid = int(self.labels[i].item())
61
+ name = self.id_to_name(lid)
62
+ result.append((i, name, lid))
63
+ return result
64
+
65
+ def get_label_summary(self) -> Dict[str, int]:
66
+ """Returns {label_name: count} for prototype counts per label."""
67
+ from collections import Counter
68
+ counts: Dict[str, int] = {}
69
+ for i in range(self.labels.shape[0]):
70
+ lid = int(self.labels[i].item())
71
+ name = self.id_to_name(lid)
72
+ counts[name] = counts.get(name, 0) + 1
73
+ return counts
74
+
75
  def save(self, path: Optional[str | Path] = None) -> Path:
76
  out = Path(path) if path is not None else self.source_path
77
  if out is None:
webui_gradio.py CHANGED
@@ -18,6 +18,20 @@ except Exception: # noqa: BLE001
18
  # Detect if running on HF Spaces (ZeroGPU requires special handling)
19
  _ON_SPACES = bool(os.getenv("SPACE_ID") or os.getenv("HF_SPACE"))
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  def _patch_fastapi_starlette_middleware_unpack() -> None:
22
  """
23
  Work around FastAPI/Starlette version mismatches where Starlette's Middleware
@@ -383,6 +397,37 @@ def save_db_as(path_text: str) -> str:
383
  return f"βœ… Saved prototype DB to `{out_path}`"
384
 
385
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
  def build_ui() -> gr.Blocks:
387
  ckpts = _list_ckpt_files(CKPT_DIR)
388
  protos = _list_proto_files(CKPT_DIR)
@@ -412,16 +457,63 @@ def build_ui() -> gr.Blocks:
412
  table = gr.Dataframe(headers=["label", "cosine_sim"], datatype=["str", "number"], interactive=False)
413
  run_btn.click(classify, inputs=[whole, topk], outputs=[out_status, table, face_prev, eyes_prev])
414
 
415
- with gr.Tab("Add prototype"):
416
  gr.Markdown(
417
- "Add a new prototype to the loaded prototype DB by averaging embeddings of uploaded whole images.\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
418
  "Multiple prototypes per label are allowed."
419
  )
420
  label = gr.Textbox(label="Label name (artist)", placeholder="e.g. new_artist")
421
  imgs = gr.Gallery(label="Whole images (1+)", columns=4, rows=2, height=240, allow_preview=True)
422
  uploader = gr.Files(label="Upload image files (whole)", file_types=["image"], file_count="multiple")
423
- save_back = gr.Checkbox(value=True, label="Save back to selected prototype DB file after adding")
424
- add_btn = gr.Button("Add prototype", variant="primary")
425
  add_status = gr.Markdown("")
426
 
427
  def _files_to_gallery(files):
@@ -436,13 +528,33 @@ def build_ui() -> gr.Blocks:
436
  continue
437
  return out
438
 
 
 
 
 
 
439
  uploader.change(_files_to_gallery, inputs=[uploader], outputs=[imgs])
440
- add_btn.click(add_prototype, inputs=[label, imgs, save_back], outputs=[add_status])
 
 
 
 
441
 
442
- gr.Markdown("Save DB as (optional):")
443
- save_path = gr.Textbox(label="Output path (relative paths go under ./checkpoints_style/)", placeholder="prototypes_custom.pt")
444
- save_btn = gr.Button("Save As")
445
- save_btn.click(save_db_as, inputs=[save_path], outputs=[add_status])
 
 
 
 
 
 
 
 
 
 
 
446
 
447
  return demo
448
 
 
18
  # Detect if running on HF Spaces (ZeroGPU requires special handling)
19
  _ON_SPACES = bool(os.getenv("SPACE_ID") or os.getenv("HF_SPACE"))
20
 
21
+ # ─────────────────────────────────────────────────────────────────────────────
22
+ # Authentication for prototype management
23
+ # Set ADMIN_PASSWORD env var, or it defaults to "admin" (change in production!)
24
+ # ─────────────────────────────────────────────────────────────────────────────
25
+ ADMIN_PASSWORD = os.getenv("ADMIN_PASSWORD", "admin")
26
+ APPROVED_USERS: dict[str, str] = {
27
+ "admin": ADMIN_PASSWORD,
28
+ }
29
+ # Add more users via APPROVED_USERS_CSV env var: "user1:pass1,user2:pass2"
30
+ for pair in os.getenv("APPROVED_USERS_CSV", "").split(","):
31
+ if ":" in pair:
32
+ u, p = pair.split(":", 1)
33
+ APPROVED_USERS[u.strip()] = p.strip()
34
+
35
  def _patch_fastapi_starlette_middleware_unpack() -> None:
36
  """
37
  Work around FastAPI/Starlette version mismatches where Starlette's Middleware
 
397
  return f"βœ… Saved prototype DB to `{out_path}`"
398
 
399
 
400
+ def list_prototypes() -> Tuple[str, List[List]]:
401
+ """Returns (status, table_rows) with all prototypes."""
402
+ if APP_STATE.db is None:
403
+ return "❌ Click **Load** first.", []
404
+ db = APP_STATE.db
405
+ protos = db.list_prototypes()
406
+ # Group by label for summary
407
+ summary = db.get_label_summary()
408
+ summary_text = ", ".join(f"{k}: {v}" for k, v in sorted(summary.items()))
409
+ rows = [[i, name, lid] for (i, name, lid) in protos]
410
+ return f"βœ… {len(protos)} prototypes. Summary: {summary_text}", rows
411
+
412
+
413
+ def remove_prototype(proto_idx: int, save_back: bool) -> Tuple[str, List[List]]:
414
+ """Remove a prototype by index. Returns (status, updated_table)."""
415
+ if APP_STATE.db is None:
416
+ return "❌ Click **Load** first.", []
417
+ db = APP_STATE.db
418
+ idx = int(proto_idx)
419
+ if not db.remove_prototype(idx):
420
+ return f"❌ Invalid index: {idx}", [[i, name, lid] for (i, name, lid) in db.list_prototypes()]
421
+
422
+ msg = f"βœ… Removed prototype at index {idx}. DB now N={db.centers.shape[0]}."
423
+ if save_back and APP_STATE.proto_path:
424
+ db.save(APP_STATE.proto_path)
425
+ msg += f" Saved to `{APP_STATE.proto_path}`."
426
+
427
+ rows = [[i, name, lid] for (i, name, lid) in db.list_prototypes()]
428
+ return msg, rows
429
+
430
+
431
  def build_ui() -> gr.Blocks:
432
  ckpts = _list_ckpt_files(CKPT_DIR)
433
  protos = _list_proto_files(CKPT_DIR)
 
457
  table = gr.Dataframe(headers=["label", "cosine_sim"], datatype=["str", "number"], interactive=False)
458
  run_btn.click(classify, inputs=[whole, topk], outputs=[out_status, table, face_prev, eyes_prev])
459
 
460
+ with gr.Tab("Manage Prototypes"):
461
  gr.Markdown(
462
+ "### Prototype Management\n"
463
+ "**Authentication required** to add or remove prototypes. "
464
+ "Enter your credentials below."
465
+ )
466
+
467
+ with gr.Row():
468
+ auth_user = gr.Textbox(label="Username", placeholder="admin")
469
+ auth_pass = gr.Textbox(label="Password", type="password", placeholder="password")
470
+
471
+ def _check_auth(user: str, pwd: str) -> bool:
472
+ u = (user or "").strip()
473
+ p = (pwd or "").strip()
474
+ return APPROVED_USERS.get(u) == p and p != ""
475
+
476
+ # ─────────────── List Prototypes ───────────────
477
+ gr.Markdown("---\n#### List Prototypes")
478
+ list_btn = gr.Button("Refresh List", variant="secondary")
479
+ list_status = gr.Markdown("")
480
+ proto_table = gr.Dataframe(
481
+ headers=["index", "label", "label_id"],
482
+ datatype=["number", "str", "number"],
483
+ interactive=False,
484
+ )
485
+ list_btn.click(list_prototypes, inputs=[], outputs=[list_status, proto_table])
486
+
487
+ # ─────────────── Remove Prototype ───────────────
488
+ gr.Markdown("---\n#### Remove Prototype")
489
+ with gr.Row():
490
+ remove_idx = gr.Number(label="Prototype index to remove", precision=0)
491
+ remove_save = gr.Checkbox(value=True, label="Save after removal")
492
+ remove_btn = gr.Button("Remove Prototype", variant="stop")
493
+ remove_status = gr.Markdown("")
494
+
495
+ def _remove_with_auth(user, pwd, idx, save_back):
496
+ if not _check_auth(user, pwd):
497
+ return "❌ Authentication failed. Check username/password.", []
498
+ return remove_prototype(idx, save_back)
499
+
500
+ remove_btn.click(
501
+ _remove_with_auth,
502
+ inputs=[auth_user, auth_pass, remove_idx, remove_save],
503
+ outputs=[remove_status, proto_table],
504
+ )
505
+
506
+ # ─────────────── Add Prototype ───────────────
507
+ gr.Markdown("---\n#### Add New Prototype")
508
+ gr.Markdown(
509
+ "Add a new prototype by averaging embeddings of uploaded whole images. "
510
  "Multiple prototypes per label are allowed."
511
  )
512
  label = gr.Textbox(label="Label name (artist)", placeholder="e.g. new_artist")
513
  imgs = gr.Gallery(label="Whole images (1+)", columns=4, rows=2, height=240, allow_preview=True)
514
  uploader = gr.Files(label="Upload image files (whole)", file_types=["image"], file_count="multiple")
515
+ add_save_back = gr.Checkbox(value=True, label="Save back to prototype DB after adding")
516
+ add_btn = gr.Button("Add Prototype", variant="primary")
517
  add_status = gr.Markdown("")
518
 
519
  def _files_to_gallery(files):
 
528
  continue
529
  return out
530
 
531
+ def _add_with_auth(user, pwd, lbl, images, save_back):
532
+ if not _check_auth(user, pwd):
533
+ return "❌ Authentication failed. Check username/password."
534
+ return add_prototype(lbl, images, save_back)
535
+
536
  uploader.change(_files_to_gallery, inputs=[uploader], outputs=[imgs])
537
+ add_btn.click(
538
+ _add_with_auth,
539
+ inputs=[auth_user, auth_pass, label, imgs, add_save_back],
540
+ outputs=[add_status],
541
+ )
542
 
543
+ # ─────────────── Save DB As ───────────────
544
+ gr.Markdown("---\n#### Save DB As (optional)")
545
+ with gr.Row():
546
+ save_path = gr.Textbox(
547
+ label="Output path (relative paths go under ./checkpoints_style/)",
548
+ placeholder="prototypes_custom.pt",
549
+ )
550
+ save_btn = gr.Button("Save As")
551
+
552
+ def _save_with_auth(user, pwd, path):
553
+ if not _check_auth(user, pwd):
554
+ return "❌ Authentication failed. Check username/password."
555
+ return save_db_as(path)
556
+
557
+ save_btn.click(_save_with_auth, inputs=[auth_user, auth_pass, save_path], outputs=[add_status])
558
 
559
  return demo
560