iljung1106
Fix contribution bar rendering
0ca0b8b
from __future__ import annotations
import argparse
import itertools
import os
import sys
from functools import lru_cache
from pathlib import Path
from typing import Dict, List, NamedTuple, Optional, Tuple
import gradio as gr
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision.transforms import functional as TF
from artist_style_dinov3 import ArtistStyleModel, explain_against_reference
from artist_style_dinov3.explain import _encode_query
ROOT = Path(__file__).resolve().parent
VIEW_NAMES = ["whole", "face", "eye"]
MODEL_VIEW_NAMES = ["full", "face", "eye"]
BRANCH_NAMES = ["structure", "texture", "line", "color"]
# ColorBrewer YlOrRd-style sequential palette. Low values are transparent; high
# values move through yellow/orange into red so saliency is visible on grayscale.
HEAT_PALETTE = np.array(
[
[1.000000, 1.000000, 0.800000],
[1.000000, 0.929412, 0.627451],
[0.996078, 0.850980, 0.462745],
[0.996078, 0.698039, 0.298039],
[0.992157, 0.552941, 0.235294],
[0.988235, 0.305882, 0.164706],
[0.890196, 0.101961, 0.109804],
[0.700000, 0.000000, 0.000000],
],
dtype=np.float32,
)
APP_CSS = """
.app {max-width: 1500px; margin: 0 auto;}
.compact-table textarea {font-size: 12px;}
#summary-box textarea {font-size: 14px; font-weight: 600;}
.bars {display: grid; gap: 10px; margin: 4px 0 14px;}
.bar-row {display: grid; grid-template-columns: 82px 1fr 52px; gap: 8px; align-items: center;}
.bar-label {font-size: 13px; font-weight: 600;}
.bar-track {height: 13px; background: #e8e8e8; border-radius: 999px; overflow: hidden;}
.bar-fill {height: 100%;}
.bar-value {font-variant-numeric: tabular-nums; font-size: 12px; text-align: right;}
.bar-title {font-size: 15px; font-weight: 700; margin: 0 0 3px;}
.bar-note {font-size: 12px; color: #666; margin: 0 0 10px; line-height: 1.35;}
.top-match {padding: 12px 14px; border: 1px solid #dedede; border-radius: 8px; margin-bottom: 10px; background: #fafafa;}
.top-match-label {font-size: 12px; color: #666; margin-bottom: 2px;}
.top-match-artist {font-size: 24px; line-height: 1.15; font-weight: 800;}
.top-match-score {font-size: 18px; font-weight: 700; color: #b91c1c; margin-top: 4px;}
"""
BRANCH_COLORS = {
"structure": "#4e79a7",
"texture": "#f28e2b",
"line": "#e15759",
"color": "#59a14f",
}
VIEW_COLORS = {
"whole": "#7b61ff",
"face": "#ff7f0e",
"eye": "#d62728",
}
def env_path(name: str, default: Path) -> str:
return str(Path(os.environ[name]).expanduser()) if name in os.environ and os.environ[name].strip() else str(default)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Gradio UI for artist style retrieval and attribution.")
parser.add_argument(
"--checkpoint",
type=str,
default=env_path("ARTIST_STYLE_CHECKPOINT", ROOT / "artifacts" / "style_training_dinov3" / "best.pt"),
)
parser.add_argument(
"--prototype-bank",
type=str,
default=env_path("ARTIST_PROTOTYPES", ROOT / "artifacts" / "style_training_dinov3" / "artist_prototypes.pt"),
)
parser.add_argument("--dinov3-root", type=str, default=env_path("DINOV3_ROOT", ROOT / "third_party" / "dinov3"))
parser.add_argument(
"--dinov3-weights",
type=str,
default=env_path(
"DINOV3_WEIGHTS",
ROOT / "artifacts" / "pretrained" / "dinov3" / "dinov3_vits16_pretrain_lvd1689m-08c60483.pth",
),
)
parser.add_argument("--device", choices=["auto", "cpu", "cuda"], default="auto")
parser.add_argument("--yolo-dir", type=str, default=env_path("YOLO_ANIME_ROOT", ROOT / "yolov5_anime"))
parser.add_argument(
"--yolo-weights",
type=str,
default=env_path("YOLO_WEIGHTS", ROOT / "yolov5_anime" / "weights" / "yolov5s_anime.pt"),
)
parser.add_argument("--eye-cascade", type=str, default=env_path("EYE_CASCADE", ROOT / "anime-eyes-cascade.xml"))
parser.add_argument("--face-conf", type=float, default=0.5)
parser.add_argument("--face-iou", type=float, default=0.5)
parser.add_argument("--face-imgsz", type=int, default=640)
parser.add_argument("--eye-neighbors", type=int, default=9)
parser.add_argument("--eye-margin", type=float, default=0.6)
parser.add_argument("--server-name", type=str, default="127.0.0.1")
parser.add_argument("--server-port", type=int, default=7860)
parser.add_argument("--share", action="store_true")
return parser.parse_args()
class ExtractedViews(NamedTuple):
face: Optional[Image.Image]
eye: Optional[Image.Image]
face_box: Optional[Tuple[int, int, int, int]]
eye_box: Optional[Tuple[int, int, int, int]]
status: str
def resolve_device(device_name: str) -> torch.device:
if device_name == "cuda":
if not torch.cuda.is_available():
raise RuntimeError("CUDA was requested but is not available.")
return torch.device("cuda")
if device_name == "cpu":
return torch.device("cpu")
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
def _patch_torch_load_for_yolov5() -> None:
import torch.serialization
try:
import numpy as _np
torch.serialization.add_safe_globals(
[
_np.core.multiarray._reconstruct,
_np.ndarray,
]
)
except Exception:
pass
original_load = torch.load
if getattr(original_load, "_anime_webui_patched", False):
return
def patched_load(*args, **kwargs):
kwargs.setdefault("weights_only", False)
return original_load(*args, **kwargs)
patched_load._anime_webui_patched = True
torch.load = patched_load
def _expand_box(box: Tuple[int, int, int, int], margin: float, width: int, height: int) -> Tuple[int, int, int, int]:
x1, y1, x2, y2 = box
cx = (x1 + x2) / 2.0
cy = (y1 + y2) / 2.0
bw = (x2 - x1) * (1.0 + margin)
bh = (y2 - y1) * (1.0 + margin)
nx1 = max(0, min(width, int(round(cx - bw / 2.0))))
ny1 = max(0, min(height, int(round(cy - bh / 2.0))))
nx2 = max(0, min(width, int(round(cx + bw / 2.0))))
ny2 = max(0, min(height, int(round(cy + bh / 2.0))))
return nx1, ny1, nx2, ny2
class AnimeFaceEyeExtractor:
def __init__(
self,
yolo_dir: str,
yolo_weights: str,
eye_cascade: str,
device_name: str,
conf: float,
iou: float,
imgsz: int,
eye_neighbors: int,
eye_margin: float,
) -> None:
self.yolo_dir = Path(yolo_dir).resolve()
self.yolo_weights = Path(yolo_weights).resolve()
self.eye_cascade = Path(eye_cascade).resolve()
self.device_name = device_name
self.conf = float(conf)
self.iou = float(iou)
self.imgsz = int(imgsz)
self.eye_neighbors = int(eye_neighbors)
self.eye_margin = float(eye_margin)
self.model = None
self.device = None
self.stride = 32
self.use_half = False
self.cascade = None
def _ensure_ready(self) -> None:
missing = [
str(path)
for path in [self.yolo_dir, self.yolo_weights, self.eye_cascade]
if not path.exists()
]
if missing:
raise gr.Error("Missing auto extraction files: " + ", ".join(missing))
if self.model is not None and self.cascade is not None:
return
try:
import cv2
except ModuleNotFoundError as exc:
raise gr.Error("opencv-python is required for auto face/eye extraction.") from exc
if str(self.yolo_dir) not in sys.path:
sys.path.insert(0, str(self.yolo_dir))
_patch_torch_load_for_yolov5()
from models.experimental import attempt_load
from utils.torch_utils import select_device
yolo_device = "0" if self.device_name in {"auto", "cuda"} and torch.cuda.is_available() else "cpu"
self.device = select_device(yolo_device)
self.use_half = getattr(self.device, "type", str(self.device)) != "cpu"
self.model = attempt_load(str(self.yolo_weights), map_location=self.device)
if self.use_half:
self.model.half()
self.model.eval()
self.stride = int(self.model.stride.max())
self.imgsz = int(np.ceil(self.imgsz / self.stride) * self.stride)
self.cascade = cv2.CascadeClassifier(str(self.eye_cascade))
if self.cascade.empty():
raise gr.Error(f"Failed to load eye cascade: {self.eye_cascade}")
def _letterbox(self, bgr: np.ndarray) -> np.ndarray:
from utils.datasets import letterbox
try:
return letterbox(bgr, (self.imgsz, self.imgsz), stride=self.stride, auto=False)[0]
except TypeError:
try:
return letterbox(bgr, (self.imgsz, self.imgsz), auto=False)[0]
except TypeError:
return letterbox(bgr, (self.imgsz, self.imgsz))[0]
def _detect_faces(self, rgb: np.ndarray) -> List[Tuple[int, int, int, int, float]]:
import cv2
from utils.general import non_max_suppression, scale_coords
bgr = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
image = self._letterbox(bgr)
image = image[:, :, ::-1].transpose(2, 0, 1)
image = np.ascontiguousarray(image)
tensor = torch.from_numpy(image).to(self.device)
tensor = tensor.half() if self.use_half else tensor.float()
tensor = tensor / 255.0
tensor = tensor.unsqueeze(0)
with torch.inference_mode():
pred = self.model(tensor)[0]
pred = non_max_suppression(pred, conf_thres=self.conf, iou_thres=self.iou, classes=None, agnostic=False)[0]
boxes: List[Tuple[int, int, int, int, float]] = []
if pred is not None and len(pred):
h0, w0 = rgb.shape[:2]
pred[:, :4] = scale_coords((self.imgsz, self.imgsz), pred[:, :4], (h0, w0)).round()
for *xyxy, conf, _cls in pred.tolist():
x1, y1, x2, y2 = [int(v) for v in xyxy]
boxes.append((x1, y1, x2, y2, float(conf)))
return boxes
def _detect_eyes(self, face_rgb: np.ndarray) -> List[Tuple[str, Tuple[int, int, int, int]]]:
import cv2
height, width = face_rgb.shape[:2]
roi_h = max(1, int(height * 0.70))
roi = face_rgb[:roi_h, :]
gray = cv2.cvtColor(roi, cv2.COLOR_RGB2GRAY)
gray = cv2.GaussianBlur(gray, (3, 3), 0)
gray = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)).apply(gray)
min_size = max(8, int(0.07 * min(width, roi_h)))
raw = self.cascade.detectMultiScale(
gray,
scaleFactor=1.05,
minNeighbors=self.eye_neighbors,
minSize=(min_size, min_size),
flags=cv2.CASCADE_SCALE_IMAGE,
)
if raw is None:
return []
if isinstance(raw, tuple):
if len(raw) == 0:
return []
raw = raw[0]
arr = np.asarray(raw)
if arr.size == 0:
return []
if arr.ndim == 1:
arr = arr.reshape(1, -1)
boxes = []
for x, y, w, h in arr[:, :4]:
if w > 0 and h > 0:
boxes.append((int(x), int(y), int(x + w), int(y + h)))
return self._best_eye_pair(boxes, width, roi_h)
@staticmethod
def _best_eye_pair(
boxes: List[Tuple[int, int, int, int]],
width: int,
height: int,
) -> List[Tuple[str, Tuple[int, int, int, int]]]:
if len(boxes) < 2:
return [("left", boxes[0])] if len(boxes) == 1 else []
def center(box: Tuple[int, int, int, int]) -> Tuple[float, float]:
x1, y1, x2, y2 = box
return (x1 + x2) / 2.0, (y1 + y2) / 2.0
def area(box: Tuple[int, int, int, int]) -> int:
x1, y1, x2, y2 = box
return max(1, (x2 - x1) * (y2 - y1))
best_pair = None
best_score = float("inf")
for first, second in itertools.combinations(boxes, 2):
c1x, c1y = center(first)
c2x, c2y = center(second)
a1, a2 = area(first), area(second)
gap = abs(c2x - c1x) / max(1.0, width)
gap_penalty = 0.0 if 0.05 <= gap <= 0.5 else 0.5
score = abs(c1y - c2y) / max(1.0, height) + abs(a1 - a2) / max(a1, a2) + gap_penalty
if score < best_score:
best_score = score
best_pair = (first, second)
assert best_pair is not None
left, right = sorted(best_pair, key=lambda box: box[0] + box[2])
return [("left", left), ("right", right)]
def extract(self, image: Image.Image) -> ExtractedViews:
self._ensure_ready()
rgb = np.asarray(image.convert("RGB"))
height, width = rgb.shape[:2]
faces = self._detect_faces(rgb)
if not faces:
return ExtractedViews(None, None, None, None, "No face detected.")
face_box_raw = max(faces, key=lambda item: (item[2] - item[0]) * (item[3] - item[1]) * max(item[4], 1e-6))
face_box = tuple(face_box_raw[:4])
x1, y1, x2, y2 = face_box
face_rgb = rgb[y1:y2, x1:x2]
if face_rgb.size == 0:
return ExtractedViews(None, None, None, None, "Detected face crop is empty.")
face_image = Image.fromarray(face_rgb)
try:
eye_labels = self._detect_eyes(face_rgb)
except Exception as exc:
return ExtractedViews(face_image, None, face_box, None, f"Face detected; eye extraction failed ({exc}).")
if not eye_labels:
return ExtractedViews(face_image, None, face_box, None, "Face detected; no eye detected.")
# The style model was trained with a single eye crop. When both eyes are
# detected, match eval/inference behavior by using the left eye.
selected_label, selected_box = eye_labels[0]
for label, box in eye_labels:
if label == "left":
selected_label, selected_box = label, box
break
ex1, ey1, ex2, ey2 = _expand_box(selected_box, self.eye_margin, face_rgb.shape[1], face_rgb.shape[0])
eye_rgb = face_rgb[ey1:ey2, ex1:ex2]
if eye_rgb.size == 0:
return ExtractedViews(face_image, None, face_box, None, "Face detected; eye crop is empty.")
eye_box = (x1 + ex1, y1 + ey1, x1 + ex2, y1 + ey2)
return ExtractedViews(face_image, Image.fromarray(eye_rgb), face_box, eye_box, f"Face and {selected_label} eye detected.")
@lru_cache(maxsize=2)
def load_extractor(
yolo_dir: str,
yolo_weights: str,
eye_cascade: str,
device_name: str,
conf: float,
iou: float,
imgsz: int,
eye_neighbors: int,
eye_margin: float,
) -> AnimeFaceEyeExtractor:
return AnimeFaceEyeExtractor(
yolo_dir=yolo_dir,
yolo_weights=yolo_weights,
eye_cascade=eye_cascade,
device_name=device_name,
conf=conf,
iou=iou,
imgsz=imgsz,
eye_neighbors=eye_neighbors,
eye_margin=eye_margin,
)
def build_model(checkpoint: Dict[str, object], dinov3_root: str, dinov3_weights: str, device: torch.device) -> ArtistStyleModel:
checkpoint_args = checkpoint.get("args", {})
if not isinstance(checkpoint_args, dict):
checkpoint_args = {}
model = ArtistStyleModel(
num_classes=int(checkpoint_args.get("train_artist_count", 1000)),
branch_hidden_dim=int(checkpoint_args.get("branch_hidden_dim", 384)),
branch_dim=int(checkpoint_args.get("branch_dim", 192)),
embedding_dim=int(checkpoint_args.get("embedding_dim", 256)),
num_prototypes=int(checkpoint_args.get("num_prototypes", 4)),
prototype_temperature=float(checkpoint_args.get("prototype_temperature", 0.1)),
arcface_scale=float(checkpoint_args.get("arcface_scale", 30.0)),
view_dropout_prob=float(checkpoint_args.get("view_dropout_prob", 0.15)),
branch_dropout_prob=float(checkpoint_args.get("branch_dropout_prob", 0.1)),
backbone_repo_dir=dinov3_root,
backbone_entrypoint=str(checkpoint_args.get("backbone_entrypoint", "dinov3_vits16")),
backbone_weights_path=dinov3_weights,
freeze_backbone=bool(checkpoint_args.get("freeze_backbone", True)),
backbone_unfreeze_last_n_blocks=int(checkpoint_args.get("backbone_unfreeze_last_n_blocks", 0)),
).to(device)
model.load_state_dict(checkpoint["model"])
model.eval()
return model
@lru_cache(maxsize=1)
def load_runtime(
checkpoint_path: str,
prototype_bank_path: str,
dinov3_root: str,
dinov3_weights: str,
device_name: str,
) -> Dict[str, object]:
device = resolve_device(device_name)
checkpoint = torch.load(checkpoint_path, map_location=device)
checkpoint_args = checkpoint.get("args", {})
if not isinstance(checkpoint_args, dict):
checkpoint_args = {}
model = build_model(checkpoint, dinov3_root, dinov3_weights, device)
bank_payload = torch.load(prototype_bank_path, map_location="cpu")
prototype_bank = bank_payload["prototype_bank"]
prototype_descriptors = bank_payload.get("prototype_descriptors", {})
artists = sorted(prototype_bank)
prototype_tensor = torch.stack([prototype_bank[artist] for artist in artists], dim=0).float()
prototype_tensor = F.normalize(prototype_tensor, dim=-1).to(device)
return {
"device": device,
"model": model,
"checkpoint_args": checkpoint_args,
"artists": artists,
"prototype_tensor": prototype_tensor,
"prototype_descriptors": prototype_descriptors,
}
def resize_rgb(image: Image.Image, size: int) -> torch.Tensor:
image = image.convert("RGB")
image = TF.resize(image, [size, size], interpolation=Image.Resampling.BICUBIC)
return TF.to_tensor(image)
def blank(size: int) -> torch.Tensor:
return torch.zeros((3, size, size), dtype=torch.float32)
def prepare_inputs(
whole_image: Image.Image,
face_image: Optional[Image.Image],
eye_image: Optional[Image.Image],
checkpoint_args: Dict[str, object],
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
full_size = int(checkpoint_args.get("full_image_size", 320))
face_size = int(checkpoint_args.get("face_image_size", 256))
eye_size = int(checkpoint_args.get("eye_image_size", 224))
full = resize_rgb(whole_image, full_size)
has_face = face_image is not None
has_eye = eye_image is not None
face = resize_rgb(face_image, face_size) if has_face else blank(face_size)
eye = resize_rgb(eye_image, eye_size) if has_eye else blank(eye_size)
view_mask = torch.tensor([[1.0, float(has_face), float(has_eye)]], dtype=torch.float32)
return (
full.unsqueeze(0).to(device),
face.unsqueeze(0).to(device),
eye.unsqueeze(0).to(device),
view_mask.to(device),
)
def normalized_box(box: Optional[Tuple[int, int, int, int]], image_size: Tuple[int, int]) -> Optional[Tuple[float, float, float, float]]:
if box is None:
return None
width, height = image_size
if width <= 0 or height <= 0:
return None
x1, y1, x2, y2 = box
return (
max(0.0, min(1.0, x1 / width)),
max(0.0, min(1.0, y1 / height)),
max(0.0, min(1.0, x2 / width)),
max(0.0, min(1.0, y2 / height)),
)
def rank_artists(query: torch.Tensor, artists: List[str], prototype_tensor: torch.Tensor, top_k: int) -> Tuple[List[List[object]], str, int, float]:
sims = torch.einsum("d,apd->ap", query.float(), prototype_tensor)
scores, proto_indices = sims.max(dim=1)
k = min(max(1, int(top_k)), len(artists))
values, artist_indices = torch.topk(scores, k=k)
rows: List[List[object]] = []
for rank, (score, artist_idx) in enumerate(zip(values.tolist(), artist_indices.tolist()), start=1):
proto_idx = int(proto_indices[artist_idx].item())
rows.append([rank, artists[artist_idx], f"{float(score):.3f}"])
best_artist_idx = int(artist_indices[0].item())
return rows, artists[best_artist_idx], int(proto_indices[best_artist_idx].item()), float(values[0].item())
def tensor_to_map(heatmap: torch.Tensor, size: Tuple[int, int]) -> np.ndarray:
data = heatmap.detach().float().squeeze().cpu().numpy()
data = np.nan_to_num(data, nan=0.0, posinf=0.0, neginf=0.0)
data = data - float(data.min())
data = data / max(float(data.max()), 1e-6)
image = Image.fromarray(np.uint8(np.clip(data, 0.0, 1.0) * 255), mode="L")
image = image.resize(size, Image.Resampling.BICUBIC)
resized = np.asarray(image, dtype=np.float32) / 255.0
return np.clip(resized, 0.0, 1.0)
def heat_colors(values: np.ndarray) -> np.ndarray:
values = np.clip(values, 0.0, 1.0)
scaled = values * (len(HEAT_PALETTE) - 1)
left = np.floor(scaled).astype(np.int32)
right = np.clip(left + 1, 0, len(HEAT_PALETTE) - 1)
frac = scaled[..., None] - left[..., None]
colors = HEAT_PALETTE[left] * (1.0 - frac) + HEAT_PALETTE[right] * frac
return colors
def _paste_map(canvas: np.ndarray, heatmap: torch.Tensor, box: Tuple[int, int, int, int], size: Tuple[int, int]) -> None:
x1, y1, x2, y2 = box
width, height = size
x1 = max(0, min(width, int(x1)))
x2 = max(0, min(width, int(x2)))
y1 = max(0, min(height, int(y1)))
y2 = max(0, min(height, int(y2)))
if x2 <= x1 or y2 <= y1:
return
local = tensor_to_map(heatmap, (x2 - x1, y2 - y1))
canvas[y1:y2, x1:x2] = np.maximum(canvas[y1:y2, x1:x2], local)
def make_overlay(
base_image: Image.Image,
view_heatmaps: Dict[str, Optional[torch.Tensor]],
face_box: Optional[Tuple[int, int, int, int]] = None,
eye_box: Optional[Tuple[int, int, int, int]] = None,
) -> Image.Image:
base = base_image.convert("RGB")
original_size = base.size
display = base.copy()
display.thumbnail((720, 720), Image.Resampling.LANCZOS)
size = display.size
sx = size[0] / max(1, original_size[0])
sy = size[1] / max(1, original_size[1])
gray = np.asarray(base.convert("L"), dtype=np.float32) / 255.0
gray = np.asarray(display.convert("L"), dtype=np.float32) / 255.0
gray = 0.78 + 0.22 * gray
gray_rgb = np.repeat(gray[..., None], 3, axis=2)
maps = []
full_heatmap = view_heatmaps.get("full")
if full_heatmap is not None:
maps.append(tensor_to_map(full_heatmap, size))
spatial = np.zeros((size[1], size[0]), dtype=np.float32)
spatial_used = False
face_heatmap = view_heatmaps.get("face")
if face_heatmap is not None and face_box is not None:
scaled_face_box = (
int(round(face_box[0] * sx)),
int(round(face_box[1] * sy)),
int(round(face_box[2] * sx)),
int(round(face_box[3] * sy)),
)
_paste_map(spatial, face_heatmap, scaled_face_box, size)
spatial_used = True
elif face_heatmap is not None:
maps.append(tensor_to_map(face_heatmap, size))
eye_heatmap = view_heatmaps.get("eye")
if eye_heatmap is not None and eye_box is not None:
scaled_eye_box = (
int(round(eye_box[0] * sx)),
int(round(eye_box[1] * sy)),
int(round(eye_box[2] * sx)),
int(round(eye_box[3] * sy)),
)
_paste_map(spatial, eye_heatmap, scaled_eye_box, size)
spatial_used = True
elif eye_heatmap is not None:
maps.append(tensor_to_map(eye_heatmap, size))
if spatial_used:
maps.append(spatial)
if not maps:
return Image.fromarray(np.uint8(gray_rgb * 255))
# Equal per-view normalization and averaging keeps larger source crops or
# local face/eye maps from dominating purely because of area or scale.
combined = np.mean(np.stack(maps, axis=0), axis=0)
combined = combined - float(combined.min())
combined = combined / max(float(combined.max()), 1e-6)
heat_rgb = heat_colors(combined)
alpha = np.clip((combined - 0.05) / 0.95, 0.0, 1.0)
alpha = (0.90 * alpha)[..., None]
blended = gray_rgb * (1.0 - alpha) + heat_rgb * alpha
return Image.fromarray(np.uint8(np.clip(blended, 0.0, 1.0) * 255))
def scalar(value: torch.Tensor) -> float:
return float(value.detach().float().flatten()[0].cpu().item())
def _normalize_values(values: List[float]) -> List[float]:
clipped = [max(0.0, float(value)) for value in values]
total = sum(clipped)
if total <= 1e-8:
return [1.0 / len(clipped) for _ in clipped]
return [value / total for value in clipped]
def _bars_html(title: str, note: str, labels: List[str], values: List[float], colors: Dict[str, str]) -> str:
rows = []
for label, value in zip(labels, values):
pct = max(0.0, min(1.0, float(value))) * 100.0
color = colors.get(label, "#e34a33")
rows.append(
"<div class='bar-row' style='display:grid; grid-template-columns:82px minmax(120px,1fr) 52px; "
"gap:8px; align-items:center; margin:7px 0;'>"
f"<div class='bar-label' style='font-size:13px; font-weight:600;'>{label}</div>"
"<div class='bar-track' style='height:14px; background:#e8e8e8; border-radius:999px; overflow:hidden;'>"
f"<div class='bar-fill' style='width:{pct:.1f}%; height:14px; background:{color}; border-radius:999px;'></div>"
"</div>"
f"<div class='bar-value' style='font-variant-numeric:tabular-nums; font-size:12px; text-align:right;'>{pct:.1f}%</div>"
"</div>"
)
return (
"<div class='bars' style='display:grid; gap:4px; margin:4px 0 14px;'>"
f"<div class='bar-title' style='font-size:15px; font-weight:700; margin:0 0 3px;'>{title}</div>"
f"<div class='bar-note' style='font-size:12px; color:#666; margin:0 0 8px; line-height:1.35;'>{note}</div>"
f"{''.join(rows)}</div>"
)
def top_match_html(artist: str, similarity: float) -> str:
percent = max(0.0, min(100.0, float(similarity) * 100.0))
return (
"<div class='top-match'>"
"<div class='top-match-label'>Top artist match</div>"
f"<div class='top-match-artist'>{artist}</div>"
f"<div class='top-match-score'>{percent:.1f}% similarity</div>"
"</div>"
)
def contribution_bars(explanation: Dict[str, object]) -> Tuple[str, str, str]:
query_outputs = explanation["query_outputs"]
branch_values = query_outputs["branch_weights"][0].detach().float().cpu().tolist()
branch_values = _normalize_values(branch_values)
view_weights = query_outputs["stacked_view_weights"][0].detach().float().cpu()
branch_tensor = torch.tensor(branch_values, dtype=view_weights.dtype).unsqueeze(-1)
view_values = (view_weights * branch_tensor).sum(dim=0).tolist()
view_values = _normalize_values(view_values)
best_branch = BRANCH_NAMES[int(np.argmax(branch_values))]
best_view = VIEW_NAMES[int(np.argmax(view_values))]
summary = f"Dominant branch: {best_branch}. Dominant view: {best_view}."
return (
_bars_html(
"Influential style factors",
"Relative weight of structure, texture, linework, and color in the query embedding.",
BRANCH_NAMES,
branch_values,
BRANCH_COLORS,
),
_bars_html(
"Similar visual regions",
"Relative weight of whole image, face crop, and eye crop used for the match.",
VIEW_NAMES,
view_values,
VIEW_COLORS,
),
summary,
)
def analyze(
whole_image: Optional[Image.Image],
face_image: Optional[Image.Image],
eye_image: Optional[Image.Image],
auto_extract: bool,
top_k: int,
use_tta: bool,
checkpoint_path: str,
prototype_bank_path: str,
dinov3_root: str,
dinov3_weights: str,
device_name: str,
yolo_dir: str,
yolo_weights: str,
eye_cascade: str,
face_conf: float,
face_iou: float,
face_imgsz: int,
eye_neighbors: int,
eye_margin: float,
) -> Tuple[Optional[Image.Image], Optional[Image.Image], Image.Image, str, List[List[object]], str, str, str]:
if whole_image is None:
raise gr.Error("Whole image is required.")
runtime = load_runtime(checkpoint_path, prototype_bank_path, dinov3_root, dinov3_weights, device_name)
device = runtime["device"]
model = runtime["model"]
checkpoint_args = runtime["checkpoint_args"]
face_box = None
eye_box = None
extract_status = "Manual face/eye inputs."
used_face = face_image
used_eye = eye_image
if auto_extract:
extractor = load_extractor(
yolo_dir,
yolo_weights,
eye_cascade,
device_name,
float(face_conf),
float(face_iou),
int(face_imgsz),
int(eye_neighbors),
float(eye_margin),
)
try:
extracted = extractor.extract(whole_image)
except Exception as exc:
extracted = ExtractedViews(None, None, None, None, f"Auto crop failed ({exc}); using whole image only.")
used_face = extracted.face
used_eye = extracted.eye
face_box = extracted.face_box
eye_box = extracted.eye_box
extract_status = extracted.status
full, face, eye, view_mask = prepare_inputs(whole_image, used_face, used_eye, checkpoint_args, device)
with torch.no_grad():
query_outputs = _encode_query(model, full, face, eye, view_mask=view_mask, use_tta=use_tta)
query = F.normalize(query_outputs["embedding"][0], dim=0)
result_rows, best_artist, best_proto_idx, best_score = rank_artists(
query=query,
artists=runtime["artists"],
prototype_tensor=runtime["prototype_tensor"],
top_k=top_k,
)
descriptors = runtime["prototype_descriptors"].get(best_artist)
if not descriptors:
raise gr.Error(f"No prototype descriptor found for {best_artist}.")
reference_descriptor = descriptors[best_proto_idx]
view_attention_boxes = {}
face_attention_box = normalized_box(face_box, whole_image.size)
eye_attention_box = normalized_box(eye_box, whole_image.size)
if face_attention_box is not None:
view_attention_boxes["face"] = face_attention_box
if eye_attention_box is not None:
view_attention_boxes["eye"] = eye_attention_box
explanation = explain_against_reference(
model=model,
query_full=full,
query_face=face,
query_eye=eye,
query_view_mask=view_mask,
reference_descriptor=reference_descriptor,
view_attention_boxes=view_attention_boxes,
use_tta=use_tta,
)
overlay = make_overlay(whole_image, explanation["combined_view_heatmaps"], face_box=face_box, eye_box=eye_box)
branch_html, view_html, summary = contribution_bars(explanation)
summary = f"{extract_status} {summary} Target prototype: {best_artist} #{best_proto_idx}."
return used_face, used_eye, overlay, top_match_html(best_artist, best_score), result_rows, branch_html, view_html, summary
def build_ui(args: argparse.Namespace) -> gr.Blocks:
with gr.Blocks(title="Artist Style DINOv3") as demo:
gr.Markdown("# Artist Style DINOv3")
with gr.Row(elem_classes=["app"]):
with gr.Column(scale=4, min_width=320):
whole = gr.Image(label="Whole", type="pil", image_mode="RGB", height=260)
with gr.Row():
face = gr.Image(label="Face", type="pil", image_mode="RGB", height=160)
eye = gr.Image(label="Eye", type="pil", image_mode="RGB", height=160)
with gr.Row():
auto_extract = gr.Checkbox(value=True, label="Auto crop")
top_k = gr.Slider(1, 25, value=10, step=1, label="Top K")
use_tta = gr.Checkbox(value=False, label="TTA")
run = gr.Button("Analyze", variant="primary")
with gr.Accordion("Paths", open=False):
checkpoint_path = gr.Textbox(value=args.checkpoint, label="Checkpoint")
prototype_bank_path = gr.Textbox(value=args.prototype_bank, label="Prototype bank")
dinov3_root = gr.Textbox(value=args.dinov3_root, label="DINOv3 root")
dinov3_weights = gr.Textbox(value=args.dinov3_weights, label="DINOv3 weights")
yolo_dir = gr.Textbox(value=args.yolo_dir, label="YOLOv5 anime root")
yolo_weights = gr.Textbox(value=args.yolo_weights, label="YOLOv5 anime weights")
eye_cascade = gr.Textbox(value=args.eye_cascade, label="Eye cascade")
device_name = gr.Dropdown(["auto", "cpu", "cuda"], value=args.device, label="Device")
with gr.Accordion("Auto Crop", open=False):
face_conf = gr.Slider(0.05, 0.95, value=args.face_conf, step=0.05, label="Face confidence")
face_iou = gr.Slider(0.1, 0.9, value=args.face_iou, step=0.05, label="Face IoU")
face_imgsz = gr.Slider(320, 1280, value=args.face_imgsz, step=32, label="Face detector size")
eye_neighbors = gr.Slider(1, 20, value=args.eye_neighbors, step=1, label="Eye neighbors")
eye_margin = gr.Slider(0.0, 1.5, value=args.eye_margin, step=0.05, label="Eye margin")
with gr.Column(scale=7, min_width=640):
with gr.Row():
overlay = gr.Image(label="Composite heatmap", type="pil", height=520)
with gr.Column(scale=1):
summary = gr.Textbox(label="Strongest contribution", lines=3, elem_id="summary-box")
top_match = gr.HTML()
results = gr.Dataframe(
headers=["rank", "artist", "score"],
label="Retrieval",
datatype=["number", "str", "str"],
row_count=10,
)
with gr.Row():
branch_bars = gr.HTML(label="Branch")
view_bars = gr.HTML(label="View")
run.click(
fn=analyze,
inputs=[
whole,
face,
eye,
auto_extract,
top_k,
use_tta,
checkpoint_path,
prototype_bank_path,
dinov3_root,
dinov3_weights,
device_name,
yolo_dir,
yolo_weights,
eye_cascade,
face_conf,
face_iou,
face_imgsz,
eye_neighbors,
eye_margin,
],
outputs=[face, eye, overlay, top_match, results, branch_bars, view_bars, summary],
)
return demo
def main() -> None:
args = parse_args()
demo = build_ui(args)
demo.launch(server_name=args.server_name, server_port=args.server_port, share=args.share, css=APP_CSS)
if __name__ == "__main__":
main()