Spaces:
Sleeping
Sleeping
iljung1106
commited on
Commit
Β·
059ebcb
1
Parent(s):
b04d768
add prototype editing feature
Browse files- app/model_io.py +21 -4
- app/proto_db.py +29 -0
- webui_gradio.py +121 -9
app/model_io.py
CHANGED
|
@@ -1,11 +1,15 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
|
|
|
| 3 |
from dataclasses import dataclass
|
| 4 |
from pathlib import Path
|
| 5 |
from typing import Optional, Tuple
|
| 6 |
|
| 7 |
import torch
|
| 8 |
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
@dataclass(frozen=True)
|
| 11 |
class LoadedModel:
|
|
@@ -21,6 +25,9 @@ class LoadedModel:
|
|
| 21 |
def _pick_device(device: str) -> torch.device:
|
| 22 |
if device.strip().lower() == "cpu":
|
| 23 |
return torch.device("cpu")
|
|
|
|
|
|
|
|
|
|
| 24 |
if torch.cuda.is_available():
|
| 25 |
return torch.device("cuda")
|
| 26 |
return torch.device("cpu")
|
|
@@ -41,9 +48,13 @@ def load_style_model(
|
|
| 41 |
if not ckpt_path.exists():
|
| 42 |
raise FileNotFoundError(str(ckpt_path))
|
| 43 |
|
| 44 |
-
|
| 45 |
-
if
|
|
|
|
|
|
|
| 46 |
dev = _pick_device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
| 47 |
|
| 48 |
ck = torch.load(str(ckpt_path), map_location="cpu")
|
| 49 |
meta = ck.get("meta", {}) if isinstance(ck, dict) else {}
|
|
@@ -103,8 +114,14 @@ def embed_triview(
|
|
| 103 |
# Use lazy dtype detection to avoid CUDA init at import time (ZeroGPU compatibility)
|
| 104 |
import train_style_ddp as _ts
|
| 105 |
_dtype = _ts._get_amp_dtype() if hasattr(_ts, "_get_amp_dtype") else torch.float16
|
| 106 |
-
|
| 107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
z = torch.nn.functional.normalize(z.float(), dim=1)
|
| 109 |
return z.squeeze(0).detach().cpu()
|
| 110 |
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
+
import os
|
| 4 |
from dataclasses import dataclass
|
| 5 |
from pathlib import Path
|
| 6 |
from typing import Optional, Tuple
|
| 7 |
|
| 8 |
import torch
|
| 9 |
|
| 10 |
+
# ZeroGPU on HF Spaces: CUDA must not be initialized in the main process
|
| 11 |
+
_ON_SPACES = bool(os.getenv("SPACE_ID") or os.getenv("HF_SPACE"))
|
| 12 |
+
|
| 13 |
|
| 14 |
@dataclass(frozen=True)
|
| 15 |
class LoadedModel:
|
|
|
|
| 25 |
def _pick_device(device: str) -> torch.device:
|
| 26 |
if device.strip().lower() == "cpu":
|
| 27 |
return torch.device("cpu")
|
| 28 |
+
if _ON_SPACES:
|
| 29 |
+
# ZeroGPU: can't init CUDA in main process
|
| 30 |
+
return torch.device("cpu")
|
| 31 |
if torch.cuda.is_available():
|
| 32 |
return torch.device("cuda")
|
| 33 |
return torch.device("cpu")
|
|
|
|
| 48 |
if not ckpt_path.exists():
|
| 49 |
raise FileNotFoundError(str(ckpt_path))
|
| 50 |
|
| 51 |
+
# On Spaces, always use CPU (ZeroGPU forbids CUDA in main process)
|
| 52 |
+
if _ON_SPACES:
|
| 53 |
+
dev = torch.device("cpu")
|
| 54 |
+
elif device == "auto":
|
| 55 |
dev = _pick_device("cuda" if torch.cuda.is_available() else "cpu")
|
| 56 |
+
else:
|
| 57 |
+
dev = _pick_device(device)
|
| 58 |
|
| 59 |
ck = torch.load(str(ckpt_path), map_location="cpu")
|
| 60 |
meta = ck.get("meta", {}) if isinstance(ck, dict) else {}
|
|
|
|
| 114 |
# Use lazy dtype detection to avoid CUDA init at import time (ZeroGPU compatibility)
|
| 115 |
import train_style_ddp as _ts
|
| 116 |
_dtype = _ts._get_amp_dtype() if hasattr(_ts, "_get_amp_dtype") else torch.float16
|
| 117 |
+
# On CPU or Spaces, skip autocast entirely to avoid touching CUDA
|
| 118 |
+
use_amp = (lm.device.type == "cuda") and not _ON_SPACES
|
| 119 |
+
if use_amp:
|
| 120 |
+
with torch.no_grad(), torch.amp.autocast("cuda", dtype=_dtype, enabled=True):
|
| 121 |
+
z, _, _ = lm.model(views, masks)
|
| 122 |
+
else:
|
| 123 |
+
with torch.no_grad():
|
| 124 |
+
z, _, _ = lm.model(views, masks)
|
| 125 |
z = torch.nn.functional.normalize(z.float(), dim=1)
|
| 126 |
return z.squeeze(0).detach().cpu()
|
| 127 |
|
app/proto_db.py
CHANGED
|
@@ -43,6 +43,35 @@ class PrototypeDB:
|
|
| 43 |
self.labels = torch.cat([self.labels, torch.tensor([lid], dtype=torch.long)], dim=0)
|
| 44 |
return lid
|
| 45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
def save(self, path: Optional[str | Path] = None) -> Path:
|
| 47 |
out = Path(path) if path is not None else self.source_path
|
| 48 |
if out is None:
|
|
|
|
| 43 |
self.labels = torch.cat([self.labels, torch.tensor([lid], dtype=torch.long)], dim=0)
|
| 44 |
return lid
|
| 45 |
|
| 46 |
+
def remove_prototype(self, proto_idx: int) -> bool:
|
| 47 |
+
"""Remove a prototype by its index. Returns True if removed."""
|
| 48 |
+
if proto_idx < 0 or proto_idx >= self.centers.shape[0]:
|
| 49 |
+
return False
|
| 50 |
+
mask = torch.ones(self.centers.shape[0], dtype=torch.bool)
|
| 51 |
+
mask[proto_idx] = False
|
| 52 |
+
self.centers = self.centers[mask]
|
| 53 |
+
self.labels = self.labels[mask]
|
| 54 |
+
return True
|
| 55 |
+
|
| 56 |
+
def list_prototypes(self) -> List[Tuple[int, str, int]]:
|
| 57 |
+
"""Returns list of (proto_idx, label_name, label_id) for all prototypes."""
|
| 58 |
+
result = []
|
| 59 |
+
for i in range(self.centers.shape[0]):
|
| 60 |
+
lid = int(self.labels[i].item())
|
| 61 |
+
name = self.id_to_name(lid)
|
| 62 |
+
result.append((i, name, lid))
|
| 63 |
+
return result
|
| 64 |
+
|
| 65 |
+
def get_label_summary(self) -> Dict[str, int]:
|
| 66 |
+
"""Returns {label_name: count} for prototype counts per label."""
|
| 67 |
+
from collections import Counter
|
| 68 |
+
counts: Dict[str, int] = {}
|
| 69 |
+
for i in range(self.labels.shape[0]):
|
| 70 |
+
lid = int(self.labels[i].item())
|
| 71 |
+
name = self.id_to_name(lid)
|
| 72 |
+
counts[name] = counts.get(name, 0) + 1
|
| 73 |
+
return counts
|
| 74 |
+
|
| 75 |
def save(self, path: Optional[str | Path] = None) -> Path:
|
| 76 |
out = Path(path) if path is not None else self.source_path
|
| 77 |
if out is None:
|
webui_gradio.py
CHANGED
|
@@ -18,6 +18,20 @@ except Exception: # noqa: BLE001
|
|
| 18 |
# Detect if running on HF Spaces (ZeroGPU requires special handling)
|
| 19 |
_ON_SPACES = bool(os.getenv("SPACE_ID") or os.getenv("HF_SPACE"))
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
def _patch_fastapi_starlette_middleware_unpack() -> None:
|
| 22 |
"""
|
| 23 |
Work around FastAPI/Starlette version mismatches where Starlette's Middleware
|
|
@@ -383,6 +397,37 @@ def save_db_as(path_text: str) -> str:
|
|
| 383 |
return f"β
Saved prototype DB to `{out_path}`"
|
| 384 |
|
| 385 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 386 |
def build_ui() -> gr.Blocks:
|
| 387 |
ckpts = _list_ckpt_files(CKPT_DIR)
|
| 388 |
protos = _list_proto_files(CKPT_DIR)
|
|
@@ -412,16 +457,63 @@ def build_ui() -> gr.Blocks:
|
|
| 412 |
table = gr.Dataframe(headers=["label", "cosine_sim"], datatype=["str", "number"], interactive=False)
|
| 413 |
run_btn.click(classify, inputs=[whole, topk], outputs=[out_status, table, face_prev, eyes_prev])
|
| 414 |
|
| 415 |
-
with gr.Tab("
|
| 416 |
gr.Markdown(
|
| 417 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 418 |
"Multiple prototypes per label are allowed."
|
| 419 |
)
|
| 420 |
label = gr.Textbox(label="Label name (artist)", placeholder="e.g. new_artist")
|
| 421 |
imgs = gr.Gallery(label="Whole images (1+)", columns=4, rows=2, height=240, allow_preview=True)
|
| 422 |
uploader = gr.Files(label="Upload image files (whole)", file_types=["image"], file_count="multiple")
|
| 423 |
-
|
| 424 |
-
add_btn = gr.Button("Add
|
| 425 |
add_status = gr.Markdown("")
|
| 426 |
|
| 427 |
def _files_to_gallery(files):
|
|
@@ -436,13 +528,33 @@ def build_ui() -> gr.Blocks:
|
|
| 436 |
continue
|
| 437 |
return out
|
| 438 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 439 |
uploader.change(_files_to_gallery, inputs=[uploader], outputs=[imgs])
|
| 440 |
-
add_btn.click(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 441 |
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 446 |
|
| 447 |
return demo
|
| 448 |
|
|
|
|
| 18 |
# Detect if running on HF Spaces (ZeroGPU requires special handling)
|
| 19 |
_ON_SPACES = bool(os.getenv("SPACE_ID") or os.getenv("HF_SPACE"))
|
| 20 |
|
| 21 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 22 |
+
# Authentication for prototype management
|
| 23 |
+
# Set ADMIN_PASSWORD env var, or it defaults to "admin" (change in production!)
|
| 24 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 25 |
+
ADMIN_PASSWORD = os.getenv("ADMIN_PASSWORD", "admin")
|
| 26 |
+
APPROVED_USERS: dict[str, str] = {
|
| 27 |
+
"admin": ADMIN_PASSWORD,
|
| 28 |
+
}
|
| 29 |
+
# Add more users via APPROVED_USERS_CSV env var: "user1:pass1,user2:pass2"
|
| 30 |
+
for pair in os.getenv("APPROVED_USERS_CSV", "").split(","):
|
| 31 |
+
if ":" in pair:
|
| 32 |
+
u, p = pair.split(":", 1)
|
| 33 |
+
APPROVED_USERS[u.strip()] = p.strip()
|
| 34 |
+
|
| 35 |
def _patch_fastapi_starlette_middleware_unpack() -> None:
|
| 36 |
"""
|
| 37 |
Work around FastAPI/Starlette version mismatches where Starlette's Middleware
|
|
|
|
| 397 |
return f"β
Saved prototype DB to `{out_path}`"
|
| 398 |
|
| 399 |
|
| 400 |
+
def list_prototypes() -> Tuple[str, List[List]]:
|
| 401 |
+
"""Returns (status, table_rows) with all prototypes."""
|
| 402 |
+
if APP_STATE.db is None:
|
| 403 |
+
return "β Click **Load** first.", []
|
| 404 |
+
db = APP_STATE.db
|
| 405 |
+
protos = db.list_prototypes()
|
| 406 |
+
# Group by label for summary
|
| 407 |
+
summary = db.get_label_summary()
|
| 408 |
+
summary_text = ", ".join(f"{k}: {v}" for k, v in sorted(summary.items()))
|
| 409 |
+
rows = [[i, name, lid] for (i, name, lid) in protos]
|
| 410 |
+
return f"β
{len(protos)} prototypes. Summary: {summary_text}", rows
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
def remove_prototype(proto_idx: int, save_back: bool) -> Tuple[str, List[List]]:
|
| 414 |
+
"""Remove a prototype by index. Returns (status, updated_table)."""
|
| 415 |
+
if APP_STATE.db is None:
|
| 416 |
+
return "β Click **Load** first.", []
|
| 417 |
+
db = APP_STATE.db
|
| 418 |
+
idx = int(proto_idx)
|
| 419 |
+
if not db.remove_prototype(idx):
|
| 420 |
+
return f"β Invalid index: {idx}", [[i, name, lid] for (i, name, lid) in db.list_prototypes()]
|
| 421 |
+
|
| 422 |
+
msg = f"β
Removed prototype at index {idx}. DB now N={db.centers.shape[0]}."
|
| 423 |
+
if save_back and APP_STATE.proto_path:
|
| 424 |
+
db.save(APP_STATE.proto_path)
|
| 425 |
+
msg += f" Saved to `{APP_STATE.proto_path}`."
|
| 426 |
+
|
| 427 |
+
rows = [[i, name, lid] for (i, name, lid) in db.list_prototypes()]
|
| 428 |
+
return msg, rows
|
| 429 |
+
|
| 430 |
+
|
| 431 |
def build_ui() -> gr.Blocks:
|
| 432 |
ckpts = _list_ckpt_files(CKPT_DIR)
|
| 433 |
protos = _list_proto_files(CKPT_DIR)
|
|
|
|
| 457 |
table = gr.Dataframe(headers=["label", "cosine_sim"], datatype=["str", "number"], interactive=False)
|
| 458 |
run_btn.click(classify, inputs=[whole, topk], outputs=[out_status, table, face_prev, eyes_prev])
|
| 459 |
|
| 460 |
+
with gr.Tab("Manage Prototypes"):
|
| 461 |
gr.Markdown(
|
| 462 |
+
"### Prototype Management\n"
|
| 463 |
+
"**Authentication required** to add or remove prototypes. "
|
| 464 |
+
"Enter your credentials below."
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
with gr.Row():
|
| 468 |
+
auth_user = gr.Textbox(label="Username", placeholder="admin")
|
| 469 |
+
auth_pass = gr.Textbox(label="Password", type="password", placeholder="password")
|
| 470 |
+
|
| 471 |
+
def _check_auth(user: str, pwd: str) -> bool:
|
| 472 |
+
u = (user or "").strip()
|
| 473 |
+
p = (pwd or "").strip()
|
| 474 |
+
return APPROVED_USERS.get(u) == p and p != ""
|
| 475 |
+
|
| 476 |
+
# βββββββββββββββ List Prototypes βββββββββββββββ
|
| 477 |
+
gr.Markdown("---\n#### List Prototypes")
|
| 478 |
+
list_btn = gr.Button("Refresh List", variant="secondary")
|
| 479 |
+
list_status = gr.Markdown("")
|
| 480 |
+
proto_table = gr.Dataframe(
|
| 481 |
+
headers=["index", "label", "label_id"],
|
| 482 |
+
datatype=["number", "str", "number"],
|
| 483 |
+
interactive=False,
|
| 484 |
+
)
|
| 485 |
+
list_btn.click(list_prototypes, inputs=[], outputs=[list_status, proto_table])
|
| 486 |
+
|
| 487 |
+
# βββββββββββββββ Remove Prototype βββββββββββββββ
|
| 488 |
+
gr.Markdown("---\n#### Remove Prototype")
|
| 489 |
+
with gr.Row():
|
| 490 |
+
remove_idx = gr.Number(label="Prototype index to remove", precision=0)
|
| 491 |
+
remove_save = gr.Checkbox(value=True, label="Save after removal")
|
| 492 |
+
remove_btn = gr.Button("Remove Prototype", variant="stop")
|
| 493 |
+
remove_status = gr.Markdown("")
|
| 494 |
+
|
| 495 |
+
def _remove_with_auth(user, pwd, idx, save_back):
|
| 496 |
+
if not _check_auth(user, pwd):
|
| 497 |
+
return "β Authentication failed. Check username/password.", []
|
| 498 |
+
return remove_prototype(idx, save_back)
|
| 499 |
+
|
| 500 |
+
remove_btn.click(
|
| 501 |
+
_remove_with_auth,
|
| 502 |
+
inputs=[auth_user, auth_pass, remove_idx, remove_save],
|
| 503 |
+
outputs=[remove_status, proto_table],
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
# βββββββββββββββ Add Prototype βββββββββββββββ
|
| 507 |
+
gr.Markdown("---\n#### Add New Prototype")
|
| 508 |
+
gr.Markdown(
|
| 509 |
+
"Add a new prototype by averaging embeddings of uploaded whole images. "
|
| 510 |
"Multiple prototypes per label are allowed."
|
| 511 |
)
|
| 512 |
label = gr.Textbox(label="Label name (artist)", placeholder="e.g. new_artist")
|
| 513 |
imgs = gr.Gallery(label="Whole images (1+)", columns=4, rows=2, height=240, allow_preview=True)
|
| 514 |
uploader = gr.Files(label="Upload image files (whole)", file_types=["image"], file_count="multiple")
|
| 515 |
+
add_save_back = gr.Checkbox(value=True, label="Save back to prototype DB after adding")
|
| 516 |
+
add_btn = gr.Button("Add Prototype", variant="primary")
|
| 517 |
add_status = gr.Markdown("")
|
| 518 |
|
| 519 |
def _files_to_gallery(files):
|
|
|
|
| 528 |
continue
|
| 529 |
return out
|
| 530 |
|
| 531 |
+
def _add_with_auth(user, pwd, lbl, images, save_back):
|
| 532 |
+
if not _check_auth(user, pwd):
|
| 533 |
+
return "β Authentication failed. Check username/password."
|
| 534 |
+
return add_prototype(lbl, images, save_back)
|
| 535 |
+
|
| 536 |
uploader.change(_files_to_gallery, inputs=[uploader], outputs=[imgs])
|
| 537 |
+
add_btn.click(
|
| 538 |
+
_add_with_auth,
|
| 539 |
+
inputs=[auth_user, auth_pass, label, imgs, add_save_back],
|
| 540 |
+
outputs=[add_status],
|
| 541 |
+
)
|
| 542 |
|
| 543 |
+
# βββββββββββββββ Save DB As βββββββββββββββ
|
| 544 |
+
gr.Markdown("---\n#### Save DB As (optional)")
|
| 545 |
+
with gr.Row():
|
| 546 |
+
save_path = gr.Textbox(
|
| 547 |
+
label="Output path (relative paths go under ./checkpoints_style/)",
|
| 548 |
+
placeholder="prototypes_custom.pt",
|
| 549 |
+
)
|
| 550 |
+
save_btn = gr.Button("Save As")
|
| 551 |
+
|
| 552 |
+
def _save_with_auth(user, pwd, path):
|
| 553 |
+
if not _check_auth(user, pwd):
|
| 554 |
+
return "β Authentication failed. Check username/password."
|
| 555 |
+
return save_db_as(path)
|
| 556 |
+
|
| 557 |
+
save_btn.click(_save_with_auth, inputs=[auth_user, auth_pass, save_path], outputs=[add_status])
|
| 558 |
|
| 559 |
return demo
|
| 560 |
|