Spaces:
Running on Zero
Running on Zero
| from __future__ import annotations | |
| import itertools | |
| import sys | |
| import threading | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Optional, Tuple | |
| import numpy as np | |
| import torch | |
| def _patch_torch_load_for_old_ckpt() -> None: | |
| """ | |
| Matches `anime_face_eye_extract._patch_torch_load_for_old_ckpt()` to load older YOLOv5 checkpoints | |
| on newer torch versions. | |
| """ | |
| import numpy as _np | |
| try: | |
| torch.serialization.add_safe_globals([_np.core.multiarray._reconstruct, _np.ndarray]) | |
| except Exception: | |
| pass | |
| _orig_load = torch.load | |
| def _patched_load(*args, **kwargs): # noqa: ANN001 | |
| kwargs.setdefault("weights_only", False) | |
| return _orig_load(*args, **kwargs) | |
| torch.load = _patched_load | |
| def _pre(gray: np.ndarray) -> np.ndarray: | |
| import cv2 | |
| gray = cv2.GaussianBlur(gray, (3, 3), 0) | |
| clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) | |
| return clahe.apply(gray) | |
| def _expand(box, margin: float, W: int, H: int): | |
| x1, y1, x2, y2 = box | |
| cx = (x1 + x2) / 2.0 | |
| cy = (y1 + y2) / 2.0 | |
| w = (x2 - x1) * (1 + margin) | |
| h = (y2 - y1) * (1 + margin) | |
| nx1 = int(round(cx - w / 2)) | |
| ny1 = int(round(cy - h / 2)) | |
| nx2 = int(round(cx + w / 2)) | |
| ny2 = int(round(cy + h / 2)) | |
| nx1 = max(0, min(W, nx1)) | |
| ny1 = max(0, min(H, ny1)) | |
| nx2 = max(0, min(W, nx2)) | |
| ny2 = max(0, min(H, ny2)) | |
| return nx1, ny1, nx2, ny2 | |
| def _shrink(img: np.ndarray, limit: int): | |
| import cv2 | |
| h, w = img.shape[:2] | |
| m = max(h, w) | |
| if m <= limit: | |
| return img, 1.0 | |
| s = limit / float(m) | |
| nh, nw = int(h * s), int(w * s) | |
| small = cv2.resize(img, (nw, nh), interpolation=cv2.INTER_AREA) | |
| return small, s | |
| def _best_pair(boxes, W: int, H: int): | |
| clean = [(int(b[0]), int(b[1]), int(b[2]), int(b[3])) for b in boxes] | |
| if len(clean) < 2: | |
| return [] | |
| def cxcy(b): | |
| x1, y1, x2, y2 = b | |
| return (x1 + x2) / 2.0, (y1 + y2) / 2.0 | |
| def area(b): | |
| x1, y1, x2, y2 = b | |
| return max(1, (x2 - x1) * (y2 - y1)) | |
| best = None | |
| best_s = 1e9 | |
| for b1, b2 in itertools.combinations(clean, 2): | |
| c1x, c1y = cxcy(b1) | |
| c2x, c2y = cxcy(b2) | |
| a1, a2 = area(b1), area(b2) | |
| horiz = 0.0 if c1x < c2x else 0.5 | |
| y_aln = abs(c1y - c2y) / max(1.0, H) | |
| szsim = abs(a1 - a2) / float(max(a1, a2)) | |
| gap = abs(c2x - c1x) / max(1.0, W) | |
| if 0.05 <= gap <= 0.5: | |
| gap_pen = 0.0 | |
| else: | |
| gap_pen = 0.5 * ((0.5 + abs(gap - 0.05) * 10) if gap < 0.05 else (gap - 0.5) * 2.0) | |
| mean_y = (c1y + c2y) / 2.0 / max(1.0, H) | |
| upper = 0.3 * max(0.0, (mean_y - 0.67) * 2.0) | |
| s = y_aln + szsim + gap_pen + upper + horiz | |
| if s < best_s: | |
| best_s = s | |
| best = (b1, b2) | |
| if best is None: | |
| return [] | |
| b1, b2 = best | |
| left, right = (b1, b2) if (b1[0] + b1[2]) <= (b2[0] + b2[2]) else (b2, b1) | |
| return [("left", left), ("right", right)] | |
| class ExtractorCfg: | |
| yolo_dir: Path | |
| weights: Path | |
| cascade: Path | |
| imgsz: int = 640 | |
| conf: float = 0.5 | |
| iou: float = 0.5 | |
| yolo_device: str = "cpu" # "cpu" or "0" | |
| eye_roi_frac: float = 0.70 | |
| eye_min_size: int = 12 | |
| eye_margin: float = 0.60 | |
| neighbors: int = 9 | |
| eye_downscale_limit_roi: int = 512 | |
| eye_downscale_limit_face: int = 768 | |
| eye_fallback_to_face: bool = True | |
| class AnimeFaceEyeExtractor: | |
| """ | |
| Single-image view extractor (whole -> face crop, eyes crop) based on `anime_face_eye_extract.py`. | |
| Designed for use in the Gradio UI: caches YOLO model + Haar cascade. | |
| """ | |
| def __init__(self, cfg: ExtractorCfg): | |
| self.cfg = cfg | |
| self._model = None | |
| self._device = None | |
| self._stride = 32 | |
| self._tl = threading.local() | |
| def _init_detector(self) -> None: | |
| if self._model is not None: | |
| return | |
| ydir = self.cfg.yolo_dir.resolve() | |
| if not ydir.exists(): | |
| raise RuntimeError(f"yolov5_anime dir not found: {ydir}") | |
| if str(ydir) not in sys.path: | |
| sys.path.insert(0, str(ydir)) | |
| _patch_torch_load_for_old_ckpt() | |
| from models.experimental import attempt_load | |
| from utils.torch_utils import select_device | |
| self._device = select_device(self.cfg.yolo_device) | |
| self._model = attempt_load(str(self.cfg.weights), map_location=self._device) | |
| self._model.eval() | |
| self._stride = int(self._model.stride.max()) | |
| s = int(self.cfg.imgsz) | |
| s = int(np.ceil(s / self._stride) * self._stride) | |
| self.cfg.imgsz = s | |
| def _letterbox_compat(self, img0, new_shape, stride): | |
| from utils.datasets import letterbox | |
| try: | |
| lb = letterbox(img0, new_shape, stride=stride, auto=False) | |
| except TypeError: | |
| try: | |
| lb = letterbox(img0, new_shape, auto=False) | |
| except TypeError: | |
| lb = letterbox(img0, new_shape) | |
| return lb[0] | |
| def _detect_faces(self, rgb: np.ndarray): | |
| import cv2 | |
| self._init_detector() | |
| from utils.general import non_max_suppression, scale_coords | |
| img0 = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR) | |
| h0, w0, _ = img0.shape | |
| img = self._letterbox_compat(img0, self.cfg.imgsz, self._stride) | |
| img = img[:, :, ::-1].transpose(2, 0, 1) | |
| img = np.ascontiguousarray(img) | |
| im = torch.from_numpy(img).to(self._device) | |
| im = im.float() / 255.0 | |
| if im.ndim == 3: | |
| im = im[None] | |
| with torch.no_grad(): | |
| pred = self._model(im)[0] | |
| pred = non_max_suppression(pred, conf_thres=self.cfg.conf, iou_thres=self.cfg.iou, classes=None, agnostic=False) | |
| boxes = [] | |
| det = pred[0] | |
| if det is not None and len(det): | |
| det[:, :4] = scale_coords((self.cfg.imgsz, self.cfg.imgsz), det[:, :4], (h0, w0)).round() | |
| for *xyxy, conf, cls in det.tolist(): | |
| x1, y1, x2, y2 = [int(v) for v in xyxy] | |
| boxes.append((x1, y1, x2, y2)) | |
| return boxes | |
| def _get_cascade(self): | |
| import cv2 | |
| c = getattr(self._tl, "cascade", None) | |
| if c is None: | |
| c = cv2.CascadeClassifier(str(self.cfg.cascade)) | |
| if c.empty(): | |
| raise RuntimeError(f"cascade load fail: {self.cfg.cascade}") | |
| self._tl.cascade = c | |
| return c | |
| def _detect_eyes_in_roi(self, rgb_roi: np.ndarray): | |
| import cv2 | |
| gray = cv2.cvtColor(rgb_roi, cv2.COLOR_RGB2GRAY) | |
| proc = _pre(gray) | |
| H, W = proc.shape[:2] | |
| min_side = max(1, min(W, H)) | |
| dyn_min = int(0.07 * min_side) | |
| min_sz = max(8, int(self.cfg.eye_min_size), dyn_min) | |
| cascade = self._get_cascade() | |
| raw = cascade.detectMultiScale( | |
| proc, | |
| scaleFactor=1.15, | |
| minNeighbors=int(self.cfg.neighbors), | |
| minSize=(min_sz, min_sz), | |
| flags=cv2.CASCADE_SCALE_IMAGE, | |
| ) | |
| try: | |
| arr = np.asarray(raw if not isinstance(raw, tuple) else raw[0]) | |
| except Exception: | |
| arr = np.empty((0, 4), dtype=int) | |
| if arr.size == 0: | |
| return [] | |
| if arr.ndim == 1: | |
| arr = arr.reshape(1, -1) | |
| boxes = [] | |
| for r in arr: | |
| x, y, w, h = [int(v) for v in r[:4]] | |
| if w <= 0 or h <= 0: | |
| continue | |
| boxes.append((x, y, x + w, y + h)) | |
| return boxes | |
| def _pick_best_face(boxes): | |
| if not boxes: | |
| return None | |
| # choose largest-area face | |
| def area(b): | |
| x1, y1, x2, y2 = b | |
| return max(1, (x2 - x1) * (y2 - y1)) | |
| return max(boxes, key=area) | |
| def extract(self, whole_rgb: np.ndarray) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]: | |
| """ | |
| Args: | |
| whole_rgb: HWC RGB uint8 | |
| Returns: | |
| (face_rgb, eyes_rgb) as RGB uint8 crops (or None if not found) | |
| """ | |
| import cv2 | |
| boxes = self._detect_faces(whole_rgb) | |
| face_box = self._pick_best_face(boxes) | |
| if face_box is None: | |
| return None, None | |
| x1, y1, x2, y2 = face_box | |
| H0, W0 = whole_rgb.shape[:2] | |
| x1 = max(0, min(W0, x1)) | |
| x2 = max(0, min(W0, x2)) | |
| y1 = max(0, min(H0, y1)) | |
| y2 = max(0, min(H0, y2)) | |
| if x2 <= x1 or y2 <= y1: | |
| return None, None | |
| face = whole_rgb[y1:y2, x1:x2].copy() | |
| # eye detection on upper ROI | |
| H, W = face.shape[:2] | |
| roi_h = int(H * float(self.cfg.eye_roi_frac)) | |
| roi = face[0: max(1, roi_h), :] | |
| roi_small, s_roi = _shrink(roi, int(self.cfg.eye_downscale_limit_roi)) | |
| face_small, s_face = _shrink(face, int(self.cfg.eye_downscale_limit_face)) | |
| eyes_roi = self._detect_eyes_in_roi(roi_small) | |
| eyes_roi = [(int(a / s_roi), int(b / s_roi), int(c / s_roi), int(d / s_roi)) for (a, b, c, d) in eyes_roi] | |
| labs = _best_pair(eyes_roi, W, roi.shape[0]) | |
| origin = "roi" if labs else None | |
| eyes_full = [] | |
| if self.cfg.eye_fallback_to_face and (not labs): | |
| eyes_full = self._detect_eyes_in_roi(face_small) | |
| eyes_full = [(int(a / s_face), int(b / s_face), int(c / s_face), int(d / s_face)) for (a, b, c, d) in eyes_full] | |
| if len(eyes_full) >= 2: | |
| labs = _best_pair(eyes_full, W, H) | |
| origin = "face" if labs else origin | |
| if not labs: | |
| cand = eyes_roi | |
| cand_origin = "roi" | |
| if self.cfg.eye_fallback_to_face and len(eyes_full) >= 1: | |
| cand = eyes_full | |
| cand_origin = "face" | |
| if len(cand) >= 2: | |
| top2 = sorted(cand, key=lambda b: (b[2] - b[0]) * (b[3] - b[1]), reverse=True)[:2] | |
| top2 = sorted(top2, key=lambda b: (b[0] + b[2])) | |
| labs = [("left", top2[0]), ("right", top2[1])] | |
| origin = cand_origin | |
| elif len(cand) == 1: | |
| labs = [("left", cand[0])] | |
| origin = cand_origin | |
| eyes_crop = None | |
| if labs: | |
| src_img = roi if origin == "roi" else face | |
| bound_h = roi.shape[0] if origin == "roi" else H | |
| boxes_only = [b for _, b in labs] | |
| # union of eye boxes -> single eyes crop (works for the "eyes" view encoder) | |
| ux1 = min(b[0] for b in boxes_only) | |
| uy1 = min(b[1] for b in boxes_only) | |
| ux2 = max(b[2] for b in boxes_only) | |
| uy2 = max(b[3] for b in boxes_only) | |
| ex1, ey1, ex2, ey2 = _expand((ux1, uy1, ux2, uy2), float(self.cfg.eye_margin), W, bound_h) | |
| crop = src_img[ey1:ey2, ex1:ex2] | |
| if crop.size > 0 and min(crop.shape[0], crop.shape[1]) >= int(self.cfg.eye_min_size): | |
| eyes_crop = crop.copy() | |
| return face, eyes_crop | |