Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import json | |
| import os | |
| import re | |
| import uuid | |
| from datetime import datetime, timezone | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Tuple | |
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| import web_ui | |
| from artist_style_dinov3.explain import _encode_query | |
| try: | |
| import spaces | |
| except Exception: | |
| class spaces: | |
| def GPU(*_args, **_kwargs): | |
| def decorator(fn): | |
| return fn | |
| return decorator | |
| ROOT = Path(__file__).resolve().parent | |
| STORE_ROOT = Path(os.environ.get("CUSTOM_STORE_DIR") or ("/data/custom_store" if Path("/data").exists() else ROOT / "custom_store")) | |
| INDEX_PATH = STORE_ROOT / "index.json" | |
| ARTISTS_DIR = STORE_ROOT / "artists" | |
| STORE_ERROR: Optional[str] = None | |
| def _env_path(name: str, default: Path) -> str: | |
| return str(Path(os.environ[name]).expanduser()) if name in os.environ and os.environ[name].strip() else str(default) | |
| def _ensure_store() -> None: | |
| global STORE_ERROR | |
| ARTISTS_DIR.mkdir(parents=True, exist_ok=True) | |
| if not INDEX_PATH.exists(): | |
| INDEX_PATH.write_text(json.dumps({"schema_version": 1, "artists": []}, indent=2), encoding="utf-8") | |
| STORE_ERROR = None | |
| def _load_index() -> dict: | |
| try: | |
| _ensure_store() | |
| except Exception as exc: | |
| global STORE_ERROR | |
| STORE_ERROR = f"Added-artist storage is not writable: {exc}" | |
| return {"schema_version": 1, "artists": []} | |
| try: | |
| payload = json.loads(INDEX_PATH.read_text(encoding="utf-8")) | |
| except Exception: | |
| payload = {"schema_version": 1, "artists": []} | |
| payload.setdefault("schema_version", 1) | |
| payload.setdefault("artists", []) | |
| return payload | |
| def _save_index(payload: dict) -> None: | |
| try: | |
| _ensure_store() | |
| except Exception as exc: | |
| raise gr.Error(f"Added-artist storage is not writable: {exc}") from exc | |
| tmp = INDEX_PATH.with_suffix(".tmp") | |
| tmp.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") | |
| tmp.replace(INDEX_PATH) | |
| def _slugify(name: str) -> str: | |
| slug = re.sub(r"[^a-zA-Z0-9._-]+", "-", name.strip()).strip("-._").lower() | |
| return slug or "artist" | |
| def _now() -> str: | |
| return datetime.now(timezone.utc).isoformat() | |
| def _mean_descriptor(descriptors: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: | |
| if not descriptors: | |
| raise ValueError("no descriptors") | |
| result = {} | |
| for key in [ | |
| "embedding", | |
| "branch_embeddings", | |
| "branch_projected_embeddings", | |
| "weighted_branch_projected_embeddings", | |
| "stacked_view_embeddings", | |
| "stacked_view_weights", | |
| "branch_weights", | |
| "view_mask", | |
| "effective_view_mask", | |
| "effective_branch_mask", | |
| ]: | |
| values = [item[key].detach().float().cpu() for item in descriptors if key in item] | |
| if not values: | |
| continue | |
| stacked = torch.stack(values, dim=0) | |
| mean = stacked.mean(dim=0) | |
| if key in {"embedding", "branch_embeddings", "branch_projected_embeddings", "stacked_view_embeddings"}: | |
| mean = F.normalize(mean.float(), dim=-1) | |
| elif key in {"stacked_view_weights", "branch_weights"}: | |
| mean = mean / mean.sum(dim=-1, keepdim=True).clamp_min(1e-6) | |
| elif key in {"view_mask", "effective_view_mask", "effective_branch_mask"}: | |
| mean = (mean > 0).float() | |
| result[key] = mean | |
| if "weighted_branch_projected_embeddings" in result: | |
| result["embedding"] = F.normalize(result["weighted_branch_projected_embeddings"].reshape(-1).float(), dim=0) | |
| return result | |
| def _spherical_kmeans(vectors: torch.Tensor, k: int, iterations: int = 25) -> Tuple[torch.Tensor, torch.Tensor]: | |
| vectors = F.normalize(vectors.float(), dim=-1) | |
| k = max(1, min(int(k), vectors.size(0))) | |
| centers = vectors[torch.linspace(0, vectors.size(0) - 1, steps=k).round().long()].clone() | |
| assignments = torch.zeros(vectors.size(0), dtype=torch.long) | |
| for _ in range(iterations): | |
| assignments = (vectors @ centers.t()).argmax(dim=1) | |
| next_centers = [] | |
| for idx in range(k): | |
| members = vectors[assignments == idx] | |
| if members.numel() == 0: | |
| next_centers.append(centers[idx]) | |
| else: | |
| next_centers.append(F.normalize(members.mean(dim=0), dim=0)) | |
| centers = torch.stack(next_centers, dim=0) | |
| return centers, assignments | |
| def _custom_display_key(display_name: str, artist_id: str, existing: set[str]) -> str: | |
| if display_name not in existing: | |
| return display_name | |
| return f"{display_name} ({artist_id[:8]})" | |
| def load_custom_bank(existing_names: Optional[set[str]] = None) -> Tuple[List[str], Dict[str, torch.Tensor], Dict[str, list], Dict[str, str]]: | |
| existing_names = existing_names or set() | |
| index = _load_index() | |
| names: List[str] = [] | |
| bank: Dict[str, torch.Tensor] = {} | |
| descriptors: Dict[str, list] = {} | |
| key_to_artist_id: Dict[str, str] = {} | |
| for item in index.get("artists", []): | |
| if not item.get("enabled", True): | |
| continue | |
| path = STORE_ROOT / item["path"] | |
| if not path.exists(): | |
| continue | |
| payload = torch.load(path, map_location="cpu") | |
| key = _custom_display_key(item["display_name"], item["artist_id"], existing_names | set(names)) | |
| names.append(key) | |
| bank[key] = payload["prototype_bank"] | |
| descriptors[key] = payload["prototype_descriptors"] | |
| key_to_artist_id[key] = item["artist_id"] | |
| return names, bank, descriptors, key_to_artist_id | |
| def load_combined_runtime(device_name: str) -> dict: | |
| runtime = web_ui.load_runtime( | |
| _env_path("ARTIST_STYLE_CHECKPOINT", ROOT / "artifacts" / "style_training_dinov3" / "best.pt"), | |
| _env_path("ARTIST_PROTOTYPES", ROOT / "artifacts" / "style_training_dinov3" / "artist_prototypes.pt"), | |
| _env_path("DINOV3_ROOT", ROOT / "third_party" / "dinov3"), | |
| _env_path( | |
| "DINOV3_WEIGHTS", | |
| ROOT / "artifacts" / "pretrained" / "dinov3" / "dinov3_vits16_pretrain_lvd1689m-08c60483.pth", | |
| ), | |
| device_name, | |
| ) | |
| base_artists = list(runtime["artists"]) | |
| base_bank = { | |
| artist: runtime["prototype_tensor"][idx].detach().cpu() | |
| for idx, artist in enumerate(base_artists) | |
| } | |
| base_descriptors = dict(runtime["prototype_descriptors"]) | |
| custom_names, custom_bank, custom_descriptors, key_to_artist_id = load_custom_bank(set(base_artists)) | |
| artists = base_artists + custom_names | |
| bank = {**base_bank, **custom_bank} | |
| descriptors = {**base_descriptors, **custom_descriptors} | |
| prototype_tensor = torch.stack([bank[artist] for artist in artists], dim=0).float() | |
| prototype_tensor = F.normalize(prototype_tensor, dim=-1).to(runtime["device"]) | |
| return { | |
| **runtime, | |
| "artists": artists, | |
| "prototype_tensor": prototype_tensor, | |
| "prototype_descriptors": descriptors, | |
| "custom_key_to_artist_id": key_to_artist_id, | |
| } | |
| def _auto_extract_views(image: Image.Image, device_name: str) -> Tuple[Optional[Image.Image], Optional[Image.Image], Optional[Tuple[int, int, int, int]], Optional[Tuple[int, int, int, int]], str]: | |
| extractor = web_ui.load_extractor( | |
| _env_path("YOLO_ANIME_ROOT", ROOT / "yolov5_anime"), | |
| _env_path("YOLO_WEIGHTS", ROOT / "yolov5_anime" / "weights" / "yolov5s_anime.pt"), | |
| _env_path("EYE_CASCADE", ROOT / "anime-eyes-cascade.xml"), | |
| device_name, | |
| 0.5, | |
| 0.5, | |
| 640, | |
| 9, | |
| 0.6, | |
| ) | |
| try: | |
| extracted = extractor.extract(image) | |
| except Exception as exc: | |
| return None, None, None, None, f"Auto crop failed ({exc}); using whole image only." | |
| return extracted.face, extracted.eye, extracted.face_box, extracted.eye_box, extracted.status | |
| def analyze(image: Image.Image, top_k: int, use_tta: bool, device_name: str): | |
| if image is None: | |
| raise gr.Error("Image is required.") | |
| runtime = load_combined_runtime(device_name) | |
| face, eye, face_box, eye_box, status = _auto_extract_views(image, device_name) | |
| full, face_tensor, eye_tensor, view_mask = web_ui.prepare_inputs(image, face, eye, runtime["checkpoint_args"], runtime["device"]) | |
| with torch.no_grad(): | |
| query_outputs = _encode_query(runtime["model"], full, face_tensor, eye_tensor, view_mask=view_mask, use_tta=use_tta) | |
| query = F.normalize(query_outputs["embedding"][0], dim=0) | |
| rows, best_artist, best_proto_idx, best_score = web_ui.rank_artists(query, runtime["artists"], runtime["prototype_tensor"], int(top_k)) | |
| descriptor = runtime["prototype_descriptors"][best_artist][best_proto_idx] | |
| view_attention_boxes = {} | |
| face_attention_box = web_ui.normalized_box(face_box, image.size) | |
| eye_attention_box = web_ui.normalized_box(eye_box, image.size) | |
| if face_attention_box is not None: | |
| view_attention_boxes["face"] = face_attention_box | |
| if eye_attention_box is not None: | |
| view_attention_boxes["eye"] = eye_attention_box | |
| explanation = web_ui.explain_against_reference( | |
| runtime["model"], | |
| full, | |
| face_tensor, | |
| eye_tensor, | |
| view_mask, | |
| descriptor, | |
| view_attention_boxes=view_attention_boxes, | |
| use_tta=use_tta, | |
| ) | |
| overlay = web_ui.make_overlay(image, explanation["combined_view_heatmaps"], face_box=face_box, eye_box=eye_box) | |
| branch_html, view_html, summary = web_ui.contribution_bars(explanation) | |
| summary = f"{status} {summary}" | |
| return face, eye, overlay, web_ui.top_match_html(best_artist, best_score), rows, branch_html, view_html, summary | |
| def _load_uploaded_image(file_obj) -> Image.Image: | |
| path = getattr(file_obj, "name", file_obj) | |
| with Image.open(path) as image: | |
| return image.convert("RGB") | |
| def add_artist(display_name: str, files: List[object], use_tta: bool, device_name: str): | |
| display_name = (display_name or "").strip() | |
| if not display_name: | |
| raise gr.Error("Artist name is required.") | |
| if not files or len(files) < 5: | |
| raise gr.Error("Upload at least 5 images.") | |
| runtime = load_combined_runtime(device_name) | |
| encoded_descriptors: List[Dict[str, torch.Tensor]] = [] | |
| embeddings = [] | |
| statuses = [] | |
| for file_obj in files: | |
| image = _load_uploaded_image(file_obj) | |
| face, eye, _face_box, _eye_box, status = _auto_extract_views(image, device_name) | |
| statuses.append(status) | |
| full, face_tensor, eye_tensor, view_mask = web_ui.prepare_inputs(image, face, eye, runtime["checkpoint_args"], runtime["device"]) | |
| with torch.no_grad(): | |
| outputs = _encode_query(runtime["model"], full, face_tensor, eye_tensor, view_mask=view_mask, use_tta=use_tta) | |
| descriptor = {} | |
| for key in [ | |
| "embedding", | |
| "branch_embeddings", | |
| "branch_projected_embeddings", | |
| "weighted_branch_projected_embeddings", | |
| "stacked_view_embeddings", | |
| "stacked_view_weights", | |
| "branch_weights", | |
| "view_mask", | |
| "effective_view_mask", | |
| "effective_branch_mask", | |
| ]: | |
| value = outputs[key] | |
| descriptor[key] = value[0].detach().cpu() if torch.is_tensor(value) else value | |
| encoded_descriptors.append(descriptor) | |
| embeddings.append(descriptor["embedding"].float()) | |
| stack = torch.stack(embeddings, dim=0) | |
| prototype_count = min(4, len(encoded_descriptors)) | |
| centers, assignments = _spherical_kmeans(stack, prototype_count) | |
| prototype_descriptors = [] | |
| prototype_vectors = [] | |
| for idx in range(prototype_count): | |
| members = [desc for desc, cluster in zip(encoded_descriptors, assignments.tolist()) if cluster == idx] | |
| if not members: | |
| members = [encoded_descriptors[int((stack @ centers[idx]).argmax().item())]] | |
| descriptor = _mean_descriptor(members) | |
| prototype_descriptors.append(descriptor) | |
| prototype_vectors.append(descriptor["embedding"]) | |
| artist_id = f"{_slugify(display_name)}-{uuid.uuid4().hex[:10]}" | |
| relative_path = f"artists/{artist_id}.pt" | |
| payload = { | |
| "schema_version": 1, | |
| "artist_id": artist_id, | |
| "display_name": display_name, | |
| "created_at": _now(), | |
| "num_images": len(encoded_descriptors), | |
| "prototype_bank": torch.stack(prototype_vectors, dim=0), | |
| "prototype_descriptors": prototype_descriptors, | |
| "crop_statuses": statuses, | |
| } | |
| try: | |
| _ensure_store() | |
| torch.save(payload, STORE_ROOT / relative_path) | |
| except Exception as exc: | |
| raise gr.Error(f"Failed to save added artist. Check persistent storage: {exc}") from exc | |
| index = _load_index() | |
| index["artists"].append( | |
| { | |
| "artist_id": artist_id, | |
| "display_name": display_name, | |
| "path": relative_path, | |
| "created_at": payload["created_at"], | |
| "num_images": len(encoded_descriptors), | |
| "enabled": True, | |
| } | |
| ) | |
| _save_index(index) | |
| web_ui.load_runtime.cache_clear() | |
| return f"Added {display_name}. Refresh the Manage tab to see it." | |
| def added_artist_choices(): | |
| index = _load_index() | |
| choices = [ | |
| (f"{item['display_name']} ({item['artist_id'][:8]})", item["artist_id"]) | |
| for item in index.get("artists", []) | |
| if item.get("enabled", True) | |
| ] | |
| return gr.update(choices=choices, value=choices[0][1] if choices else None) | |
| def refresh_added_artists(): | |
| status = STORE_ERROR or "List refreshed." | |
| return added_artist_choices(), status | |
| def delete_artist(artist_id: str): | |
| if not artist_id: | |
| raise gr.Error("Select an added artist.") | |
| index = _load_index() | |
| deleted = None | |
| for item in index.get("artists", []): | |
| if item["artist_id"] == artist_id and item.get("enabled", True): | |
| item["enabled"] = False | |
| item["deleted_at"] = _now() | |
| deleted = item["display_name"] | |
| break | |
| if deleted is None: | |
| raise gr.Error("Artist not found or already deleted.") | |
| _save_index(index) | |
| return added_artist_choices(), f"Deleted {deleted}." | |
| def build_app() -> gr.Blocks: | |
| with gr.Blocks(title="Anime Artist Style Embedder", css=web_ui.APP_CSS) as demo: | |
| gr.Markdown("# Anime Artist Style Embedder") | |
| with gr.Tab("Search"): | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| image = gr.Image(label="Image", type="pil", image_mode="RGB", height=360) | |
| with gr.Row(): | |
| top_k = gr.Slider(1, 25, value=10, step=1, label="Top K") | |
| use_tta = gr.Checkbox(value=False, label="TTA") | |
| device_name = gr.Dropdown(["auto", "cpu", "cuda"], value="auto", label="Device") | |
| search_button = gr.Button("Analyze", variant="primary") | |
| with gr.Row(): | |
| face = gr.Image(label="Detected Face", type="pil", height=160) | |
| eye = gr.Image(label="Detected Eye", type="pil", height=160) | |
| with gr.Column(scale=7): | |
| overlay = gr.Image(label="Heatmap", type="pil", height=520) | |
| top_match = gr.HTML() | |
| results = gr.Dataframe(headers=["rank", "artist", "score"], datatype=["number", "str", "str"], label="Retrieval") | |
| summary = gr.Textbox(label="Status", lines=2) | |
| with gr.Row(): | |
| branch_bars = gr.HTML() | |
| view_bars = gr.HTML() | |
| search_button.click( | |
| analyze, | |
| inputs=[image, top_k, use_tta, device_name], | |
| outputs=[face, eye, overlay, top_match, results, branch_bars, view_bars, summary], | |
| api_name=False, | |
| ) | |
| with gr.Tab("Add Artist"): | |
| artist_name = gr.Textbox(label="Artist name") | |
| upload_files = gr.File(label="Images (minimum 5)", file_count="multiple", file_types=["image"]) | |
| add_tta = gr.Checkbox(value=False, label="TTA") | |
| add_device = gr.Dropdown(["auto", "cpu", "cuda"], value="auto", label="Device") | |
| add_button = gr.Button("Add Artist", variant="primary") | |
| add_status = gr.Textbox(label="Status", lines=2) | |
| add_button.click(add_artist, inputs=[artist_name, upload_files, add_tta, add_device], outputs=[add_status], api_name=False) | |
| with gr.Tab("Manage Added Artists"): | |
| artist_select = gr.Dropdown(label="Added artists", choices=[]) | |
| refresh_button = gr.Button("Refresh") | |
| delete_button = gr.Button("Delete selected", variant="stop") | |
| manage_status = gr.Textbox(label="Status", lines=2) | |
| refresh_button.click(refresh_added_artists, outputs=[artist_select, manage_status], api_name=False) | |
| delete_button.click(delete_artist, inputs=[artist_select], outputs=[artist_select, manage_status], api_name=False) | |
| demo.load(refresh_added_artists, outputs=[artist_select, manage_status], api_name=False) | |
| return demo.queue() | |
| if __name__ == "__main__": | |
| build_app().launch(server_name="0.0.0.0", server_port=7860, show_error=True) | |