Spaces:
Sleeping
Sleeping
iljung1106
commited on
Commit
·
89b3ad1
1
Parent(s):
059ebcb
Removed editing feature and make additional prototype temporary.
Browse files- app/proto_db.py +0 -29
- webui_gradio.py +8 -143
app/proto_db.py
CHANGED
|
@@ -43,35 +43,6 @@ class PrototypeDB:
|
|
| 43 |
self.labels = torch.cat([self.labels, torch.tensor([lid], dtype=torch.long)], dim=0)
|
| 44 |
return lid
|
| 45 |
|
| 46 |
-
def remove_prototype(self, proto_idx: int) -> bool:
|
| 47 |
-
"""Remove a prototype by its index. Returns True if removed."""
|
| 48 |
-
if proto_idx < 0 or proto_idx >= self.centers.shape[0]:
|
| 49 |
-
return False
|
| 50 |
-
mask = torch.ones(self.centers.shape[0], dtype=torch.bool)
|
| 51 |
-
mask[proto_idx] = False
|
| 52 |
-
self.centers = self.centers[mask]
|
| 53 |
-
self.labels = self.labels[mask]
|
| 54 |
-
return True
|
| 55 |
-
|
| 56 |
-
def list_prototypes(self) -> List[Tuple[int, str, int]]:
|
| 57 |
-
"""Returns list of (proto_idx, label_name, label_id) for all prototypes."""
|
| 58 |
-
result = []
|
| 59 |
-
for i in range(self.centers.shape[0]):
|
| 60 |
-
lid = int(self.labels[i].item())
|
| 61 |
-
name = self.id_to_name(lid)
|
| 62 |
-
result.append((i, name, lid))
|
| 63 |
-
return result
|
| 64 |
-
|
| 65 |
-
def get_label_summary(self) -> Dict[str, int]:
|
| 66 |
-
"""Returns {label_name: count} for prototype counts per label."""
|
| 67 |
-
from collections import Counter
|
| 68 |
-
counts: Dict[str, int] = {}
|
| 69 |
-
for i in range(self.labels.shape[0]):
|
| 70 |
-
lid = int(self.labels[i].item())
|
| 71 |
-
name = self.id_to_name(lid)
|
| 72 |
-
counts[name] = counts.get(name, 0) + 1
|
| 73 |
-
return counts
|
| 74 |
-
|
| 75 |
def save(self, path: Optional[str | Path] = None) -> Path:
|
| 76 |
out = Path(path) if path is not None else self.source_path
|
| 77 |
if out is None:
|
|
|
|
| 43 |
self.labels = torch.cat([self.labels, torch.tensor([lid], dtype=torch.long)], dim=0)
|
| 44 |
return lid
|
| 45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
def save(self, path: Optional[str | Path] = None) -> Path:
|
| 47 |
out = Path(path) if path is not None else self.source_path
|
| 48 |
if out is None:
|
webui_gradio.py
CHANGED
|
@@ -18,20 +18,6 @@ except Exception: # noqa: BLE001
|
|
| 18 |
# Detect if running on HF Spaces (ZeroGPU requires special handling)
|
| 19 |
_ON_SPACES = bool(os.getenv("SPACE_ID") or os.getenv("HF_SPACE"))
|
| 20 |
|
| 21 |
-
# ─────────────────────────────────────────────────────────────────────────────
|
| 22 |
-
# Authentication for prototype management
|
| 23 |
-
# Set ADMIN_PASSWORD env var, or it defaults to "admin" (change in production!)
|
| 24 |
-
# ─────────────────────────────────────────────────────────────────────────────
|
| 25 |
-
ADMIN_PASSWORD = os.getenv("ADMIN_PASSWORD", "admin")
|
| 26 |
-
APPROVED_USERS: dict[str, str] = {
|
| 27 |
-
"admin": ADMIN_PASSWORD,
|
| 28 |
-
}
|
| 29 |
-
# Add more users via APPROVED_USERS_CSV env var: "user1:pass1,user2:pass2"
|
| 30 |
-
for pair in os.getenv("APPROVED_USERS_CSV", "").split(","):
|
| 31 |
-
if ":" in pair:
|
| 32 |
-
u, p = pair.split(":", 1)
|
| 33 |
-
APPROVED_USERS[u.strip()] = p.strip()
|
| 34 |
-
|
| 35 |
def _patch_fastapi_starlette_middleware_unpack() -> None:
|
| 36 |
"""
|
| 37 |
Work around FastAPI/Starlette version mismatches where Starlette's Middleware
|
|
@@ -333,8 +319,8 @@ def classify(
|
|
| 333 |
def add_prototype(
|
| 334 |
label_name: str,
|
| 335 |
images: List,
|
| 336 |
-
save_back: bool,
|
| 337 |
) -> str:
|
|
|
|
| 338 |
if APP_STATE.lm is None or APP_STATE.db is None:
|
| 339 |
return "❌ Click **Load** first."
|
| 340 |
lm = APP_STATE.lm
|
|
@@ -375,57 +361,7 @@ def add_prototype(
|
|
| 375 |
center = torch.stack(zs, dim=0).mean(dim=0)
|
| 376 |
lid = db.add_center(label_name, center)
|
| 377 |
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
if save_back:
|
| 381 |
-
out_path = db.save(APP_STATE.proto_path)
|
| 382 |
-
msg += f" Saved to `{out_path}`."
|
| 383 |
-
return msg
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
def save_db_as(path_text: str) -> str:
|
| 387 |
-
if APP_STATE.db is None:
|
| 388 |
-
return "❌ Nothing loaded."
|
| 389 |
-
out = (path_text or "").strip()
|
| 390 |
-
if not out:
|
| 391 |
-
return "❌ Provide an output path."
|
| 392 |
-
out_path = Path(out)
|
| 393 |
-
if not out_path.is_absolute():
|
| 394 |
-
out_path = (CKPT_DIR / out_path).resolve()
|
| 395 |
-
APP_STATE.db.save(out_path)
|
| 396 |
-
APP_STATE.proto_path = str(out_path)
|
| 397 |
-
return f"✅ Saved prototype DB to `{out_path}`"
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
def list_prototypes() -> Tuple[str, List[List]]:
|
| 401 |
-
"""Returns (status, table_rows) with all prototypes."""
|
| 402 |
-
if APP_STATE.db is None:
|
| 403 |
-
return "❌ Click **Load** first.", []
|
| 404 |
-
db = APP_STATE.db
|
| 405 |
-
protos = db.list_prototypes()
|
| 406 |
-
# Group by label for summary
|
| 407 |
-
summary = db.get_label_summary()
|
| 408 |
-
summary_text = ", ".join(f"{k}: {v}" for k, v in sorted(summary.items()))
|
| 409 |
-
rows = [[i, name, lid] for (i, name, lid) in protos]
|
| 410 |
-
return f"✅ {len(protos)} prototypes. Summary: {summary_text}", rows
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
def remove_prototype(proto_idx: int, save_back: bool) -> Tuple[str, List[List]]:
|
| 414 |
-
"""Remove a prototype by index. Returns (status, updated_table)."""
|
| 415 |
-
if APP_STATE.db is None:
|
| 416 |
-
return "❌ Click **Load** first.", []
|
| 417 |
-
db = APP_STATE.db
|
| 418 |
-
idx = int(proto_idx)
|
| 419 |
-
if not db.remove_prototype(idx):
|
| 420 |
-
return f"❌ Invalid index: {idx}", [[i, name, lid] for (i, name, lid) in db.list_prototypes()]
|
| 421 |
-
|
| 422 |
-
msg = f"✅ Removed prototype at index {idx}. DB now N={db.centers.shape[0]}."
|
| 423 |
-
if save_back and APP_STATE.proto_path:
|
| 424 |
-
db.save(APP_STATE.proto_path)
|
| 425 |
-
msg += f" Saved to `{APP_STATE.proto_path}`."
|
| 426 |
-
|
| 427 |
-
rows = [[i, name, lid] for (i, name, lid) in db.list_prototypes()]
|
| 428 |
-
return msg, rows
|
| 429 |
|
| 430 |
|
| 431 |
def build_ui() -> gr.Blocks:
|
|
@@ -457,63 +393,17 @@ def build_ui() -> gr.Blocks:
|
|
| 457 |
table = gr.Dataframe(headers=["label", "cosine_sim"], datatype=["str", "number"], interactive=False)
|
| 458 |
run_btn.click(classify, inputs=[whole, topk], outputs=[out_status, table, face_prev, eyes_prev])
|
| 459 |
|
| 460 |
-
with gr.Tab("
|
| 461 |
gr.Markdown(
|
| 462 |
-
"###
|
| 463 |
-
"
|
| 464 |
-
"
|
| 465 |
-
)
|
| 466 |
-
|
| 467 |
-
with gr.Row():
|
| 468 |
-
auth_user = gr.Textbox(label="Username", placeholder="admin")
|
| 469 |
-
auth_pass = gr.Textbox(label="Password", type="password", placeholder="password")
|
| 470 |
-
|
| 471 |
-
def _check_auth(user: str, pwd: str) -> bool:
|
| 472 |
-
u = (user or "").strip()
|
| 473 |
-
p = (pwd or "").strip()
|
| 474 |
-
return APPROVED_USERS.get(u) == p and p != ""
|
| 475 |
-
|
| 476 |
-
# ─────────────── List Prototypes ───────────────
|
| 477 |
-
gr.Markdown("---\n#### List Prototypes")
|
| 478 |
-
list_btn = gr.Button("Refresh List", variant="secondary")
|
| 479 |
-
list_status = gr.Markdown("")
|
| 480 |
-
proto_table = gr.Dataframe(
|
| 481 |
-
headers=["index", "label", "label_id"],
|
| 482 |
-
datatype=["number", "str", "number"],
|
| 483 |
-
interactive=False,
|
| 484 |
-
)
|
| 485 |
-
list_btn.click(list_prototypes, inputs=[], outputs=[list_status, proto_table])
|
| 486 |
-
|
| 487 |
-
# ─────────────── Remove Prototype ───────────────
|
| 488 |
-
gr.Markdown("---\n#### Remove Prototype")
|
| 489 |
-
with gr.Row():
|
| 490 |
-
remove_idx = gr.Number(label="Prototype index to remove", precision=0)
|
| 491 |
-
remove_save = gr.Checkbox(value=True, label="Save after removal")
|
| 492 |
-
remove_btn = gr.Button("Remove Prototype", variant="stop")
|
| 493 |
-
remove_status = gr.Markdown("")
|
| 494 |
-
|
| 495 |
-
def _remove_with_auth(user, pwd, idx, save_back):
|
| 496 |
-
if not _check_auth(user, pwd):
|
| 497 |
-
return "❌ Authentication failed. Check username/password.", []
|
| 498 |
-
return remove_prototype(idx, save_back)
|
| 499 |
-
|
| 500 |
-
remove_btn.click(
|
| 501 |
-
_remove_with_auth,
|
| 502 |
-
inputs=[auth_user, auth_pass, remove_idx, remove_save],
|
| 503 |
-
outputs=[remove_status, proto_table],
|
| 504 |
-
)
|
| 505 |
-
|
| 506 |
-
# ─────────────── Add Prototype ───────────────
|
| 507 |
-
gr.Markdown("---\n#### Add New Prototype")
|
| 508 |
-
gr.Markdown(
|
| 509 |
-
"Add a new prototype by averaging embeddings of uploaded whole images. "
|
| 510 |
"Multiple prototypes per label are allowed."
|
| 511 |
)
|
| 512 |
label = gr.Textbox(label="Label name (artist)", placeholder="e.g. new_artist")
|
| 513 |
imgs = gr.Gallery(label="Whole images (1+)", columns=4, rows=2, height=240, allow_preview=True)
|
| 514 |
uploader = gr.Files(label="Upload image files (whole)", file_types=["image"], file_count="multiple")
|
| 515 |
-
|
| 516 |
-
add_btn = gr.Button("Add Prototype", variant="primary")
|
| 517 |
add_status = gr.Markdown("")
|
| 518 |
|
| 519 |
def _files_to_gallery(files):
|
|
@@ -528,33 +418,8 @@ def build_ui() -> gr.Blocks:
|
|
| 528 |
continue
|
| 529 |
return out
|
| 530 |
|
| 531 |
-
def _add_with_auth(user, pwd, lbl, images, save_back):
|
| 532 |
-
if not _check_auth(user, pwd):
|
| 533 |
-
return "❌ Authentication failed. Check username/password."
|
| 534 |
-
return add_prototype(lbl, images, save_back)
|
| 535 |
-
|
| 536 |
uploader.change(_files_to_gallery, inputs=[uploader], outputs=[imgs])
|
| 537 |
-
add_btn.click(
|
| 538 |
-
_add_with_auth,
|
| 539 |
-
inputs=[auth_user, auth_pass, label, imgs, add_save_back],
|
| 540 |
-
outputs=[add_status],
|
| 541 |
-
)
|
| 542 |
-
|
| 543 |
-
# ─────────────── Save DB As ───────────────
|
| 544 |
-
gr.Markdown("---\n#### Save DB As (optional)")
|
| 545 |
-
with gr.Row():
|
| 546 |
-
save_path = gr.Textbox(
|
| 547 |
-
label="Output path (relative paths go under ./checkpoints_style/)",
|
| 548 |
-
placeholder="prototypes_custom.pt",
|
| 549 |
-
)
|
| 550 |
-
save_btn = gr.Button("Save As")
|
| 551 |
-
|
| 552 |
-
def _save_with_auth(user, pwd, path):
|
| 553 |
-
if not _check_auth(user, pwd):
|
| 554 |
-
return "❌ Authentication failed. Check username/password."
|
| 555 |
-
return save_db_as(path)
|
| 556 |
-
|
| 557 |
-
save_btn.click(_save_with_auth, inputs=[auth_user, auth_pass, save_path], outputs=[add_status])
|
| 558 |
|
| 559 |
return demo
|
| 560 |
|
|
|
|
| 18 |
# Detect if running on HF Spaces (ZeroGPU requires special handling)
|
| 19 |
_ON_SPACES = bool(os.getenv("SPACE_ID") or os.getenv("HF_SPACE"))
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
def _patch_fastapi_starlette_middleware_unpack() -> None:
|
| 22 |
"""
|
| 23 |
Work around FastAPI/Starlette version mismatches where Starlette's Middleware
|
|
|
|
| 319 |
def add_prototype(
|
| 320 |
label_name: str,
|
| 321 |
images: List,
|
|
|
|
| 322 |
) -> str:
|
| 323 |
+
"""Add a temporary prototype (in-memory only, not persisted to disk)."""
|
| 324 |
if APP_STATE.lm is None or APP_STATE.db is None:
|
| 325 |
return "❌ Click **Load** first."
|
| 326 |
lm = APP_STATE.lm
|
|
|
|
| 361 |
center = torch.stack(zs, dim=0).mean(dim=0)
|
| 362 |
lid = db.add_center(label_name, center)
|
| 363 |
|
| 364 |
+
return f"✅ Added temporary prototype for `{label_name}` (label_id={lid}). DB now N={db.centers.shape[0]}. ⚠️ This is session-only and will be lost when the Space restarts."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 365 |
|
| 366 |
|
| 367 |
def build_ui() -> gr.Blocks:
|
|
|
|
| 393 |
table = gr.Dataframe(headers=["label", "cosine_sim"], datatype=["str", "number"], interactive=False)
|
| 394 |
run_btn.click(classify, inputs=[whole, topk], outputs=[out_status, table, face_prev, eyes_prev])
|
| 395 |
|
| 396 |
+
with gr.Tab("Add prototype (temporary)"):
|
| 397 |
gr.Markdown(
|
| 398 |
+
"### ⚠️ Temporary Prototypes Only\n"
|
| 399 |
+
"Add a new prototype by averaging embeddings of uploaded whole images.\n"
|
| 400 |
+
"**These prototypes are session-only** — they will be lost when the Space restarts or goes idle.\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 401 |
"Multiple prototypes per label are allowed."
|
| 402 |
)
|
| 403 |
label = gr.Textbox(label="Label name (artist)", placeholder="e.g. new_artist")
|
| 404 |
imgs = gr.Gallery(label="Whole images (1+)", columns=4, rows=2, height=240, allow_preview=True)
|
| 405 |
uploader = gr.Files(label="Upload image files (whole)", file_types=["image"], file_count="multiple")
|
| 406 |
+
add_btn = gr.Button("Add temporary prototype", variant="primary")
|
|
|
|
| 407 |
add_status = gr.Markdown("")
|
| 408 |
|
| 409 |
def _files_to_gallery(files):
|
|
|
|
| 418 |
continue
|
| 419 |
return out
|
| 420 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 421 |
uploader.change(_files_to_gallery, inputs=[uploader], outputs=[imgs])
|
| 422 |
+
add_btn.click(add_prototype, inputs=[label, imgs], outputs=[add_status])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 423 |
|
| 424 |
return demo
|
| 425 |
|