from __future__ import annotations import argparse import itertools import os import sys from functools import lru_cache from pathlib import Path from typing import Dict, List, NamedTuple, Optional, Tuple import gradio as gr import numpy as np import torch import torch.nn.functional as F from PIL import Image from torchvision.transforms import functional as TF from artist_style_dinov3 import ArtistStyleModel, explain_against_reference from artist_style_dinov3.explain import _encode_query ROOT = Path(__file__).resolve().parent VIEW_NAMES = ["whole", "face", "eye"] MODEL_VIEW_NAMES = ["full", "face", "eye"] BRANCH_NAMES = ["structure", "texture", "line", "color"] # ColorBrewer YlOrRd-style sequential palette. Low values are transparent; high # values move through yellow/orange into red so saliency is visible on grayscale. HEAT_PALETTE = np.array( [ [1.000000, 1.000000, 0.800000], [1.000000, 0.929412, 0.627451], [0.996078, 0.850980, 0.462745], [0.996078, 0.698039, 0.298039], [0.992157, 0.552941, 0.235294], [0.988235, 0.305882, 0.164706], [0.890196, 0.101961, 0.109804], [0.700000, 0.000000, 0.000000], ], dtype=np.float32, ) APP_CSS = """ .app {max-width: 1500px; margin: 0 auto;} .compact-table textarea {font-size: 12px;} #summary-box textarea {font-size: 14px; font-weight: 600;} .bars {display: grid; gap: 10px; margin: 4px 0 14px;} .bar-row {display: grid; grid-template-columns: 82px 1fr 52px; gap: 8px; align-items: center;} .bar-label {font-size: 13px; font-weight: 600;} .bar-track {height: 13px; background: #e8e8e8; border-radius: 999px; overflow: hidden;} .bar-fill {height: 100%;} .bar-value {font-variant-numeric: tabular-nums; font-size: 12px; text-align: right;} .bar-title {font-size: 15px; font-weight: 700; margin: 0 0 3px;} .bar-note {font-size: 12px; color: #666; margin: 0 0 10px; line-height: 1.35;} .top-match {padding: 12px 14px; border: 1px solid #dedede; border-radius: 8px; margin-bottom: 10px; background: #fafafa;} .top-match-label {font-size: 12px; color: #666; margin-bottom: 2px;} .top-match-artist {font-size: 24px; line-height: 1.15; font-weight: 800;} .top-match-score {font-size: 18px; font-weight: 700; color: #b91c1c; margin-top: 4px;} """ BRANCH_COLORS = { "structure": "#4e79a7", "texture": "#f28e2b", "line": "#e15759", "color": "#59a14f", } VIEW_COLORS = { "whole": "#7b61ff", "face": "#ff7f0e", "eye": "#d62728", } def env_path(name: str, default: Path) -> str: return str(Path(os.environ[name]).expanduser()) if name in os.environ and os.environ[name].strip() else str(default) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Gradio UI for artist style retrieval and attribution.") parser.add_argument( "--checkpoint", type=str, default=env_path("ARTIST_STYLE_CHECKPOINT", ROOT / "artifacts" / "style_training_dinov3" / "best.pt"), ) parser.add_argument( "--prototype-bank", type=str, default=env_path("ARTIST_PROTOTYPES", ROOT / "artifacts" / "style_training_dinov3" / "artist_prototypes.pt"), ) parser.add_argument("--dinov3-root", type=str, default=env_path("DINOV3_ROOT", ROOT / "third_party" / "dinov3")) parser.add_argument( "--dinov3-weights", type=str, default=env_path( "DINOV3_WEIGHTS", ROOT / "artifacts" / "pretrained" / "dinov3" / "dinov3_vits16_pretrain_lvd1689m-08c60483.pth", ), ) parser.add_argument("--device", choices=["auto", "cpu", "cuda"], default="auto") parser.add_argument("--yolo-dir", type=str, default=env_path("YOLO_ANIME_ROOT", ROOT / "yolov5_anime")) parser.add_argument( "--yolo-weights", type=str, default=env_path("YOLO_WEIGHTS", ROOT / "yolov5_anime" / "weights" / "yolov5s_anime.pt"), ) parser.add_argument("--eye-cascade", type=str, default=env_path("EYE_CASCADE", ROOT / "anime-eyes-cascade.xml")) parser.add_argument("--face-conf", type=float, default=0.5) parser.add_argument("--face-iou", type=float, default=0.5) parser.add_argument("--face-imgsz", type=int, default=640) parser.add_argument("--eye-neighbors", type=int, default=9) parser.add_argument("--eye-margin", type=float, default=0.6) parser.add_argument("--server-name", type=str, default="127.0.0.1") parser.add_argument("--server-port", type=int, default=7860) parser.add_argument("--share", action="store_true") return parser.parse_args() class ExtractedViews(NamedTuple): face: Optional[Image.Image] eye: Optional[Image.Image] face_box: Optional[Tuple[int, int, int, int]] eye_box: Optional[Tuple[int, int, int, int]] status: str def resolve_device(device_name: str) -> torch.device: if device_name == "cuda": if not torch.cuda.is_available(): raise RuntimeError("CUDA was requested but is not available.") return torch.device("cuda") if device_name == "cpu": return torch.device("cpu") return torch.device("cuda" if torch.cuda.is_available() else "cpu") def _patch_torch_load_for_yolov5() -> None: import torch.serialization try: import numpy as _np torch.serialization.add_safe_globals( [ _np.core.multiarray._reconstruct, _np.ndarray, ] ) except Exception: pass original_load = torch.load if getattr(original_load, "_anime_webui_patched", False): return def patched_load(*args, **kwargs): kwargs.setdefault("weights_only", False) return original_load(*args, **kwargs) patched_load._anime_webui_patched = True torch.load = patched_load def _expand_box(box: Tuple[int, int, int, int], margin: float, width: int, height: int) -> Tuple[int, int, int, int]: x1, y1, x2, y2 = box cx = (x1 + x2) / 2.0 cy = (y1 + y2) / 2.0 bw = (x2 - x1) * (1.0 + margin) bh = (y2 - y1) * (1.0 + margin) nx1 = max(0, min(width, int(round(cx - bw / 2.0)))) ny1 = max(0, min(height, int(round(cy - bh / 2.0)))) nx2 = max(0, min(width, int(round(cx + bw / 2.0)))) ny2 = max(0, min(height, int(round(cy + bh / 2.0)))) return nx1, ny1, nx2, ny2 class AnimeFaceEyeExtractor: def __init__( self, yolo_dir: str, yolo_weights: str, eye_cascade: str, device_name: str, conf: float, iou: float, imgsz: int, eye_neighbors: int, eye_margin: float, ) -> None: self.yolo_dir = Path(yolo_dir).resolve() self.yolo_weights = Path(yolo_weights).resolve() self.eye_cascade = Path(eye_cascade).resolve() self.device_name = device_name self.conf = float(conf) self.iou = float(iou) self.imgsz = int(imgsz) self.eye_neighbors = int(eye_neighbors) self.eye_margin = float(eye_margin) self.model = None self.device = None self.stride = 32 self.use_half = False self.cascade = None def _ensure_ready(self) -> None: missing = [ str(path) for path in [self.yolo_dir, self.yolo_weights, self.eye_cascade] if not path.exists() ] if missing: raise gr.Error("Missing auto extraction files: " + ", ".join(missing)) if self.model is not None and self.cascade is not None: return try: import cv2 except ModuleNotFoundError as exc: raise gr.Error("opencv-python is required for auto face/eye extraction.") from exc if str(self.yolo_dir) not in sys.path: sys.path.insert(0, str(self.yolo_dir)) _patch_torch_load_for_yolov5() from models.experimental import attempt_load from utils.torch_utils import select_device yolo_device = "0" if self.device_name in {"auto", "cuda"} and torch.cuda.is_available() else "cpu" self.device = select_device(yolo_device) self.use_half = getattr(self.device, "type", str(self.device)) != "cpu" self.model = attempt_load(str(self.yolo_weights), map_location=self.device) if self.use_half: self.model.half() self.model.eval() self.stride = int(self.model.stride.max()) self.imgsz = int(np.ceil(self.imgsz / self.stride) * self.stride) self.cascade = cv2.CascadeClassifier(str(self.eye_cascade)) if self.cascade.empty(): raise gr.Error(f"Failed to load eye cascade: {self.eye_cascade}") def _letterbox(self, bgr: np.ndarray) -> np.ndarray: from utils.datasets import letterbox try: return letterbox(bgr, (self.imgsz, self.imgsz), stride=self.stride, auto=False)[0] except TypeError: try: return letterbox(bgr, (self.imgsz, self.imgsz), auto=False)[0] except TypeError: return letterbox(bgr, (self.imgsz, self.imgsz))[0] def _detect_faces(self, rgb: np.ndarray) -> List[Tuple[int, int, int, int, float]]: import cv2 from utils.general import non_max_suppression, scale_coords bgr = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR) image = self._letterbox(bgr) image = image[:, :, ::-1].transpose(2, 0, 1) image = np.ascontiguousarray(image) tensor = torch.from_numpy(image).to(self.device) tensor = tensor.half() if self.use_half else tensor.float() tensor = tensor / 255.0 tensor = tensor.unsqueeze(0) with torch.inference_mode(): pred = self.model(tensor)[0] pred = non_max_suppression(pred, conf_thres=self.conf, iou_thres=self.iou, classes=None, agnostic=False)[0] boxes: List[Tuple[int, int, int, int, float]] = [] if pred is not None and len(pred): h0, w0 = rgb.shape[:2] pred[:, :4] = scale_coords((self.imgsz, self.imgsz), pred[:, :4], (h0, w0)).round() for *xyxy, conf, _cls in pred.tolist(): x1, y1, x2, y2 = [int(v) for v in xyxy] boxes.append((x1, y1, x2, y2, float(conf))) return boxes def _detect_eyes(self, face_rgb: np.ndarray) -> List[Tuple[str, Tuple[int, int, int, int]]]: import cv2 height, width = face_rgb.shape[:2] roi_h = max(1, int(height * 0.70)) roi = face_rgb[:roi_h, :] gray = cv2.cvtColor(roi, cv2.COLOR_RGB2GRAY) gray = cv2.GaussianBlur(gray, (3, 3), 0) gray = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)).apply(gray) min_size = max(8, int(0.07 * min(width, roi_h))) raw = self.cascade.detectMultiScale( gray, scaleFactor=1.05, minNeighbors=self.eye_neighbors, minSize=(min_size, min_size), flags=cv2.CASCADE_SCALE_IMAGE, ) if raw is None: return [] if isinstance(raw, tuple): if len(raw) == 0: return [] raw = raw[0] arr = np.asarray(raw) if arr.size == 0: return [] if arr.ndim == 1: arr = arr.reshape(1, -1) boxes = [] for x, y, w, h in arr[:, :4]: if w > 0 and h > 0: boxes.append((int(x), int(y), int(x + w), int(y + h))) return self._best_eye_pair(boxes, width, roi_h) @staticmethod def _best_eye_pair( boxes: List[Tuple[int, int, int, int]], width: int, height: int, ) -> List[Tuple[str, Tuple[int, int, int, int]]]: if len(boxes) < 2: return [("left", boxes[0])] if len(boxes) == 1 else [] def center(box: Tuple[int, int, int, int]) -> Tuple[float, float]: x1, y1, x2, y2 = box return (x1 + x2) / 2.0, (y1 + y2) / 2.0 def area(box: Tuple[int, int, int, int]) -> int: x1, y1, x2, y2 = box return max(1, (x2 - x1) * (y2 - y1)) best_pair = None best_score = float("inf") for first, second in itertools.combinations(boxes, 2): c1x, c1y = center(first) c2x, c2y = center(second) a1, a2 = area(first), area(second) gap = abs(c2x - c1x) / max(1.0, width) gap_penalty = 0.0 if 0.05 <= gap <= 0.5 else 0.5 score = abs(c1y - c2y) / max(1.0, height) + abs(a1 - a2) / max(a1, a2) + gap_penalty if score < best_score: best_score = score best_pair = (first, second) assert best_pair is not None left, right = sorted(best_pair, key=lambda box: box[0] + box[2]) return [("left", left), ("right", right)] def extract(self, image: Image.Image) -> ExtractedViews: self._ensure_ready() rgb = np.asarray(image.convert("RGB")) height, width = rgb.shape[:2] faces = self._detect_faces(rgb) if not faces: return ExtractedViews(None, None, None, None, "No face detected.") face_box_raw = max(faces, key=lambda item: (item[2] - item[0]) * (item[3] - item[1]) * max(item[4], 1e-6)) face_box = tuple(face_box_raw[:4]) x1, y1, x2, y2 = face_box face_rgb = rgb[y1:y2, x1:x2] if face_rgb.size == 0: return ExtractedViews(None, None, None, None, "Detected face crop is empty.") face_image = Image.fromarray(face_rgb) try: eye_labels = self._detect_eyes(face_rgb) except Exception as exc: return ExtractedViews(face_image, None, face_box, None, f"Face detected; eye extraction failed ({exc}).") if not eye_labels: return ExtractedViews(face_image, None, face_box, None, "Face detected; no eye detected.") # The style model was trained with a single eye crop. When both eyes are # detected, match eval/inference behavior by using the left eye. selected_label, selected_box = eye_labels[0] for label, box in eye_labels: if label == "left": selected_label, selected_box = label, box break ex1, ey1, ex2, ey2 = _expand_box(selected_box, self.eye_margin, face_rgb.shape[1], face_rgb.shape[0]) eye_rgb = face_rgb[ey1:ey2, ex1:ex2] if eye_rgb.size == 0: return ExtractedViews(face_image, None, face_box, None, "Face detected; eye crop is empty.") eye_box = (x1 + ex1, y1 + ey1, x1 + ex2, y1 + ey2) return ExtractedViews(face_image, Image.fromarray(eye_rgb), face_box, eye_box, f"Face and {selected_label} eye detected.") @lru_cache(maxsize=2) def load_extractor( yolo_dir: str, yolo_weights: str, eye_cascade: str, device_name: str, conf: float, iou: float, imgsz: int, eye_neighbors: int, eye_margin: float, ) -> AnimeFaceEyeExtractor: return AnimeFaceEyeExtractor( yolo_dir=yolo_dir, yolo_weights=yolo_weights, eye_cascade=eye_cascade, device_name=device_name, conf=conf, iou=iou, imgsz=imgsz, eye_neighbors=eye_neighbors, eye_margin=eye_margin, ) def build_model(checkpoint: Dict[str, object], dinov3_root: str, dinov3_weights: str, device: torch.device) -> ArtistStyleModel: checkpoint_args = checkpoint.get("args", {}) if not isinstance(checkpoint_args, dict): checkpoint_args = {} model = ArtistStyleModel( num_classes=int(checkpoint_args.get("train_artist_count", 1000)), branch_hidden_dim=int(checkpoint_args.get("branch_hidden_dim", 384)), branch_dim=int(checkpoint_args.get("branch_dim", 192)), embedding_dim=int(checkpoint_args.get("embedding_dim", 256)), num_prototypes=int(checkpoint_args.get("num_prototypes", 4)), prototype_temperature=float(checkpoint_args.get("prototype_temperature", 0.1)), arcface_scale=float(checkpoint_args.get("arcface_scale", 30.0)), view_dropout_prob=float(checkpoint_args.get("view_dropout_prob", 0.15)), branch_dropout_prob=float(checkpoint_args.get("branch_dropout_prob", 0.1)), backbone_repo_dir=dinov3_root, backbone_entrypoint=str(checkpoint_args.get("backbone_entrypoint", "dinov3_vits16")), backbone_weights_path=dinov3_weights, freeze_backbone=bool(checkpoint_args.get("freeze_backbone", True)), backbone_unfreeze_last_n_blocks=int(checkpoint_args.get("backbone_unfreeze_last_n_blocks", 0)), ).to(device) model.load_state_dict(checkpoint["model"]) model.eval() return model @lru_cache(maxsize=1) def load_runtime( checkpoint_path: str, prototype_bank_path: str, dinov3_root: str, dinov3_weights: str, device_name: str, ) -> Dict[str, object]: device = resolve_device(device_name) checkpoint = torch.load(checkpoint_path, map_location=device) checkpoint_args = checkpoint.get("args", {}) if not isinstance(checkpoint_args, dict): checkpoint_args = {} model = build_model(checkpoint, dinov3_root, dinov3_weights, device) bank_payload = torch.load(prototype_bank_path, map_location="cpu") prototype_bank = bank_payload["prototype_bank"] prototype_descriptors = bank_payload.get("prototype_descriptors", {}) artists = sorted(prototype_bank) prototype_tensor = torch.stack([prototype_bank[artist] for artist in artists], dim=0).float() prototype_tensor = F.normalize(prototype_tensor, dim=-1).to(device) return { "device": device, "model": model, "checkpoint_args": checkpoint_args, "artists": artists, "prototype_tensor": prototype_tensor, "prototype_descriptors": prototype_descriptors, } def resize_rgb(image: Image.Image, size: int) -> torch.Tensor: image = image.convert("RGB") image = TF.resize(image, [size, size], interpolation=Image.Resampling.BICUBIC) return TF.to_tensor(image) def blank(size: int) -> torch.Tensor: return torch.zeros((3, size, size), dtype=torch.float32) def prepare_inputs( whole_image: Image.Image, face_image: Optional[Image.Image], eye_image: Optional[Image.Image], checkpoint_args: Dict[str, object], device: torch.device, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: full_size = int(checkpoint_args.get("full_image_size", 320)) face_size = int(checkpoint_args.get("face_image_size", 256)) eye_size = int(checkpoint_args.get("eye_image_size", 224)) full = resize_rgb(whole_image, full_size) has_face = face_image is not None has_eye = eye_image is not None face = resize_rgb(face_image, face_size) if has_face else blank(face_size) eye = resize_rgb(eye_image, eye_size) if has_eye else blank(eye_size) view_mask = torch.tensor([[1.0, float(has_face), float(has_eye)]], dtype=torch.float32) return ( full.unsqueeze(0).to(device), face.unsqueeze(0).to(device), eye.unsqueeze(0).to(device), view_mask.to(device), ) def normalized_box(box: Optional[Tuple[int, int, int, int]], image_size: Tuple[int, int]) -> Optional[Tuple[float, float, float, float]]: if box is None: return None width, height = image_size if width <= 0 or height <= 0: return None x1, y1, x2, y2 = box return ( max(0.0, min(1.0, x1 / width)), max(0.0, min(1.0, y1 / height)), max(0.0, min(1.0, x2 / width)), max(0.0, min(1.0, y2 / height)), ) def rank_artists(query: torch.Tensor, artists: List[str], prototype_tensor: torch.Tensor, top_k: int) -> Tuple[List[List[object]], str, int, float]: sims = torch.einsum("d,apd->ap", query.float(), prototype_tensor) scores, proto_indices = sims.max(dim=1) k = min(max(1, int(top_k)), len(artists)) values, artist_indices = torch.topk(scores, k=k) rows: List[List[object]] = [] for rank, (score, artist_idx) in enumerate(zip(values.tolist(), artist_indices.tolist()), start=1): proto_idx = int(proto_indices[artist_idx].item()) rows.append([rank, artists[artist_idx], f"{float(score):.3f}"]) best_artist_idx = int(artist_indices[0].item()) return rows, artists[best_artist_idx], int(proto_indices[best_artist_idx].item()), float(values[0].item()) def tensor_to_map(heatmap: torch.Tensor, size: Tuple[int, int]) -> np.ndarray: data = heatmap.detach().float().squeeze().cpu().numpy() data = np.nan_to_num(data, nan=0.0, posinf=0.0, neginf=0.0) data = data - float(data.min()) data = data / max(float(data.max()), 1e-6) image = Image.fromarray(np.uint8(np.clip(data, 0.0, 1.0) * 255), mode="L") image = image.resize(size, Image.Resampling.BICUBIC) resized = np.asarray(image, dtype=np.float32) / 255.0 return np.clip(resized, 0.0, 1.0) def heat_colors(values: np.ndarray) -> np.ndarray: values = np.clip(values, 0.0, 1.0) scaled = values * (len(HEAT_PALETTE) - 1) left = np.floor(scaled).astype(np.int32) right = np.clip(left + 1, 0, len(HEAT_PALETTE) - 1) frac = scaled[..., None] - left[..., None] colors = HEAT_PALETTE[left] * (1.0 - frac) + HEAT_PALETTE[right] * frac return colors def _paste_map(canvas: np.ndarray, heatmap: torch.Tensor, box: Tuple[int, int, int, int], size: Tuple[int, int]) -> None: x1, y1, x2, y2 = box width, height = size x1 = max(0, min(width, int(x1))) x2 = max(0, min(width, int(x2))) y1 = max(0, min(height, int(y1))) y2 = max(0, min(height, int(y2))) if x2 <= x1 or y2 <= y1: return local = tensor_to_map(heatmap, (x2 - x1, y2 - y1)) canvas[y1:y2, x1:x2] = np.maximum(canvas[y1:y2, x1:x2], local) def make_overlay( base_image: Image.Image, view_heatmaps: Dict[str, Optional[torch.Tensor]], face_box: Optional[Tuple[int, int, int, int]] = None, eye_box: Optional[Tuple[int, int, int, int]] = None, ) -> Image.Image: base = base_image.convert("RGB") original_size = base.size display = base.copy() display.thumbnail((720, 720), Image.Resampling.LANCZOS) size = display.size sx = size[0] / max(1, original_size[0]) sy = size[1] / max(1, original_size[1]) gray = np.asarray(base.convert("L"), dtype=np.float32) / 255.0 gray = np.asarray(display.convert("L"), dtype=np.float32) / 255.0 gray = 0.78 + 0.22 * gray gray_rgb = np.repeat(gray[..., None], 3, axis=2) maps = [] full_heatmap = view_heatmaps.get("full") if full_heatmap is not None: maps.append(tensor_to_map(full_heatmap, size)) spatial = np.zeros((size[1], size[0]), dtype=np.float32) spatial_used = False face_heatmap = view_heatmaps.get("face") if face_heatmap is not None and face_box is not None: scaled_face_box = ( int(round(face_box[0] * sx)), int(round(face_box[1] * sy)), int(round(face_box[2] * sx)), int(round(face_box[3] * sy)), ) _paste_map(spatial, face_heatmap, scaled_face_box, size) spatial_used = True elif face_heatmap is not None: maps.append(tensor_to_map(face_heatmap, size)) eye_heatmap = view_heatmaps.get("eye") if eye_heatmap is not None and eye_box is not None: scaled_eye_box = ( int(round(eye_box[0] * sx)), int(round(eye_box[1] * sy)), int(round(eye_box[2] * sx)), int(round(eye_box[3] * sy)), ) _paste_map(spatial, eye_heatmap, scaled_eye_box, size) spatial_used = True elif eye_heatmap is not None: maps.append(tensor_to_map(eye_heatmap, size)) if spatial_used: maps.append(spatial) if not maps: return Image.fromarray(np.uint8(gray_rgb * 255)) # Equal per-view normalization and averaging keeps larger source crops or # local face/eye maps from dominating purely because of area or scale. combined = np.mean(np.stack(maps, axis=0), axis=0) combined = combined - float(combined.min()) combined = combined / max(float(combined.max()), 1e-6) heat_rgb = heat_colors(combined) alpha = np.clip((combined - 0.05) / 0.95, 0.0, 1.0) alpha = (0.90 * alpha)[..., None] blended = gray_rgb * (1.0 - alpha) + heat_rgb * alpha return Image.fromarray(np.uint8(np.clip(blended, 0.0, 1.0) * 255)) def scalar(value: torch.Tensor) -> float: return float(value.detach().float().flatten()[0].cpu().item()) def _normalize_values(values: List[float]) -> List[float]: clipped = [max(0.0, float(value)) for value in values] total = sum(clipped) if total <= 1e-8: return [1.0 / len(clipped) for _ in clipped] return [value / total for value in clipped] def _bars_html(title: str, note: str, labels: List[str], values: List[float], colors: Dict[str, str]) -> str: rows = [] for label, value in zip(labels, values): pct = max(0.0, min(1.0, float(value))) * 100.0 color = colors.get(label, "#e34a33") rows.append( "
" f"
{label}
" "
" f"
" "
" f"
{pct:.1f}%
" "
" ) return ( "
" f"
{title}
" f"
{note}
" f"{''.join(rows)}
" ) def top_match_html(artist: str, similarity: float) -> str: percent = max(0.0, min(100.0, float(similarity) * 100.0)) return ( "
" "
Top artist match
" f"
{artist}
" f"
{percent:.1f}% similarity
" "
" ) def contribution_bars(explanation: Dict[str, object]) -> Tuple[str, str, str]: query_outputs = explanation["query_outputs"] branch_values = query_outputs["branch_weights"][0].detach().float().cpu().tolist() branch_values = _normalize_values(branch_values) view_weights = query_outputs["stacked_view_weights"][0].detach().float().cpu() branch_tensor = torch.tensor(branch_values, dtype=view_weights.dtype).unsqueeze(-1) view_values = (view_weights * branch_tensor).sum(dim=0).tolist() view_values = _normalize_values(view_values) best_branch = BRANCH_NAMES[int(np.argmax(branch_values))] best_view = VIEW_NAMES[int(np.argmax(view_values))] summary = f"Dominant branch: {best_branch}. Dominant view: {best_view}." return ( _bars_html( "Influential style factors", "Relative weight of structure, texture, linework, and color in the query embedding.", BRANCH_NAMES, branch_values, BRANCH_COLORS, ), _bars_html( "Similar visual regions", "Relative weight of whole image, face crop, and eye crop used for the match.", VIEW_NAMES, view_values, VIEW_COLORS, ), summary, ) def analyze( whole_image: Optional[Image.Image], face_image: Optional[Image.Image], eye_image: Optional[Image.Image], auto_extract: bool, top_k: int, use_tta: bool, checkpoint_path: str, prototype_bank_path: str, dinov3_root: str, dinov3_weights: str, device_name: str, yolo_dir: str, yolo_weights: str, eye_cascade: str, face_conf: float, face_iou: float, face_imgsz: int, eye_neighbors: int, eye_margin: float, ) -> Tuple[Optional[Image.Image], Optional[Image.Image], Image.Image, str, List[List[object]], str, str, str]: if whole_image is None: raise gr.Error("Whole image is required.") runtime = load_runtime(checkpoint_path, prototype_bank_path, dinov3_root, dinov3_weights, device_name) device = runtime["device"] model = runtime["model"] checkpoint_args = runtime["checkpoint_args"] face_box = None eye_box = None extract_status = "Manual face/eye inputs." used_face = face_image used_eye = eye_image if auto_extract: extractor = load_extractor( yolo_dir, yolo_weights, eye_cascade, device_name, float(face_conf), float(face_iou), int(face_imgsz), int(eye_neighbors), float(eye_margin), ) try: extracted = extractor.extract(whole_image) except Exception as exc: extracted = ExtractedViews(None, None, None, None, f"Auto crop failed ({exc}); using whole image only.") used_face = extracted.face used_eye = extracted.eye face_box = extracted.face_box eye_box = extracted.eye_box extract_status = extracted.status full, face, eye, view_mask = prepare_inputs(whole_image, used_face, used_eye, checkpoint_args, device) with torch.no_grad(): query_outputs = _encode_query(model, full, face, eye, view_mask=view_mask, use_tta=use_tta) query = F.normalize(query_outputs["embedding"][0], dim=0) result_rows, best_artist, best_proto_idx, best_score = rank_artists( query=query, artists=runtime["artists"], prototype_tensor=runtime["prototype_tensor"], top_k=top_k, ) descriptors = runtime["prototype_descriptors"].get(best_artist) if not descriptors: raise gr.Error(f"No prototype descriptor found for {best_artist}.") reference_descriptor = descriptors[best_proto_idx] view_attention_boxes = {} face_attention_box = normalized_box(face_box, whole_image.size) eye_attention_box = normalized_box(eye_box, whole_image.size) if face_attention_box is not None: view_attention_boxes["face"] = face_attention_box if eye_attention_box is not None: view_attention_boxes["eye"] = eye_attention_box explanation = explain_against_reference( model=model, query_full=full, query_face=face, query_eye=eye, query_view_mask=view_mask, reference_descriptor=reference_descriptor, view_attention_boxes=view_attention_boxes, use_tta=use_tta, ) overlay = make_overlay(whole_image, explanation["combined_view_heatmaps"], face_box=face_box, eye_box=eye_box) branch_html, view_html, summary = contribution_bars(explanation) summary = f"{extract_status} {summary} Target prototype: {best_artist} #{best_proto_idx}." return used_face, used_eye, overlay, top_match_html(best_artist, best_score), result_rows, branch_html, view_html, summary def build_ui(args: argparse.Namespace) -> gr.Blocks: with gr.Blocks(title="Artist Style DINOv3") as demo: gr.Markdown("# Artist Style DINOv3") with gr.Row(elem_classes=["app"]): with gr.Column(scale=4, min_width=320): whole = gr.Image(label="Whole", type="pil", image_mode="RGB", height=260) with gr.Row(): face = gr.Image(label="Face", type="pil", image_mode="RGB", height=160) eye = gr.Image(label="Eye", type="pil", image_mode="RGB", height=160) with gr.Row(): auto_extract = gr.Checkbox(value=True, label="Auto crop") top_k = gr.Slider(1, 25, value=10, step=1, label="Top K") use_tta = gr.Checkbox(value=False, label="TTA") run = gr.Button("Analyze", variant="primary") with gr.Accordion("Paths", open=False): checkpoint_path = gr.Textbox(value=args.checkpoint, label="Checkpoint") prototype_bank_path = gr.Textbox(value=args.prototype_bank, label="Prototype bank") dinov3_root = gr.Textbox(value=args.dinov3_root, label="DINOv3 root") dinov3_weights = gr.Textbox(value=args.dinov3_weights, label="DINOv3 weights") yolo_dir = gr.Textbox(value=args.yolo_dir, label="YOLOv5 anime root") yolo_weights = gr.Textbox(value=args.yolo_weights, label="YOLOv5 anime weights") eye_cascade = gr.Textbox(value=args.eye_cascade, label="Eye cascade") device_name = gr.Dropdown(["auto", "cpu", "cuda"], value=args.device, label="Device") with gr.Accordion("Auto Crop", open=False): face_conf = gr.Slider(0.05, 0.95, value=args.face_conf, step=0.05, label="Face confidence") face_iou = gr.Slider(0.1, 0.9, value=args.face_iou, step=0.05, label="Face IoU") face_imgsz = gr.Slider(320, 1280, value=args.face_imgsz, step=32, label="Face detector size") eye_neighbors = gr.Slider(1, 20, value=args.eye_neighbors, step=1, label="Eye neighbors") eye_margin = gr.Slider(0.0, 1.5, value=args.eye_margin, step=0.05, label="Eye margin") with gr.Column(scale=7, min_width=640): with gr.Row(): overlay = gr.Image(label="Composite heatmap", type="pil", height=520) with gr.Column(scale=1): summary = gr.Textbox(label="Strongest contribution", lines=3, elem_id="summary-box") top_match = gr.HTML() results = gr.Dataframe( headers=["rank", "artist", "score"], label="Retrieval", datatype=["number", "str", "str"], row_count=10, ) with gr.Row(): branch_bars = gr.HTML(label="Branch") view_bars = gr.HTML(label="View") run.click( fn=analyze, inputs=[ whole, face, eye, auto_extract, top_k, use_tta, checkpoint_path, prototype_bank_path, dinov3_root, dinov3_weights, device_name, yolo_dir, yolo_weights, eye_cascade, face_conf, face_iou, face_imgsz, eye_neighbors, eye_margin, ], outputs=[face, eye, overlay, top_match, results, branch_bars, view_bars, summary], ) return demo def main() -> None: args = parse_args() demo = build_ui(args) demo.launch(server_name=args.server_name, server_port=args.server_port, share=args.share, css=APP_CSS) if __name__ == "__main__": main()