iljung1106
Fix contribution bar rendering
0ca0b8b
from __future__ import annotations
import json
import os
import re
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import gradio as gr
import torch
import torch.nn.functional as F
from PIL import Image
import web_ui
from artist_style_dinov3.explain import _encode_query
try:
import spaces
except Exception:
class spaces:
@staticmethod
def GPU(*_args, **_kwargs):
def decorator(fn):
return fn
return decorator
ROOT = Path(__file__).resolve().parent
STORE_ROOT = Path(os.environ.get("CUSTOM_STORE_DIR") or ("/data/custom_store" if Path("/data").exists() else ROOT / "custom_store"))
INDEX_PATH = STORE_ROOT / "index.json"
ARTISTS_DIR = STORE_ROOT / "artists"
STORE_ERROR: Optional[str] = None
def _env_path(name: str, default: Path) -> str:
return str(Path(os.environ[name]).expanduser()) if name in os.environ and os.environ[name].strip() else str(default)
def _ensure_store() -> None:
global STORE_ERROR
ARTISTS_DIR.mkdir(parents=True, exist_ok=True)
if not INDEX_PATH.exists():
INDEX_PATH.write_text(json.dumps({"schema_version": 1, "artists": []}, indent=2), encoding="utf-8")
STORE_ERROR = None
def _load_index() -> dict:
try:
_ensure_store()
except Exception as exc:
global STORE_ERROR
STORE_ERROR = f"Added-artist storage is not writable: {exc}"
return {"schema_version": 1, "artists": []}
try:
payload = json.loads(INDEX_PATH.read_text(encoding="utf-8"))
except Exception:
payload = {"schema_version": 1, "artists": []}
payload.setdefault("schema_version", 1)
payload.setdefault("artists", [])
return payload
def _save_index(payload: dict) -> None:
try:
_ensure_store()
except Exception as exc:
raise gr.Error(f"Added-artist storage is not writable: {exc}") from exc
tmp = INDEX_PATH.with_suffix(".tmp")
tmp.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
tmp.replace(INDEX_PATH)
def _slugify(name: str) -> str:
slug = re.sub(r"[^a-zA-Z0-9._-]+", "-", name.strip()).strip("-._").lower()
return slug or "artist"
def _now() -> str:
return datetime.now(timezone.utc).isoformat()
def _mean_descriptor(descriptors: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
if not descriptors:
raise ValueError("no descriptors")
result = {}
for key in [
"embedding",
"branch_embeddings",
"branch_projected_embeddings",
"weighted_branch_projected_embeddings",
"stacked_view_embeddings",
"stacked_view_weights",
"branch_weights",
"view_mask",
"effective_view_mask",
"effective_branch_mask",
]:
values = [item[key].detach().float().cpu() for item in descriptors if key in item]
if not values:
continue
stacked = torch.stack(values, dim=0)
mean = stacked.mean(dim=0)
if key in {"embedding", "branch_embeddings", "branch_projected_embeddings", "stacked_view_embeddings"}:
mean = F.normalize(mean.float(), dim=-1)
elif key in {"stacked_view_weights", "branch_weights"}:
mean = mean / mean.sum(dim=-1, keepdim=True).clamp_min(1e-6)
elif key in {"view_mask", "effective_view_mask", "effective_branch_mask"}:
mean = (mean > 0).float()
result[key] = mean
if "weighted_branch_projected_embeddings" in result:
result["embedding"] = F.normalize(result["weighted_branch_projected_embeddings"].reshape(-1).float(), dim=0)
return result
def _spherical_kmeans(vectors: torch.Tensor, k: int, iterations: int = 25) -> Tuple[torch.Tensor, torch.Tensor]:
vectors = F.normalize(vectors.float(), dim=-1)
k = max(1, min(int(k), vectors.size(0)))
centers = vectors[torch.linspace(0, vectors.size(0) - 1, steps=k).round().long()].clone()
assignments = torch.zeros(vectors.size(0), dtype=torch.long)
for _ in range(iterations):
assignments = (vectors @ centers.t()).argmax(dim=1)
next_centers = []
for idx in range(k):
members = vectors[assignments == idx]
if members.numel() == 0:
next_centers.append(centers[idx])
else:
next_centers.append(F.normalize(members.mean(dim=0), dim=0))
centers = torch.stack(next_centers, dim=0)
return centers, assignments
def _custom_display_key(display_name: str, artist_id: str, existing: set[str]) -> str:
if display_name not in existing:
return display_name
return f"{display_name} ({artist_id[:8]})"
def load_custom_bank(existing_names: Optional[set[str]] = None) -> Tuple[List[str], Dict[str, torch.Tensor], Dict[str, list], Dict[str, str]]:
existing_names = existing_names or set()
index = _load_index()
names: List[str] = []
bank: Dict[str, torch.Tensor] = {}
descriptors: Dict[str, list] = {}
key_to_artist_id: Dict[str, str] = {}
for item in index.get("artists", []):
if not item.get("enabled", True):
continue
path = STORE_ROOT / item["path"]
if not path.exists():
continue
payload = torch.load(path, map_location="cpu")
key = _custom_display_key(item["display_name"], item["artist_id"], existing_names | set(names))
names.append(key)
bank[key] = payload["prototype_bank"]
descriptors[key] = payload["prototype_descriptors"]
key_to_artist_id[key] = item["artist_id"]
return names, bank, descriptors, key_to_artist_id
def load_combined_runtime(device_name: str) -> dict:
runtime = web_ui.load_runtime(
_env_path("ARTIST_STYLE_CHECKPOINT", ROOT / "artifacts" / "style_training_dinov3" / "best.pt"),
_env_path("ARTIST_PROTOTYPES", ROOT / "artifacts" / "style_training_dinov3" / "artist_prototypes.pt"),
_env_path("DINOV3_ROOT", ROOT / "third_party" / "dinov3"),
_env_path(
"DINOV3_WEIGHTS",
ROOT / "artifacts" / "pretrained" / "dinov3" / "dinov3_vits16_pretrain_lvd1689m-08c60483.pth",
),
device_name,
)
base_artists = list(runtime["artists"])
base_bank = {
artist: runtime["prototype_tensor"][idx].detach().cpu()
for idx, artist in enumerate(base_artists)
}
base_descriptors = dict(runtime["prototype_descriptors"])
custom_names, custom_bank, custom_descriptors, key_to_artist_id = load_custom_bank(set(base_artists))
artists = base_artists + custom_names
bank = {**base_bank, **custom_bank}
descriptors = {**base_descriptors, **custom_descriptors}
prototype_tensor = torch.stack([bank[artist] for artist in artists], dim=0).float()
prototype_tensor = F.normalize(prototype_tensor, dim=-1).to(runtime["device"])
return {
**runtime,
"artists": artists,
"prototype_tensor": prototype_tensor,
"prototype_descriptors": descriptors,
"custom_key_to_artist_id": key_to_artist_id,
}
def _auto_extract_views(image: Image.Image, device_name: str) -> Tuple[Optional[Image.Image], Optional[Image.Image], Optional[Tuple[int, int, int, int]], Optional[Tuple[int, int, int, int]], str]:
extractor = web_ui.load_extractor(
_env_path("YOLO_ANIME_ROOT", ROOT / "yolov5_anime"),
_env_path("YOLO_WEIGHTS", ROOT / "yolov5_anime" / "weights" / "yolov5s_anime.pt"),
_env_path("EYE_CASCADE", ROOT / "anime-eyes-cascade.xml"),
device_name,
0.5,
0.5,
640,
9,
0.6,
)
try:
extracted = extractor.extract(image)
except Exception as exc:
return None, None, None, None, f"Auto crop failed ({exc}); using whole image only."
return extracted.face, extracted.eye, extracted.face_box, extracted.eye_box, extracted.status
@spaces.GPU(duration=90)
def analyze(image: Image.Image, top_k: int, use_tta: bool, device_name: str):
if image is None:
raise gr.Error("Image is required.")
runtime = load_combined_runtime(device_name)
face, eye, face_box, eye_box, status = _auto_extract_views(image, device_name)
full, face_tensor, eye_tensor, view_mask = web_ui.prepare_inputs(image, face, eye, runtime["checkpoint_args"], runtime["device"])
with torch.no_grad():
query_outputs = _encode_query(runtime["model"], full, face_tensor, eye_tensor, view_mask=view_mask, use_tta=use_tta)
query = F.normalize(query_outputs["embedding"][0], dim=0)
rows, best_artist, best_proto_idx, best_score = web_ui.rank_artists(query, runtime["artists"], runtime["prototype_tensor"], int(top_k))
descriptor = runtime["prototype_descriptors"][best_artist][best_proto_idx]
view_attention_boxes = {}
face_attention_box = web_ui.normalized_box(face_box, image.size)
eye_attention_box = web_ui.normalized_box(eye_box, image.size)
if face_attention_box is not None:
view_attention_boxes["face"] = face_attention_box
if eye_attention_box is not None:
view_attention_boxes["eye"] = eye_attention_box
explanation = web_ui.explain_against_reference(
runtime["model"],
full,
face_tensor,
eye_tensor,
view_mask,
descriptor,
view_attention_boxes=view_attention_boxes,
use_tta=use_tta,
)
overlay = web_ui.make_overlay(image, explanation["combined_view_heatmaps"], face_box=face_box, eye_box=eye_box)
branch_html, view_html, summary = web_ui.contribution_bars(explanation)
summary = f"{status} {summary}"
return face, eye, overlay, web_ui.top_match_html(best_artist, best_score), rows, branch_html, view_html, summary
def _load_uploaded_image(file_obj) -> Image.Image:
path = getattr(file_obj, "name", file_obj)
with Image.open(path) as image:
return image.convert("RGB")
@spaces.GPU(duration=180)
def add_artist(display_name: str, files: List[object], use_tta: bool, device_name: str):
display_name = (display_name or "").strip()
if not display_name:
raise gr.Error("Artist name is required.")
if not files or len(files) < 5:
raise gr.Error("Upload at least 5 images.")
runtime = load_combined_runtime(device_name)
encoded_descriptors: List[Dict[str, torch.Tensor]] = []
embeddings = []
statuses = []
for file_obj in files:
image = _load_uploaded_image(file_obj)
face, eye, _face_box, _eye_box, status = _auto_extract_views(image, device_name)
statuses.append(status)
full, face_tensor, eye_tensor, view_mask = web_ui.prepare_inputs(image, face, eye, runtime["checkpoint_args"], runtime["device"])
with torch.no_grad():
outputs = _encode_query(runtime["model"], full, face_tensor, eye_tensor, view_mask=view_mask, use_tta=use_tta)
descriptor = {}
for key in [
"embedding",
"branch_embeddings",
"branch_projected_embeddings",
"weighted_branch_projected_embeddings",
"stacked_view_embeddings",
"stacked_view_weights",
"branch_weights",
"view_mask",
"effective_view_mask",
"effective_branch_mask",
]:
value = outputs[key]
descriptor[key] = value[0].detach().cpu() if torch.is_tensor(value) else value
encoded_descriptors.append(descriptor)
embeddings.append(descriptor["embedding"].float())
stack = torch.stack(embeddings, dim=0)
prototype_count = min(4, len(encoded_descriptors))
centers, assignments = _spherical_kmeans(stack, prototype_count)
prototype_descriptors = []
prototype_vectors = []
for idx in range(prototype_count):
members = [desc for desc, cluster in zip(encoded_descriptors, assignments.tolist()) if cluster == idx]
if not members:
members = [encoded_descriptors[int((stack @ centers[idx]).argmax().item())]]
descriptor = _mean_descriptor(members)
prototype_descriptors.append(descriptor)
prototype_vectors.append(descriptor["embedding"])
artist_id = f"{_slugify(display_name)}-{uuid.uuid4().hex[:10]}"
relative_path = f"artists/{artist_id}.pt"
payload = {
"schema_version": 1,
"artist_id": artist_id,
"display_name": display_name,
"created_at": _now(),
"num_images": len(encoded_descriptors),
"prototype_bank": torch.stack(prototype_vectors, dim=0),
"prototype_descriptors": prototype_descriptors,
"crop_statuses": statuses,
}
try:
_ensure_store()
torch.save(payload, STORE_ROOT / relative_path)
except Exception as exc:
raise gr.Error(f"Failed to save added artist. Check persistent storage: {exc}") from exc
index = _load_index()
index["artists"].append(
{
"artist_id": artist_id,
"display_name": display_name,
"path": relative_path,
"created_at": payload["created_at"],
"num_images": len(encoded_descriptors),
"enabled": True,
}
)
_save_index(index)
web_ui.load_runtime.cache_clear()
return f"Added {display_name}. Refresh the Manage tab to see it."
def added_artist_choices():
index = _load_index()
choices = [
(f"{item['display_name']} ({item['artist_id'][:8]})", item["artist_id"])
for item in index.get("artists", [])
if item.get("enabled", True)
]
return gr.update(choices=choices, value=choices[0][1] if choices else None)
def refresh_added_artists():
status = STORE_ERROR or "List refreshed."
return added_artist_choices(), status
def delete_artist(artist_id: str):
if not artist_id:
raise gr.Error("Select an added artist.")
index = _load_index()
deleted = None
for item in index.get("artists", []):
if item["artist_id"] == artist_id and item.get("enabled", True):
item["enabled"] = False
item["deleted_at"] = _now()
deleted = item["display_name"]
break
if deleted is None:
raise gr.Error("Artist not found or already deleted.")
_save_index(index)
return added_artist_choices(), f"Deleted {deleted}."
def build_app() -> gr.Blocks:
with gr.Blocks(title="Anime Artist Style Embedder", css=web_ui.APP_CSS) as demo:
gr.Markdown("# Anime Artist Style Embedder")
with gr.Tab("Search"):
with gr.Row():
with gr.Column(scale=4):
image = gr.Image(label="Image", type="pil", image_mode="RGB", height=360)
with gr.Row():
top_k = gr.Slider(1, 25, value=10, step=1, label="Top K")
use_tta = gr.Checkbox(value=False, label="TTA")
device_name = gr.Dropdown(["auto", "cpu", "cuda"], value="auto", label="Device")
search_button = gr.Button("Analyze", variant="primary")
with gr.Row():
face = gr.Image(label="Detected Face", type="pil", height=160)
eye = gr.Image(label="Detected Eye", type="pil", height=160)
with gr.Column(scale=7):
overlay = gr.Image(label="Heatmap", type="pil", height=520)
top_match = gr.HTML()
results = gr.Dataframe(headers=["rank", "artist", "score"], datatype=["number", "str", "str"], label="Retrieval")
summary = gr.Textbox(label="Status", lines=2)
with gr.Row():
branch_bars = gr.HTML()
view_bars = gr.HTML()
search_button.click(
analyze,
inputs=[image, top_k, use_tta, device_name],
outputs=[face, eye, overlay, top_match, results, branch_bars, view_bars, summary],
api_name=False,
)
with gr.Tab("Add Artist"):
artist_name = gr.Textbox(label="Artist name")
upload_files = gr.File(label="Images (minimum 5)", file_count="multiple", file_types=["image"])
add_tta = gr.Checkbox(value=False, label="TTA")
add_device = gr.Dropdown(["auto", "cpu", "cuda"], value="auto", label="Device")
add_button = gr.Button("Add Artist", variant="primary")
add_status = gr.Textbox(label="Status", lines=2)
add_button.click(add_artist, inputs=[artist_name, upload_files, add_tta, add_device], outputs=[add_status], api_name=False)
with gr.Tab("Manage Added Artists"):
artist_select = gr.Dropdown(label="Added artists", choices=[])
refresh_button = gr.Button("Refresh")
delete_button = gr.Button("Delete selected", variant="stop")
manage_status = gr.Textbox(label="Status", lines=2)
refresh_button.click(refresh_added_artists, outputs=[artist_select, manage_status], api_name=False)
delete_button.click(delete_artist, inputs=[artist_select], outputs=[artist_select, manage_status], api_name=False)
demo.load(refresh_added_artists, outputs=[artist_select, manage_status], api_name=False)
return demo.queue()
if __name__ == "__main__":
build_app().launch(server_name="0.0.0.0", server_port=7860, show_error=True)