from __future__ import annotations import json import os import re import uuid from datetime import datetime, timezone from pathlib import Path from typing import Dict, List, Optional, Tuple import gradio as gr import torch import torch.nn.functional as F from PIL import Image import web_ui from artist_style_dinov3.explain import _encode_query try: import spaces except Exception: class spaces: @staticmethod def GPU(*_args, **_kwargs): def decorator(fn): return fn return decorator ROOT = Path(__file__).resolve().parent STORE_ROOT = Path(os.environ.get("CUSTOM_STORE_DIR") or ("/data/custom_store" if Path("/data").exists() else ROOT / "custom_store")) INDEX_PATH = STORE_ROOT / "index.json" ARTISTS_DIR = STORE_ROOT / "artists" STORE_ERROR: Optional[str] = None def _env_path(name: str, default: Path) -> str: return str(Path(os.environ[name]).expanduser()) if name in os.environ and os.environ[name].strip() else str(default) def _ensure_store() -> None: global STORE_ERROR ARTISTS_DIR.mkdir(parents=True, exist_ok=True) if not INDEX_PATH.exists(): INDEX_PATH.write_text(json.dumps({"schema_version": 1, "artists": []}, indent=2), encoding="utf-8") STORE_ERROR = None def _load_index() -> dict: try: _ensure_store() except Exception as exc: global STORE_ERROR STORE_ERROR = f"Added-artist storage is not writable: {exc}" return {"schema_version": 1, "artists": []} try: payload = json.loads(INDEX_PATH.read_text(encoding="utf-8")) except Exception: payload = {"schema_version": 1, "artists": []} payload.setdefault("schema_version", 1) payload.setdefault("artists", []) return payload def _save_index(payload: dict) -> None: try: _ensure_store() except Exception as exc: raise gr.Error(f"Added-artist storage is not writable: {exc}") from exc tmp = INDEX_PATH.with_suffix(".tmp") tmp.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") tmp.replace(INDEX_PATH) def _slugify(name: str) -> str: slug = re.sub(r"[^a-zA-Z0-9._-]+", "-", name.strip()).strip("-._").lower() return slug or "artist" def _now() -> str: return datetime.now(timezone.utc).isoformat() def _mean_descriptor(descriptors: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: if not descriptors: raise ValueError("no descriptors") result = {} for key in [ "embedding", "branch_embeddings", "branch_projected_embeddings", "weighted_branch_projected_embeddings", "stacked_view_embeddings", "stacked_view_weights", "branch_weights", "view_mask", "effective_view_mask", "effective_branch_mask", ]: values = [item[key].detach().float().cpu() for item in descriptors if key in item] if not values: continue stacked = torch.stack(values, dim=0) mean = stacked.mean(dim=0) if key in {"embedding", "branch_embeddings", "branch_projected_embeddings", "stacked_view_embeddings"}: mean = F.normalize(mean.float(), dim=-1) elif key in {"stacked_view_weights", "branch_weights"}: mean = mean / mean.sum(dim=-1, keepdim=True).clamp_min(1e-6) elif key in {"view_mask", "effective_view_mask", "effective_branch_mask"}: mean = (mean > 0).float() result[key] = mean if "weighted_branch_projected_embeddings" in result: result["embedding"] = F.normalize(result["weighted_branch_projected_embeddings"].reshape(-1).float(), dim=0) return result def _spherical_kmeans(vectors: torch.Tensor, k: int, iterations: int = 25) -> Tuple[torch.Tensor, torch.Tensor]: vectors = F.normalize(vectors.float(), dim=-1) k = max(1, min(int(k), vectors.size(0))) centers = vectors[torch.linspace(0, vectors.size(0) - 1, steps=k).round().long()].clone() assignments = torch.zeros(vectors.size(0), dtype=torch.long) for _ in range(iterations): assignments = (vectors @ centers.t()).argmax(dim=1) next_centers = [] for idx in range(k): members = vectors[assignments == idx] if members.numel() == 0: next_centers.append(centers[idx]) else: next_centers.append(F.normalize(members.mean(dim=0), dim=0)) centers = torch.stack(next_centers, dim=0) return centers, assignments def _custom_display_key(display_name: str, artist_id: str, existing: set[str]) -> str: if display_name not in existing: return display_name return f"{display_name} ({artist_id[:8]})" def load_custom_bank(existing_names: Optional[set[str]] = None) -> Tuple[List[str], Dict[str, torch.Tensor], Dict[str, list], Dict[str, str]]: existing_names = existing_names or set() index = _load_index() names: List[str] = [] bank: Dict[str, torch.Tensor] = {} descriptors: Dict[str, list] = {} key_to_artist_id: Dict[str, str] = {} for item in index.get("artists", []): if not item.get("enabled", True): continue path = STORE_ROOT / item["path"] if not path.exists(): continue payload = torch.load(path, map_location="cpu") key = _custom_display_key(item["display_name"], item["artist_id"], existing_names | set(names)) names.append(key) bank[key] = payload["prototype_bank"] descriptors[key] = payload["prototype_descriptors"] key_to_artist_id[key] = item["artist_id"] return names, bank, descriptors, key_to_artist_id def load_combined_runtime(device_name: str) -> dict: runtime = web_ui.load_runtime( _env_path("ARTIST_STYLE_CHECKPOINT", ROOT / "artifacts" / "style_training_dinov3" / "best.pt"), _env_path("ARTIST_PROTOTYPES", ROOT / "artifacts" / "style_training_dinov3" / "artist_prototypes.pt"), _env_path("DINOV3_ROOT", ROOT / "third_party" / "dinov3"), _env_path( "DINOV3_WEIGHTS", ROOT / "artifacts" / "pretrained" / "dinov3" / "dinov3_vits16_pretrain_lvd1689m-08c60483.pth", ), device_name, ) base_artists = list(runtime["artists"]) base_bank = { artist: runtime["prototype_tensor"][idx].detach().cpu() for idx, artist in enumerate(base_artists) } base_descriptors = dict(runtime["prototype_descriptors"]) custom_names, custom_bank, custom_descriptors, key_to_artist_id = load_custom_bank(set(base_artists)) artists = base_artists + custom_names bank = {**base_bank, **custom_bank} descriptors = {**base_descriptors, **custom_descriptors} prototype_tensor = torch.stack([bank[artist] for artist in artists], dim=0).float() prototype_tensor = F.normalize(prototype_tensor, dim=-1).to(runtime["device"]) return { **runtime, "artists": artists, "prototype_tensor": prototype_tensor, "prototype_descriptors": descriptors, "custom_key_to_artist_id": key_to_artist_id, } def _auto_extract_views(image: Image.Image, device_name: str) -> Tuple[Optional[Image.Image], Optional[Image.Image], Optional[Tuple[int, int, int, int]], Optional[Tuple[int, int, int, int]], str]: extractor = web_ui.load_extractor( _env_path("YOLO_ANIME_ROOT", ROOT / "yolov5_anime"), _env_path("YOLO_WEIGHTS", ROOT / "yolov5_anime" / "weights" / "yolov5s_anime.pt"), _env_path("EYE_CASCADE", ROOT / "anime-eyes-cascade.xml"), device_name, 0.5, 0.5, 640, 9, 0.6, ) try: extracted = extractor.extract(image) except Exception as exc: return None, None, None, None, f"Auto crop failed ({exc}); using whole image only." return extracted.face, extracted.eye, extracted.face_box, extracted.eye_box, extracted.status @spaces.GPU(duration=90) def analyze(image: Image.Image, top_k: int, use_tta: bool, device_name: str): if image is None: raise gr.Error("Image is required.") runtime = load_combined_runtime(device_name) face, eye, face_box, eye_box, status = _auto_extract_views(image, device_name) full, face_tensor, eye_tensor, view_mask = web_ui.prepare_inputs(image, face, eye, runtime["checkpoint_args"], runtime["device"]) with torch.no_grad(): query_outputs = _encode_query(runtime["model"], full, face_tensor, eye_tensor, view_mask=view_mask, use_tta=use_tta) query = F.normalize(query_outputs["embedding"][0], dim=0) rows, best_artist, best_proto_idx, best_score = web_ui.rank_artists(query, runtime["artists"], runtime["prototype_tensor"], int(top_k)) descriptor = runtime["prototype_descriptors"][best_artist][best_proto_idx] view_attention_boxes = {} face_attention_box = web_ui.normalized_box(face_box, image.size) eye_attention_box = web_ui.normalized_box(eye_box, image.size) if face_attention_box is not None: view_attention_boxes["face"] = face_attention_box if eye_attention_box is not None: view_attention_boxes["eye"] = eye_attention_box explanation = web_ui.explain_against_reference( runtime["model"], full, face_tensor, eye_tensor, view_mask, descriptor, view_attention_boxes=view_attention_boxes, use_tta=use_tta, ) overlay = web_ui.make_overlay(image, explanation["combined_view_heatmaps"], face_box=face_box, eye_box=eye_box) branch_html, view_html, summary = web_ui.contribution_bars(explanation) summary = f"{status} {summary}" return face, eye, overlay, web_ui.top_match_html(best_artist, best_score), rows, branch_html, view_html, summary def _load_uploaded_image(file_obj) -> Image.Image: path = getattr(file_obj, "name", file_obj) with Image.open(path) as image: return image.convert("RGB") @spaces.GPU(duration=180) def add_artist(display_name: str, files: List[object], use_tta: bool, device_name: str): display_name = (display_name or "").strip() if not display_name: raise gr.Error("Artist name is required.") if not files or len(files) < 5: raise gr.Error("Upload at least 5 images.") runtime = load_combined_runtime(device_name) encoded_descriptors: List[Dict[str, torch.Tensor]] = [] embeddings = [] statuses = [] for file_obj in files: image = _load_uploaded_image(file_obj) face, eye, _face_box, _eye_box, status = _auto_extract_views(image, device_name) statuses.append(status) full, face_tensor, eye_tensor, view_mask = web_ui.prepare_inputs(image, face, eye, runtime["checkpoint_args"], runtime["device"]) with torch.no_grad(): outputs = _encode_query(runtime["model"], full, face_tensor, eye_tensor, view_mask=view_mask, use_tta=use_tta) descriptor = {} for key in [ "embedding", "branch_embeddings", "branch_projected_embeddings", "weighted_branch_projected_embeddings", "stacked_view_embeddings", "stacked_view_weights", "branch_weights", "view_mask", "effective_view_mask", "effective_branch_mask", ]: value = outputs[key] descriptor[key] = value[0].detach().cpu() if torch.is_tensor(value) else value encoded_descriptors.append(descriptor) embeddings.append(descriptor["embedding"].float()) stack = torch.stack(embeddings, dim=0) prototype_count = min(4, len(encoded_descriptors)) centers, assignments = _spherical_kmeans(stack, prototype_count) prototype_descriptors = [] prototype_vectors = [] for idx in range(prototype_count): members = [desc for desc, cluster in zip(encoded_descriptors, assignments.tolist()) if cluster == idx] if not members: members = [encoded_descriptors[int((stack @ centers[idx]).argmax().item())]] descriptor = _mean_descriptor(members) prototype_descriptors.append(descriptor) prototype_vectors.append(descriptor["embedding"]) artist_id = f"{_slugify(display_name)}-{uuid.uuid4().hex[:10]}" relative_path = f"artists/{artist_id}.pt" payload = { "schema_version": 1, "artist_id": artist_id, "display_name": display_name, "created_at": _now(), "num_images": len(encoded_descriptors), "prototype_bank": torch.stack(prototype_vectors, dim=0), "prototype_descriptors": prototype_descriptors, "crop_statuses": statuses, } try: _ensure_store() torch.save(payload, STORE_ROOT / relative_path) except Exception as exc: raise gr.Error(f"Failed to save added artist. Check persistent storage: {exc}") from exc index = _load_index() index["artists"].append( { "artist_id": artist_id, "display_name": display_name, "path": relative_path, "created_at": payload["created_at"], "num_images": len(encoded_descriptors), "enabled": True, } ) _save_index(index) web_ui.load_runtime.cache_clear() return f"Added {display_name}. Refresh the Manage tab to see it." def added_artist_choices(): index = _load_index() choices = [ (f"{item['display_name']} ({item['artist_id'][:8]})", item["artist_id"]) for item in index.get("artists", []) if item.get("enabled", True) ] return gr.update(choices=choices, value=choices[0][1] if choices else None) def refresh_added_artists(): status = STORE_ERROR or "List refreshed." return added_artist_choices(), status def delete_artist(artist_id: str): if not artist_id: raise gr.Error("Select an added artist.") index = _load_index() deleted = None for item in index.get("artists", []): if item["artist_id"] == artist_id and item.get("enabled", True): item["enabled"] = False item["deleted_at"] = _now() deleted = item["display_name"] break if deleted is None: raise gr.Error("Artist not found or already deleted.") _save_index(index) return added_artist_choices(), f"Deleted {deleted}." def build_app() -> gr.Blocks: with gr.Blocks(title="Anime Artist Style Embedder", css=web_ui.APP_CSS) as demo: gr.Markdown("# Anime Artist Style Embedder") with gr.Tab("Search"): with gr.Row(): with gr.Column(scale=4): image = gr.Image(label="Image", type="pil", image_mode="RGB", height=360) with gr.Row(): top_k = gr.Slider(1, 25, value=10, step=1, label="Top K") use_tta = gr.Checkbox(value=False, label="TTA") device_name = gr.Dropdown(["auto", "cpu", "cuda"], value="auto", label="Device") search_button = gr.Button("Analyze", variant="primary") with gr.Row(): face = gr.Image(label="Detected Face", type="pil", height=160) eye = gr.Image(label="Detected Eye", type="pil", height=160) with gr.Column(scale=7): overlay = gr.Image(label="Heatmap", type="pil", height=520) top_match = gr.HTML() results = gr.Dataframe(headers=["rank", "artist", "score"], datatype=["number", "str", "str"], label="Retrieval") summary = gr.Textbox(label="Status", lines=2) with gr.Row(): branch_bars = gr.HTML() view_bars = gr.HTML() search_button.click( analyze, inputs=[image, top_k, use_tta, device_name], outputs=[face, eye, overlay, top_match, results, branch_bars, view_bars, summary], api_name=False, ) with gr.Tab("Add Artist"): artist_name = gr.Textbox(label="Artist name") upload_files = gr.File(label="Images (minimum 5)", file_count="multiple", file_types=["image"]) add_tta = gr.Checkbox(value=False, label="TTA") add_device = gr.Dropdown(["auto", "cpu", "cuda"], value="auto", label="Device") add_button = gr.Button("Add Artist", variant="primary") add_status = gr.Textbox(label="Status", lines=2) add_button.click(add_artist, inputs=[artist_name, upload_files, add_tta, add_device], outputs=[add_status], api_name=False) with gr.Tab("Manage Added Artists"): artist_select = gr.Dropdown(label="Added artists", choices=[]) refresh_button = gr.Button("Refresh") delete_button = gr.Button("Delete selected", variant="stop") manage_status = gr.Textbox(label="Status", lines=2) refresh_button.click(refresh_added_artists, outputs=[artist_select, manage_status], api_name=False) delete_button.click(delete_artist, inputs=[artist_select], outputs=[artist_select, manage_status], api_name=False) demo.load(refresh_added_artists, outputs=[artist_select, manage_status], api_name=False) return demo.queue() if __name__ == "__main__": build_app().launch(server_name="0.0.0.0", server_port=7860, show_error=True)