Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import argparse | |
| import itertools | |
| import os | |
| import sys | |
| from functools import lru_cache | |
| from pathlib import Path | |
| from typing import Dict, List, NamedTuple, Optional, Tuple | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from torchvision.transforms import functional as TF | |
| from artist_style_dinov3 import ArtistStyleModel, explain_against_reference | |
| from artist_style_dinov3.explain import _encode_query | |
| ROOT = Path(__file__).resolve().parent | |
| VIEW_NAMES = ["whole", "face", "eye"] | |
| MODEL_VIEW_NAMES = ["full", "face", "eye"] | |
| BRANCH_NAMES = ["structure", "texture", "line", "color"] | |
| # ColorBrewer YlOrRd-style sequential palette. Low values are transparent; high | |
| # values move through yellow/orange into red so saliency is visible on grayscale. | |
| HEAT_PALETTE = np.array( | |
| [ | |
| [1.000000, 1.000000, 0.800000], | |
| [1.000000, 0.929412, 0.627451], | |
| [0.996078, 0.850980, 0.462745], | |
| [0.996078, 0.698039, 0.298039], | |
| [0.992157, 0.552941, 0.235294], | |
| [0.988235, 0.305882, 0.164706], | |
| [0.890196, 0.101961, 0.109804], | |
| [0.700000, 0.000000, 0.000000], | |
| ], | |
| dtype=np.float32, | |
| ) | |
| APP_CSS = """ | |
| .app {max-width: 1500px; margin: 0 auto;} | |
| .compact-table textarea {font-size: 12px;} | |
| #summary-box textarea {font-size: 14px; font-weight: 600;} | |
| .bars {display: grid; gap: 10px; margin: 4px 0 14px;} | |
| .bar-row {display: grid; grid-template-columns: 82px 1fr 52px; gap: 8px; align-items: center;} | |
| .bar-label {font-size: 13px; font-weight: 600;} | |
| .bar-track {height: 13px; background: #e8e8e8; border-radius: 999px; overflow: hidden;} | |
| .bar-fill {height: 100%;} | |
| .bar-value {font-variant-numeric: tabular-nums; font-size: 12px; text-align: right;} | |
| .bar-title {font-size: 15px; font-weight: 700; margin: 0 0 3px;} | |
| .bar-note {font-size: 12px; color: #666; margin: 0 0 10px; line-height: 1.35;} | |
| .top-match {padding: 12px 14px; border: 1px solid #dedede; border-radius: 8px; margin-bottom: 10px; background: #fafafa;} | |
| .top-match-label {font-size: 12px; color: #666; margin-bottom: 2px;} | |
| .top-match-artist {font-size: 24px; line-height: 1.15; font-weight: 800;} | |
| .top-match-score {font-size: 18px; font-weight: 700; color: #b91c1c; margin-top: 4px;} | |
| """ | |
| BRANCH_COLORS = { | |
| "structure": "#4e79a7", | |
| "texture": "#f28e2b", | |
| "line": "#e15759", | |
| "color": "#59a14f", | |
| } | |
| VIEW_COLORS = { | |
| "whole": "#7b61ff", | |
| "face": "#ff7f0e", | |
| "eye": "#d62728", | |
| } | |
| def env_path(name: str, default: Path) -> str: | |
| return str(Path(os.environ[name]).expanduser()) if name in os.environ and os.environ[name].strip() else str(default) | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="Gradio UI for artist style retrieval and attribution.") | |
| parser.add_argument( | |
| "--checkpoint", | |
| type=str, | |
| default=env_path("ARTIST_STYLE_CHECKPOINT", ROOT / "artifacts" / "style_training_dinov3" / "best.pt"), | |
| ) | |
| parser.add_argument( | |
| "--prototype-bank", | |
| type=str, | |
| default=env_path("ARTIST_PROTOTYPES", ROOT / "artifacts" / "style_training_dinov3" / "artist_prototypes.pt"), | |
| ) | |
| parser.add_argument("--dinov3-root", type=str, default=env_path("DINOV3_ROOT", ROOT / "third_party" / "dinov3")) | |
| parser.add_argument( | |
| "--dinov3-weights", | |
| type=str, | |
| default=env_path( | |
| "DINOV3_WEIGHTS", | |
| ROOT / "artifacts" / "pretrained" / "dinov3" / "dinov3_vits16_pretrain_lvd1689m-08c60483.pth", | |
| ), | |
| ) | |
| parser.add_argument("--device", choices=["auto", "cpu", "cuda"], default="auto") | |
| parser.add_argument("--yolo-dir", type=str, default=env_path("YOLO_ANIME_ROOT", ROOT / "yolov5_anime")) | |
| parser.add_argument( | |
| "--yolo-weights", | |
| type=str, | |
| default=env_path("YOLO_WEIGHTS", ROOT / "yolov5_anime" / "weights" / "yolov5s_anime.pt"), | |
| ) | |
| parser.add_argument("--eye-cascade", type=str, default=env_path("EYE_CASCADE", ROOT / "anime-eyes-cascade.xml")) | |
| parser.add_argument("--face-conf", type=float, default=0.5) | |
| parser.add_argument("--face-iou", type=float, default=0.5) | |
| parser.add_argument("--face-imgsz", type=int, default=640) | |
| parser.add_argument("--eye-neighbors", type=int, default=9) | |
| parser.add_argument("--eye-margin", type=float, default=0.6) | |
| parser.add_argument("--server-name", type=str, default="127.0.0.1") | |
| parser.add_argument("--server-port", type=int, default=7860) | |
| parser.add_argument("--share", action="store_true") | |
| return parser.parse_args() | |
| class ExtractedViews(NamedTuple): | |
| face: Optional[Image.Image] | |
| eye: Optional[Image.Image] | |
| face_box: Optional[Tuple[int, int, int, int]] | |
| eye_box: Optional[Tuple[int, int, int, int]] | |
| status: str | |
| def resolve_device(device_name: str) -> torch.device: | |
| if device_name == "cuda": | |
| if not torch.cuda.is_available(): | |
| raise RuntimeError("CUDA was requested but is not available.") | |
| return torch.device("cuda") | |
| if device_name == "cpu": | |
| return torch.device("cpu") | |
| return torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def _patch_torch_load_for_yolov5() -> None: | |
| import torch.serialization | |
| try: | |
| import numpy as _np | |
| torch.serialization.add_safe_globals( | |
| [ | |
| _np.core.multiarray._reconstruct, | |
| _np.ndarray, | |
| ] | |
| ) | |
| except Exception: | |
| pass | |
| original_load = torch.load | |
| if getattr(original_load, "_anime_webui_patched", False): | |
| return | |
| def patched_load(*args, **kwargs): | |
| kwargs.setdefault("weights_only", False) | |
| return original_load(*args, **kwargs) | |
| patched_load._anime_webui_patched = True | |
| torch.load = patched_load | |
| def _expand_box(box: Tuple[int, int, int, int], margin: float, width: int, height: int) -> Tuple[int, int, int, int]: | |
| x1, y1, x2, y2 = box | |
| cx = (x1 + x2) / 2.0 | |
| cy = (y1 + y2) / 2.0 | |
| bw = (x2 - x1) * (1.0 + margin) | |
| bh = (y2 - y1) * (1.0 + margin) | |
| nx1 = max(0, min(width, int(round(cx - bw / 2.0)))) | |
| ny1 = max(0, min(height, int(round(cy - bh / 2.0)))) | |
| nx2 = max(0, min(width, int(round(cx + bw / 2.0)))) | |
| ny2 = max(0, min(height, int(round(cy + bh / 2.0)))) | |
| return nx1, ny1, nx2, ny2 | |
| class AnimeFaceEyeExtractor: | |
| def __init__( | |
| self, | |
| yolo_dir: str, | |
| yolo_weights: str, | |
| eye_cascade: str, | |
| device_name: str, | |
| conf: float, | |
| iou: float, | |
| imgsz: int, | |
| eye_neighbors: int, | |
| eye_margin: float, | |
| ) -> None: | |
| self.yolo_dir = Path(yolo_dir).resolve() | |
| self.yolo_weights = Path(yolo_weights).resolve() | |
| self.eye_cascade = Path(eye_cascade).resolve() | |
| self.device_name = device_name | |
| self.conf = float(conf) | |
| self.iou = float(iou) | |
| self.imgsz = int(imgsz) | |
| self.eye_neighbors = int(eye_neighbors) | |
| self.eye_margin = float(eye_margin) | |
| self.model = None | |
| self.device = None | |
| self.stride = 32 | |
| self.use_half = False | |
| self.cascade = None | |
| def _ensure_ready(self) -> None: | |
| missing = [ | |
| str(path) | |
| for path in [self.yolo_dir, self.yolo_weights, self.eye_cascade] | |
| if not path.exists() | |
| ] | |
| if missing: | |
| raise gr.Error("Missing auto extraction files: " + ", ".join(missing)) | |
| if self.model is not None and self.cascade is not None: | |
| return | |
| try: | |
| import cv2 | |
| except ModuleNotFoundError as exc: | |
| raise gr.Error("opencv-python is required for auto face/eye extraction.") from exc | |
| if str(self.yolo_dir) not in sys.path: | |
| sys.path.insert(0, str(self.yolo_dir)) | |
| _patch_torch_load_for_yolov5() | |
| from models.experimental import attempt_load | |
| from utils.torch_utils import select_device | |
| yolo_device = "0" if self.device_name in {"auto", "cuda"} and torch.cuda.is_available() else "cpu" | |
| self.device = select_device(yolo_device) | |
| self.use_half = getattr(self.device, "type", str(self.device)) != "cpu" | |
| self.model = attempt_load(str(self.yolo_weights), map_location=self.device) | |
| if self.use_half: | |
| self.model.half() | |
| self.model.eval() | |
| self.stride = int(self.model.stride.max()) | |
| self.imgsz = int(np.ceil(self.imgsz / self.stride) * self.stride) | |
| self.cascade = cv2.CascadeClassifier(str(self.eye_cascade)) | |
| if self.cascade.empty(): | |
| raise gr.Error(f"Failed to load eye cascade: {self.eye_cascade}") | |
| def _letterbox(self, bgr: np.ndarray) -> np.ndarray: | |
| from utils.datasets import letterbox | |
| try: | |
| return letterbox(bgr, (self.imgsz, self.imgsz), stride=self.stride, auto=False)[0] | |
| except TypeError: | |
| try: | |
| return letterbox(bgr, (self.imgsz, self.imgsz), auto=False)[0] | |
| except TypeError: | |
| return letterbox(bgr, (self.imgsz, self.imgsz))[0] | |
| def _detect_faces(self, rgb: np.ndarray) -> List[Tuple[int, int, int, int, float]]: | |
| import cv2 | |
| from utils.general import non_max_suppression, scale_coords | |
| bgr = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR) | |
| image = self._letterbox(bgr) | |
| image = image[:, :, ::-1].transpose(2, 0, 1) | |
| image = np.ascontiguousarray(image) | |
| tensor = torch.from_numpy(image).to(self.device) | |
| tensor = tensor.half() if self.use_half else tensor.float() | |
| tensor = tensor / 255.0 | |
| tensor = tensor.unsqueeze(0) | |
| with torch.inference_mode(): | |
| pred = self.model(tensor)[0] | |
| pred = non_max_suppression(pred, conf_thres=self.conf, iou_thres=self.iou, classes=None, agnostic=False)[0] | |
| boxes: List[Tuple[int, int, int, int, float]] = [] | |
| if pred is not None and len(pred): | |
| h0, w0 = rgb.shape[:2] | |
| pred[:, :4] = scale_coords((self.imgsz, self.imgsz), pred[:, :4], (h0, w0)).round() | |
| for *xyxy, conf, _cls in pred.tolist(): | |
| x1, y1, x2, y2 = [int(v) for v in xyxy] | |
| boxes.append((x1, y1, x2, y2, float(conf))) | |
| return boxes | |
| def _detect_eyes(self, face_rgb: np.ndarray) -> List[Tuple[str, Tuple[int, int, int, int]]]: | |
| import cv2 | |
| height, width = face_rgb.shape[:2] | |
| roi_h = max(1, int(height * 0.70)) | |
| roi = face_rgb[:roi_h, :] | |
| gray = cv2.cvtColor(roi, cv2.COLOR_RGB2GRAY) | |
| gray = cv2.GaussianBlur(gray, (3, 3), 0) | |
| gray = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)).apply(gray) | |
| min_size = max(8, int(0.07 * min(width, roi_h))) | |
| raw = self.cascade.detectMultiScale( | |
| gray, | |
| scaleFactor=1.05, | |
| minNeighbors=self.eye_neighbors, | |
| minSize=(min_size, min_size), | |
| flags=cv2.CASCADE_SCALE_IMAGE, | |
| ) | |
| if raw is None: | |
| return [] | |
| if isinstance(raw, tuple): | |
| if len(raw) == 0: | |
| return [] | |
| raw = raw[0] | |
| arr = np.asarray(raw) | |
| if arr.size == 0: | |
| return [] | |
| if arr.ndim == 1: | |
| arr = arr.reshape(1, -1) | |
| boxes = [] | |
| for x, y, w, h in arr[:, :4]: | |
| if w > 0 and h > 0: | |
| boxes.append((int(x), int(y), int(x + w), int(y + h))) | |
| return self._best_eye_pair(boxes, width, roi_h) | |
| def _best_eye_pair( | |
| boxes: List[Tuple[int, int, int, int]], | |
| width: int, | |
| height: int, | |
| ) -> List[Tuple[str, Tuple[int, int, int, int]]]: | |
| if len(boxes) < 2: | |
| return [("left", boxes[0])] if len(boxes) == 1 else [] | |
| def center(box: Tuple[int, int, int, int]) -> Tuple[float, float]: | |
| x1, y1, x2, y2 = box | |
| return (x1 + x2) / 2.0, (y1 + y2) / 2.0 | |
| def area(box: Tuple[int, int, int, int]) -> int: | |
| x1, y1, x2, y2 = box | |
| return max(1, (x2 - x1) * (y2 - y1)) | |
| best_pair = None | |
| best_score = float("inf") | |
| for first, second in itertools.combinations(boxes, 2): | |
| c1x, c1y = center(first) | |
| c2x, c2y = center(second) | |
| a1, a2 = area(first), area(second) | |
| gap = abs(c2x - c1x) / max(1.0, width) | |
| gap_penalty = 0.0 if 0.05 <= gap <= 0.5 else 0.5 | |
| score = abs(c1y - c2y) / max(1.0, height) + abs(a1 - a2) / max(a1, a2) + gap_penalty | |
| if score < best_score: | |
| best_score = score | |
| best_pair = (first, second) | |
| assert best_pair is not None | |
| left, right = sorted(best_pair, key=lambda box: box[0] + box[2]) | |
| return [("left", left), ("right", right)] | |
| def extract(self, image: Image.Image) -> ExtractedViews: | |
| self._ensure_ready() | |
| rgb = np.asarray(image.convert("RGB")) | |
| height, width = rgb.shape[:2] | |
| faces = self._detect_faces(rgb) | |
| if not faces: | |
| return ExtractedViews(None, None, None, None, "No face detected.") | |
| face_box_raw = max(faces, key=lambda item: (item[2] - item[0]) * (item[3] - item[1]) * max(item[4], 1e-6)) | |
| face_box = tuple(face_box_raw[:4]) | |
| x1, y1, x2, y2 = face_box | |
| face_rgb = rgb[y1:y2, x1:x2] | |
| if face_rgb.size == 0: | |
| return ExtractedViews(None, None, None, None, "Detected face crop is empty.") | |
| face_image = Image.fromarray(face_rgb) | |
| try: | |
| eye_labels = self._detect_eyes(face_rgb) | |
| except Exception as exc: | |
| return ExtractedViews(face_image, None, face_box, None, f"Face detected; eye extraction failed ({exc}).") | |
| if not eye_labels: | |
| return ExtractedViews(face_image, None, face_box, None, "Face detected; no eye detected.") | |
| # The style model was trained with a single eye crop. When both eyes are | |
| # detected, match eval/inference behavior by using the left eye. | |
| selected_label, selected_box = eye_labels[0] | |
| for label, box in eye_labels: | |
| if label == "left": | |
| selected_label, selected_box = label, box | |
| break | |
| ex1, ey1, ex2, ey2 = _expand_box(selected_box, self.eye_margin, face_rgb.shape[1], face_rgb.shape[0]) | |
| eye_rgb = face_rgb[ey1:ey2, ex1:ex2] | |
| if eye_rgb.size == 0: | |
| return ExtractedViews(face_image, None, face_box, None, "Face detected; eye crop is empty.") | |
| eye_box = (x1 + ex1, y1 + ey1, x1 + ex2, y1 + ey2) | |
| return ExtractedViews(face_image, Image.fromarray(eye_rgb), face_box, eye_box, f"Face and {selected_label} eye detected.") | |
| def load_extractor( | |
| yolo_dir: str, | |
| yolo_weights: str, | |
| eye_cascade: str, | |
| device_name: str, | |
| conf: float, | |
| iou: float, | |
| imgsz: int, | |
| eye_neighbors: int, | |
| eye_margin: float, | |
| ) -> AnimeFaceEyeExtractor: | |
| return AnimeFaceEyeExtractor( | |
| yolo_dir=yolo_dir, | |
| yolo_weights=yolo_weights, | |
| eye_cascade=eye_cascade, | |
| device_name=device_name, | |
| conf=conf, | |
| iou=iou, | |
| imgsz=imgsz, | |
| eye_neighbors=eye_neighbors, | |
| eye_margin=eye_margin, | |
| ) | |
| def build_model(checkpoint: Dict[str, object], dinov3_root: str, dinov3_weights: str, device: torch.device) -> ArtistStyleModel: | |
| checkpoint_args = checkpoint.get("args", {}) | |
| if not isinstance(checkpoint_args, dict): | |
| checkpoint_args = {} | |
| model = ArtistStyleModel( | |
| num_classes=int(checkpoint_args.get("train_artist_count", 1000)), | |
| branch_hidden_dim=int(checkpoint_args.get("branch_hidden_dim", 384)), | |
| branch_dim=int(checkpoint_args.get("branch_dim", 192)), | |
| embedding_dim=int(checkpoint_args.get("embedding_dim", 256)), | |
| num_prototypes=int(checkpoint_args.get("num_prototypes", 4)), | |
| prototype_temperature=float(checkpoint_args.get("prototype_temperature", 0.1)), | |
| arcface_scale=float(checkpoint_args.get("arcface_scale", 30.0)), | |
| view_dropout_prob=float(checkpoint_args.get("view_dropout_prob", 0.15)), | |
| branch_dropout_prob=float(checkpoint_args.get("branch_dropout_prob", 0.1)), | |
| backbone_repo_dir=dinov3_root, | |
| backbone_entrypoint=str(checkpoint_args.get("backbone_entrypoint", "dinov3_vits16")), | |
| backbone_weights_path=dinov3_weights, | |
| freeze_backbone=bool(checkpoint_args.get("freeze_backbone", True)), | |
| backbone_unfreeze_last_n_blocks=int(checkpoint_args.get("backbone_unfreeze_last_n_blocks", 0)), | |
| ).to(device) | |
| model.load_state_dict(checkpoint["model"]) | |
| model.eval() | |
| return model | |
| def load_runtime( | |
| checkpoint_path: str, | |
| prototype_bank_path: str, | |
| dinov3_root: str, | |
| dinov3_weights: str, | |
| device_name: str, | |
| ) -> Dict[str, object]: | |
| device = resolve_device(device_name) | |
| checkpoint = torch.load(checkpoint_path, map_location=device) | |
| checkpoint_args = checkpoint.get("args", {}) | |
| if not isinstance(checkpoint_args, dict): | |
| checkpoint_args = {} | |
| model = build_model(checkpoint, dinov3_root, dinov3_weights, device) | |
| bank_payload = torch.load(prototype_bank_path, map_location="cpu") | |
| prototype_bank = bank_payload["prototype_bank"] | |
| prototype_descriptors = bank_payload.get("prototype_descriptors", {}) | |
| artists = sorted(prototype_bank) | |
| prototype_tensor = torch.stack([prototype_bank[artist] for artist in artists], dim=0).float() | |
| prototype_tensor = F.normalize(prototype_tensor, dim=-1).to(device) | |
| return { | |
| "device": device, | |
| "model": model, | |
| "checkpoint_args": checkpoint_args, | |
| "artists": artists, | |
| "prototype_tensor": prototype_tensor, | |
| "prototype_descriptors": prototype_descriptors, | |
| } | |
| def resize_rgb(image: Image.Image, size: int) -> torch.Tensor: | |
| image = image.convert("RGB") | |
| image = TF.resize(image, [size, size], interpolation=Image.Resampling.BICUBIC) | |
| return TF.to_tensor(image) | |
| def blank(size: int) -> torch.Tensor: | |
| return torch.zeros((3, size, size), dtype=torch.float32) | |
| def prepare_inputs( | |
| whole_image: Image.Image, | |
| face_image: Optional[Image.Image], | |
| eye_image: Optional[Image.Image], | |
| checkpoint_args: Dict[str, object], | |
| device: torch.device, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| full_size = int(checkpoint_args.get("full_image_size", 320)) | |
| face_size = int(checkpoint_args.get("face_image_size", 256)) | |
| eye_size = int(checkpoint_args.get("eye_image_size", 224)) | |
| full = resize_rgb(whole_image, full_size) | |
| has_face = face_image is not None | |
| has_eye = eye_image is not None | |
| face = resize_rgb(face_image, face_size) if has_face else blank(face_size) | |
| eye = resize_rgb(eye_image, eye_size) if has_eye else blank(eye_size) | |
| view_mask = torch.tensor([[1.0, float(has_face), float(has_eye)]], dtype=torch.float32) | |
| return ( | |
| full.unsqueeze(0).to(device), | |
| face.unsqueeze(0).to(device), | |
| eye.unsqueeze(0).to(device), | |
| view_mask.to(device), | |
| ) | |
| def normalized_box(box: Optional[Tuple[int, int, int, int]], image_size: Tuple[int, int]) -> Optional[Tuple[float, float, float, float]]: | |
| if box is None: | |
| return None | |
| width, height = image_size | |
| if width <= 0 or height <= 0: | |
| return None | |
| x1, y1, x2, y2 = box | |
| return ( | |
| max(0.0, min(1.0, x1 / width)), | |
| max(0.0, min(1.0, y1 / height)), | |
| max(0.0, min(1.0, x2 / width)), | |
| max(0.0, min(1.0, y2 / height)), | |
| ) | |
| def rank_artists(query: torch.Tensor, artists: List[str], prototype_tensor: torch.Tensor, top_k: int) -> Tuple[List[List[object]], str, int, float]: | |
| sims = torch.einsum("d,apd->ap", query.float(), prototype_tensor) | |
| scores, proto_indices = sims.max(dim=1) | |
| k = min(max(1, int(top_k)), len(artists)) | |
| values, artist_indices = torch.topk(scores, k=k) | |
| rows: List[List[object]] = [] | |
| for rank, (score, artist_idx) in enumerate(zip(values.tolist(), artist_indices.tolist()), start=1): | |
| proto_idx = int(proto_indices[artist_idx].item()) | |
| rows.append([rank, artists[artist_idx], f"{float(score):.3f}"]) | |
| best_artist_idx = int(artist_indices[0].item()) | |
| return rows, artists[best_artist_idx], int(proto_indices[best_artist_idx].item()), float(values[0].item()) | |
| def tensor_to_map(heatmap: torch.Tensor, size: Tuple[int, int]) -> np.ndarray: | |
| data = heatmap.detach().float().squeeze().cpu().numpy() | |
| data = np.nan_to_num(data, nan=0.0, posinf=0.0, neginf=0.0) | |
| data = data - float(data.min()) | |
| data = data / max(float(data.max()), 1e-6) | |
| image = Image.fromarray(np.uint8(np.clip(data, 0.0, 1.0) * 255), mode="L") | |
| image = image.resize(size, Image.Resampling.BICUBIC) | |
| resized = np.asarray(image, dtype=np.float32) / 255.0 | |
| return np.clip(resized, 0.0, 1.0) | |
| def heat_colors(values: np.ndarray) -> np.ndarray: | |
| values = np.clip(values, 0.0, 1.0) | |
| scaled = values * (len(HEAT_PALETTE) - 1) | |
| left = np.floor(scaled).astype(np.int32) | |
| right = np.clip(left + 1, 0, len(HEAT_PALETTE) - 1) | |
| frac = scaled[..., None] - left[..., None] | |
| colors = HEAT_PALETTE[left] * (1.0 - frac) + HEAT_PALETTE[right] * frac | |
| return colors | |
| def _paste_map(canvas: np.ndarray, heatmap: torch.Tensor, box: Tuple[int, int, int, int], size: Tuple[int, int]) -> None: | |
| x1, y1, x2, y2 = box | |
| width, height = size | |
| x1 = max(0, min(width, int(x1))) | |
| x2 = max(0, min(width, int(x2))) | |
| y1 = max(0, min(height, int(y1))) | |
| y2 = max(0, min(height, int(y2))) | |
| if x2 <= x1 or y2 <= y1: | |
| return | |
| local = tensor_to_map(heatmap, (x2 - x1, y2 - y1)) | |
| canvas[y1:y2, x1:x2] = np.maximum(canvas[y1:y2, x1:x2], local) | |
| def make_overlay( | |
| base_image: Image.Image, | |
| view_heatmaps: Dict[str, Optional[torch.Tensor]], | |
| face_box: Optional[Tuple[int, int, int, int]] = None, | |
| eye_box: Optional[Tuple[int, int, int, int]] = None, | |
| ) -> Image.Image: | |
| base = base_image.convert("RGB") | |
| original_size = base.size | |
| display = base.copy() | |
| display.thumbnail((720, 720), Image.Resampling.LANCZOS) | |
| size = display.size | |
| sx = size[0] / max(1, original_size[0]) | |
| sy = size[1] / max(1, original_size[1]) | |
| gray = np.asarray(base.convert("L"), dtype=np.float32) / 255.0 | |
| gray = np.asarray(display.convert("L"), dtype=np.float32) / 255.0 | |
| gray = 0.78 + 0.22 * gray | |
| gray_rgb = np.repeat(gray[..., None], 3, axis=2) | |
| maps = [] | |
| full_heatmap = view_heatmaps.get("full") | |
| if full_heatmap is not None: | |
| maps.append(tensor_to_map(full_heatmap, size)) | |
| spatial = np.zeros((size[1], size[0]), dtype=np.float32) | |
| spatial_used = False | |
| face_heatmap = view_heatmaps.get("face") | |
| if face_heatmap is not None and face_box is not None: | |
| scaled_face_box = ( | |
| int(round(face_box[0] * sx)), | |
| int(round(face_box[1] * sy)), | |
| int(round(face_box[2] * sx)), | |
| int(round(face_box[3] * sy)), | |
| ) | |
| _paste_map(spatial, face_heatmap, scaled_face_box, size) | |
| spatial_used = True | |
| elif face_heatmap is not None: | |
| maps.append(tensor_to_map(face_heatmap, size)) | |
| eye_heatmap = view_heatmaps.get("eye") | |
| if eye_heatmap is not None and eye_box is not None: | |
| scaled_eye_box = ( | |
| int(round(eye_box[0] * sx)), | |
| int(round(eye_box[1] * sy)), | |
| int(round(eye_box[2] * sx)), | |
| int(round(eye_box[3] * sy)), | |
| ) | |
| _paste_map(spatial, eye_heatmap, scaled_eye_box, size) | |
| spatial_used = True | |
| elif eye_heatmap is not None: | |
| maps.append(tensor_to_map(eye_heatmap, size)) | |
| if spatial_used: | |
| maps.append(spatial) | |
| if not maps: | |
| return Image.fromarray(np.uint8(gray_rgb * 255)) | |
| # Equal per-view normalization and averaging keeps larger source crops or | |
| # local face/eye maps from dominating purely because of area or scale. | |
| combined = np.mean(np.stack(maps, axis=0), axis=0) | |
| combined = combined - float(combined.min()) | |
| combined = combined / max(float(combined.max()), 1e-6) | |
| heat_rgb = heat_colors(combined) | |
| alpha = np.clip((combined - 0.05) / 0.95, 0.0, 1.0) | |
| alpha = (0.90 * alpha)[..., None] | |
| blended = gray_rgb * (1.0 - alpha) + heat_rgb * alpha | |
| return Image.fromarray(np.uint8(np.clip(blended, 0.0, 1.0) * 255)) | |
| def scalar(value: torch.Tensor) -> float: | |
| return float(value.detach().float().flatten()[0].cpu().item()) | |
| def _normalize_values(values: List[float]) -> List[float]: | |
| clipped = [max(0.0, float(value)) for value in values] | |
| total = sum(clipped) | |
| if total <= 1e-8: | |
| return [1.0 / len(clipped) for _ in clipped] | |
| return [value / total for value in clipped] | |
| def _bars_html(title: str, note: str, labels: List[str], values: List[float], colors: Dict[str, str]) -> str: | |
| rows = [] | |
| for label, value in zip(labels, values): | |
| pct = max(0.0, min(1.0, float(value))) * 100.0 | |
| color = colors.get(label, "#e34a33") | |
| rows.append( | |
| "<div class='bar-row' style='display:grid; grid-template-columns:82px minmax(120px,1fr) 52px; " | |
| "gap:8px; align-items:center; margin:7px 0;'>" | |
| f"<div class='bar-label' style='font-size:13px; font-weight:600;'>{label}</div>" | |
| "<div class='bar-track' style='height:14px; background:#e8e8e8; border-radius:999px; overflow:hidden;'>" | |
| f"<div class='bar-fill' style='width:{pct:.1f}%; height:14px; background:{color}; border-radius:999px;'></div>" | |
| "</div>" | |
| f"<div class='bar-value' style='font-variant-numeric:tabular-nums; font-size:12px; text-align:right;'>{pct:.1f}%</div>" | |
| "</div>" | |
| ) | |
| return ( | |
| "<div class='bars' style='display:grid; gap:4px; margin:4px 0 14px;'>" | |
| f"<div class='bar-title' style='font-size:15px; font-weight:700; margin:0 0 3px;'>{title}</div>" | |
| f"<div class='bar-note' style='font-size:12px; color:#666; margin:0 0 8px; line-height:1.35;'>{note}</div>" | |
| f"{''.join(rows)}</div>" | |
| ) | |
| def top_match_html(artist: str, similarity: float) -> str: | |
| percent = max(0.0, min(100.0, float(similarity) * 100.0)) | |
| return ( | |
| "<div class='top-match'>" | |
| "<div class='top-match-label'>Top artist match</div>" | |
| f"<div class='top-match-artist'>{artist}</div>" | |
| f"<div class='top-match-score'>{percent:.1f}% similarity</div>" | |
| "</div>" | |
| ) | |
| def contribution_bars(explanation: Dict[str, object]) -> Tuple[str, str, str]: | |
| query_outputs = explanation["query_outputs"] | |
| branch_values = query_outputs["branch_weights"][0].detach().float().cpu().tolist() | |
| branch_values = _normalize_values(branch_values) | |
| view_weights = query_outputs["stacked_view_weights"][0].detach().float().cpu() | |
| branch_tensor = torch.tensor(branch_values, dtype=view_weights.dtype).unsqueeze(-1) | |
| view_values = (view_weights * branch_tensor).sum(dim=0).tolist() | |
| view_values = _normalize_values(view_values) | |
| best_branch = BRANCH_NAMES[int(np.argmax(branch_values))] | |
| best_view = VIEW_NAMES[int(np.argmax(view_values))] | |
| summary = f"Dominant branch: {best_branch}. Dominant view: {best_view}." | |
| return ( | |
| _bars_html( | |
| "Influential style factors", | |
| "Relative weight of structure, texture, linework, and color in the query embedding.", | |
| BRANCH_NAMES, | |
| branch_values, | |
| BRANCH_COLORS, | |
| ), | |
| _bars_html( | |
| "Similar visual regions", | |
| "Relative weight of whole image, face crop, and eye crop used for the match.", | |
| VIEW_NAMES, | |
| view_values, | |
| VIEW_COLORS, | |
| ), | |
| summary, | |
| ) | |
| def analyze( | |
| whole_image: Optional[Image.Image], | |
| face_image: Optional[Image.Image], | |
| eye_image: Optional[Image.Image], | |
| auto_extract: bool, | |
| top_k: int, | |
| use_tta: bool, | |
| checkpoint_path: str, | |
| prototype_bank_path: str, | |
| dinov3_root: str, | |
| dinov3_weights: str, | |
| device_name: str, | |
| yolo_dir: str, | |
| yolo_weights: str, | |
| eye_cascade: str, | |
| face_conf: float, | |
| face_iou: float, | |
| face_imgsz: int, | |
| eye_neighbors: int, | |
| eye_margin: float, | |
| ) -> Tuple[Optional[Image.Image], Optional[Image.Image], Image.Image, str, List[List[object]], str, str, str]: | |
| if whole_image is None: | |
| raise gr.Error("Whole image is required.") | |
| runtime = load_runtime(checkpoint_path, prototype_bank_path, dinov3_root, dinov3_weights, device_name) | |
| device = runtime["device"] | |
| model = runtime["model"] | |
| checkpoint_args = runtime["checkpoint_args"] | |
| face_box = None | |
| eye_box = None | |
| extract_status = "Manual face/eye inputs." | |
| used_face = face_image | |
| used_eye = eye_image | |
| if auto_extract: | |
| extractor = load_extractor( | |
| yolo_dir, | |
| yolo_weights, | |
| eye_cascade, | |
| device_name, | |
| float(face_conf), | |
| float(face_iou), | |
| int(face_imgsz), | |
| int(eye_neighbors), | |
| float(eye_margin), | |
| ) | |
| try: | |
| extracted = extractor.extract(whole_image) | |
| except Exception as exc: | |
| extracted = ExtractedViews(None, None, None, None, f"Auto crop failed ({exc}); using whole image only.") | |
| used_face = extracted.face | |
| used_eye = extracted.eye | |
| face_box = extracted.face_box | |
| eye_box = extracted.eye_box | |
| extract_status = extracted.status | |
| full, face, eye, view_mask = prepare_inputs(whole_image, used_face, used_eye, checkpoint_args, device) | |
| with torch.no_grad(): | |
| query_outputs = _encode_query(model, full, face, eye, view_mask=view_mask, use_tta=use_tta) | |
| query = F.normalize(query_outputs["embedding"][0], dim=0) | |
| result_rows, best_artist, best_proto_idx, best_score = rank_artists( | |
| query=query, | |
| artists=runtime["artists"], | |
| prototype_tensor=runtime["prototype_tensor"], | |
| top_k=top_k, | |
| ) | |
| descriptors = runtime["prototype_descriptors"].get(best_artist) | |
| if not descriptors: | |
| raise gr.Error(f"No prototype descriptor found for {best_artist}.") | |
| reference_descriptor = descriptors[best_proto_idx] | |
| view_attention_boxes = {} | |
| face_attention_box = normalized_box(face_box, whole_image.size) | |
| eye_attention_box = normalized_box(eye_box, whole_image.size) | |
| if face_attention_box is not None: | |
| view_attention_boxes["face"] = face_attention_box | |
| if eye_attention_box is not None: | |
| view_attention_boxes["eye"] = eye_attention_box | |
| explanation = explain_against_reference( | |
| model=model, | |
| query_full=full, | |
| query_face=face, | |
| query_eye=eye, | |
| query_view_mask=view_mask, | |
| reference_descriptor=reference_descriptor, | |
| view_attention_boxes=view_attention_boxes, | |
| use_tta=use_tta, | |
| ) | |
| overlay = make_overlay(whole_image, explanation["combined_view_heatmaps"], face_box=face_box, eye_box=eye_box) | |
| branch_html, view_html, summary = contribution_bars(explanation) | |
| summary = f"{extract_status} {summary} Target prototype: {best_artist} #{best_proto_idx}." | |
| return used_face, used_eye, overlay, top_match_html(best_artist, best_score), result_rows, branch_html, view_html, summary | |
| def build_ui(args: argparse.Namespace) -> gr.Blocks: | |
| with gr.Blocks(title="Artist Style DINOv3") as demo: | |
| gr.Markdown("# Artist Style DINOv3") | |
| with gr.Row(elem_classes=["app"]): | |
| with gr.Column(scale=4, min_width=320): | |
| whole = gr.Image(label="Whole", type="pil", image_mode="RGB", height=260) | |
| with gr.Row(): | |
| face = gr.Image(label="Face", type="pil", image_mode="RGB", height=160) | |
| eye = gr.Image(label="Eye", type="pil", image_mode="RGB", height=160) | |
| with gr.Row(): | |
| auto_extract = gr.Checkbox(value=True, label="Auto crop") | |
| top_k = gr.Slider(1, 25, value=10, step=1, label="Top K") | |
| use_tta = gr.Checkbox(value=False, label="TTA") | |
| run = gr.Button("Analyze", variant="primary") | |
| with gr.Accordion("Paths", open=False): | |
| checkpoint_path = gr.Textbox(value=args.checkpoint, label="Checkpoint") | |
| prototype_bank_path = gr.Textbox(value=args.prototype_bank, label="Prototype bank") | |
| dinov3_root = gr.Textbox(value=args.dinov3_root, label="DINOv3 root") | |
| dinov3_weights = gr.Textbox(value=args.dinov3_weights, label="DINOv3 weights") | |
| yolo_dir = gr.Textbox(value=args.yolo_dir, label="YOLOv5 anime root") | |
| yolo_weights = gr.Textbox(value=args.yolo_weights, label="YOLOv5 anime weights") | |
| eye_cascade = gr.Textbox(value=args.eye_cascade, label="Eye cascade") | |
| device_name = gr.Dropdown(["auto", "cpu", "cuda"], value=args.device, label="Device") | |
| with gr.Accordion("Auto Crop", open=False): | |
| face_conf = gr.Slider(0.05, 0.95, value=args.face_conf, step=0.05, label="Face confidence") | |
| face_iou = gr.Slider(0.1, 0.9, value=args.face_iou, step=0.05, label="Face IoU") | |
| face_imgsz = gr.Slider(320, 1280, value=args.face_imgsz, step=32, label="Face detector size") | |
| eye_neighbors = gr.Slider(1, 20, value=args.eye_neighbors, step=1, label="Eye neighbors") | |
| eye_margin = gr.Slider(0.0, 1.5, value=args.eye_margin, step=0.05, label="Eye margin") | |
| with gr.Column(scale=7, min_width=640): | |
| with gr.Row(): | |
| overlay = gr.Image(label="Composite heatmap", type="pil", height=520) | |
| with gr.Column(scale=1): | |
| summary = gr.Textbox(label="Strongest contribution", lines=3, elem_id="summary-box") | |
| top_match = gr.HTML() | |
| results = gr.Dataframe( | |
| headers=["rank", "artist", "score"], | |
| label="Retrieval", | |
| datatype=["number", "str", "str"], | |
| row_count=10, | |
| ) | |
| with gr.Row(): | |
| branch_bars = gr.HTML(label="Branch") | |
| view_bars = gr.HTML(label="View") | |
| run.click( | |
| fn=analyze, | |
| inputs=[ | |
| whole, | |
| face, | |
| eye, | |
| auto_extract, | |
| top_k, | |
| use_tta, | |
| checkpoint_path, | |
| prototype_bank_path, | |
| dinov3_root, | |
| dinov3_weights, | |
| device_name, | |
| yolo_dir, | |
| yolo_weights, | |
| eye_cascade, | |
| face_conf, | |
| face_iou, | |
| face_imgsz, | |
| eye_neighbors, | |
| eye_margin, | |
| ], | |
| outputs=[face, eye, overlay, top_match, results, branch_bars, view_bars, summary], | |
| ) | |
| return demo | |
| def main() -> None: | |
| args = parse_args() | |
| demo = build_ui(args) | |
| demo.launch(server_name=args.server_name, server_port=args.server_port, share=args.share, css=APP_CSS) | |
| if __name__ == "__main__": | |
| main() | |