|
|
""" |
|
|
Three-View-Style-Embedder - Inference Utilities |
|
|
Lazy loading for Hugging Face Spaces compatibility |
|
|
""" |
|
|
from pathlib import Path |
|
|
from typing import List, Optional, Tuple |
|
|
import numpy as np |
|
|
import torch |
|
|
from PIL import Image |
|
|
from torchvision import transforms |
|
|
import threading |
|
|
|
|
|
|
|
|
try: |
|
|
import spaces |
|
|
except ImportError: |
|
|
class spaces: |
|
|
@staticmethod |
|
|
def GPU(func): |
|
|
return func |
|
|
|
|
|
def _import_gradio(): |
|
|
try: |
|
|
import gradio as gr |
|
|
return gr |
|
|
except Exception as e: |
|
|
raise RuntimeError( |
|
|
"Failed to import Gradio. This usually means you're running with the wrong Python interpreter " |
|
|
"(e.g., system Python instead of the workspace .venv) or you have incompatible package versions.\n" |
|
|
"Fix: run with the venv interpreter: .\\.venv\\Scripts\\python.exe app.py ...\n" |
|
|
"Or on Windows: run run.bat" |
|
|
) from e |
|
|
|
|
|
def _default_path(path_str: str) -> Path: |
|
|
return (Path(__file__).resolve().parent / path_str).resolve() |
|
|
|
|
|
from config import get_config |
|
|
from model import ArtistStyleModel |
|
|
|
|
|
class FaceEyeExtractor: |
|
|
def __init__( |
|
|
self, |
|
|
yolo_dir: Path, |
|
|
weights_path: Path, |
|
|
cascade_path: Path, |
|
|
device: str = 'cpu', |
|
|
imgsz: int = 640, |
|
|
conf: float = 0.5, |
|
|
iou: float = 0.5, |
|
|
eye_roi_frac: float = 0.70, |
|
|
eye_min_size: int = 12, |
|
|
eye_margin: float = 0.60, |
|
|
neighbors: int = 9, |
|
|
eye_fallback_to_face: bool = True, |
|
|
): |
|
|
self.yolo_dir = Path(yolo_dir) |
|
|
self.weights_path = Path(weights_path) |
|
|
self.cascade_path = Path(cascade_path) |
|
|
self.device = device |
|
|
self.imgsz = imgsz |
|
|
self.conf = conf |
|
|
self.iou = iou |
|
|
self.eye_roi_frac = eye_roi_frac |
|
|
self.eye_min_size = eye_min_size |
|
|
self.eye_margin = eye_margin |
|
|
self.neighbors = neighbors |
|
|
self.eye_fallback_to_face = eye_fallback_to_face |
|
|
|
|
|
|
|
|
self._yolo_model = None |
|
|
self._yolo_device = None |
|
|
self._stride = 32 |
|
|
self._tl = threading.local() |
|
|
|
|
|
def __getstate__(self): |
|
|
state = self.__dict__.copy() |
|
|
if "_tl" in state: |
|
|
del state["_tl"] |
|
|
return state |
|
|
|
|
|
def __setstate__(self, state): |
|
|
self.__dict__.update(state) |
|
|
self._tl = threading.local() |
|
|
|
|
|
def _patch_torch_load_for_old_ckpt(self): |
|
|
import torch as _torch |
|
|
import numpy as _np |
|
|
|
|
|
try: |
|
|
_torch.serialization.add_safe_globals([ |
|
|
_np.core.multiarray._reconstruct, |
|
|
_np.ndarray, |
|
|
]) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
def _ensure_ready(self): |
|
|
if self._yolo_model is not None and self._cascade is not None: |
|
|
return |
|
|
|
|
|
|
|
|
import sys |
|
|
import cv2 |
|
|
|
|
|
|
|
|
if not self.yolo_dir.exists(): |
|
|
|
|
|
cwd_yolo = Path("yolov5_anime").resolve() |
|
|
if cwd_yolo.exists(): |
|
|
self.yolo_dir = cwd_yolo |
|
|
else: |
|
|
|
|
|
file_yolo = Path(__file__).parent / "yolov5_anime" |
|
|
if file_yolo.exists(): |
|
|
self.yolo_dir = file_yolo |
|
|
|
|
|
if not self.yolo_dir.exists(): |
|
|
raise RuntimeError( |
|
|
f"yolov5_anime directory not found. Tried: {self.yolo_dir}, " |
|
|
f"current dir: {Path.cwd()}, file dir: {Path(__file__).parent}" |
|
|
) |
|
|
|
|
|
|
|
|
yolo_path_str = str(self.yolo_dir.resolve()) |
|
|
if yolo_path_str not in sys.path: |
|
|
sys.path.insert(0, yolo_path_str) |
|
|
|
|
|
self._patch_torch_load_for_old_ckpt() |
|
|
|
|
|
import torch as _torch |
|
|
|
|
|
try: |
|
|
from models.experimental import attempt_load |
|
|
from utils.torch_utils import select_device |
|
|
except ImportError as e: |
|
|
raise RuntimeError( |
|
|
f"Failed to import YOLOv5 modules. Make sure yolov5_anime directory exists at {self.yolo_dir}. " |
|
|
f"sys.path includes: {[p for p in sys.path if 'yolo' in p.lower()]}. " |
|
|
f"Original error: {e}" |
|
|
) from e |
|
|
|
|
|
|
|
|
orig_load = _torch.load |
|
|
|
|
|
def patched_load(*args, **kwargs): |
|
|
kwargs.setdefault('weights_only', False) |
|
|
return orig_load(*args, **kwargs) |
|
|
|
|
|
_torch.load = patched_load |
|
|
try: |
|
|
|
|
|
detector_device = 'cpu' if self.device.startswith('cuda') else self.device |
|
|
self._yolo_device = select_device(detector_device) |
|
|
if not self.weights_path.exists(): |
|
|
raise RuntimeError(f"YOLO weights not found: {self.weights_path}") |
|
|
self._yolo_model = attempt_load(str(self.weights_path), map_location=self._yolo_device) |
|
|
self._yolo_model.eval() |
|
|
self._stride = int(self._yolo_model.stride.max()) |
|
|
finally: |
|
|
_torch.load = orig_load |
|
|
|
|
|
if not self.cascade_path.exists(): |
|
|
raise RuntimeError(f"Cascade xml not found: {self.cascade_path}") |
|
|
|
|
|
cascade = cv2.CascadeClassifier(str(self.cascade_path)) |
|
|
if cascade.empty(): |
|
|
raise RuntimeError(f"cascade load failed: {self.cascade_path}") |
|
|
self._tl.cascade = cascade |
|
|
|
|
|
def _letterbox_compat(self, img0, new_shape, stride): |
|
|
from utils.datasets import letterbox |
|
|
try: |
|
|
out = letterbox(img0, new_shape, stride=stride, auto=False) |
|
|
except TypeError: |
|
|
try: |
|
|
out = letterbox(img0, new_shape, auto=False) |
|
|
except TypeError: |
|
|
out = letterbox(img0, new_shape) |
|
|
return out[0] |
|
|
|
|
|
def _detect_faces(self, rgb: np.ndarray) -> List[Tuple[int, int, int, int]]: |
|
|
self._ensure_ready() |
|
|
|
|
|
import cv2 |
|
|
import torch as _torch |
|
|
from utils.general import non_max_suppression, scale_coords |
|
|
|
|
|
img0 = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR) |
|
|
h0, w0 = img0.shape[:2] |
|
|
|
|
|
imgsz = int(np.ceil(self.imgsz / self._stride) * self._stride) |
|
|
img = self._letterbox_compat(img0, imgsz, self._stride) |
|
|
img = img[:, :, ::-1].transpose(2, 0, 1) |
|
|
img = np.ascontiguousarray(img) |
|
|
|
|
|
im = _torch.from_numpy(img).to(self._yolo_device) |
|
|
im = im.float() / 255.0 |
|
|
if im.ndim == 3: |
|
|
im = im[None] |
|
|
|
|
|
with _torch.no_grad(): |
|
|
pred = self._yolo_model(im)[0] |
|
|
|
|
|
pred = non_max_suppression( |
|
|
pred, |
|
|
conf_thres=self.conf, |
|
|
iou_thres=self.iou, |
|
|
classes=None, |
|
|
agnostic=False, |
|
|
) |
|
|
|
|
|
boxes: List[Tuple[int, int, int, int, float]] = [] |
|
|
det = pred[0] |
|
|
if det is not None and len(det): |
|
|
det[:, :4] = scale_coords((imgsz, imgsz), det[:, :4], (h0, w0)).round() |
|
|
for *xyxy, conf, _cls in det.tolist(): |
|
|
x1, y1, x2, y2 = [int(v) for v in xyxy] |
|
|
boxes.append((x1, y1, x2, y2, float(conf))) |
|
|
|
|
|
|
|
|
boxes_xyxy = [(b[0], b[1], b[2], b[3]) for b in boxes] |
|
|
return boxes_xyxy |
|
|
|
|
|
def _expand(self, box, margin, W, H): |
|
|
x1, y1, x2, y2 = box |
|
|
cx = (x1 + x2) / 2.0 |
|
|
cy = (y1 + y2) / 2.0 |
|
|
w = (x2 - x1) * (1 + margin) |
|
|
h = (y2 - y1) * (1 + margin) |
|
|
nx1 = int(round(cx - w / 2)) |
|
|
ny1 = int(round(cy - h / 2)) |
|
|
nx2 = int(round(cx + w / 2)) |
|
|
ny2 = int(round(cy + h / 2)) |
|
|
nx1 = max(0, min(W, nx1)) |
|
|
ny1 = max(0, min(H, ny1)) |
|
|
nx2 = max(0, min(W, nx2)) |
|
|
ny2 = max(0, min(H, ny2)) |
|
|
return nx1, ny1, nx2, ny2 |
|
|
|
|
|
def _pre(self, gray): |
|
|
import cv2 |
|
|
|
|
|
gray = cv2.GaussianBlur(gray, (3, 3), 0) |
|
|
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) |
|
|
return clahe.apply(gray) |
|
|
|
|
|
def _shrink_for_eye(self, img, limit=900): |
|
|
import cv2 |
|
|
|
|
|
h, w = img.shape[:2] |
|
|
m = max(h, w) |
|
|
if m <= limit: |
|
|
return img, 1.0 |
|
|
s = limit / float(m) |
|
|
nh, nw = int(h * s), int(w * s) |
|
|
small = cv2.resize(img, (nw, nh), interpolation=cv2.INTER_AREA) |
|
|
return small, s |
|
|
|
|
|
def _detect_eyes_in_roi(self, rgb_roi): |
|
|
import cv2 |
|
|
|
|
|
gray = cv2.cvtColor(rgb_roi, cv2.COLOR_RGB2GRAY) |
|
|
proc = self._pre(gray) |
|
|
|
|
|
H, W = proc.shape[:2] |
|
|
min_side = max(1, min(W, H)) |
|
|
dyn_min = int(0.07 * min_side) |
|
|
min_sz = max(8, int(self.eye_min_size), dyn_min) |
|
|
|
|
|
cascade = getattr(self._tl, 'cascade', None) |
|
|
if cascade is None: |
|
|
cascade = cv2.CascadeClassifier(str(self.cascade_path)) |
|
|
if cascade.empty(): |
|
|
raise RuntimeError(f"cascade load failed: {self.cascade_path}") |
|
|
self._tl.cascade = cascade |
|
|
|
|
|
raw = cascade.detectMultiScale( |
|
|
proc, |
|
|
scaleFactor=1.15, |
|
|
minNeighbors=self.neighbors, |
|
|
minSize=(min_sz, min_sz), |
|
|
flags=cv2.CASCADE_SCALE_IMAGE, |
|
|
) |
|
|
|
|
|
try: |
|
|
arr = np.asarray(raw if not isinstance(raw, tuple) else raw[0]) |
|
|
except Exception: |
|
|
arr = np.empty((0, 4), dtype=int) |
|
|
if arr.size == 0: |
|
|
return [] |
|
|
if arr.ndim == 1: |
|
|
arr = arr.reshape(1, -1) |
|
|
|
|
|
boxes = [] |
|
|
for r in arr: |
|
|
x, y, w, h = [int(v) for v in r[:4]] |
|
|
if w <= 0 or h <= 0: |
|
|
continue |
|
|
boxes.append((x, y, x + w, y + h)) |
|
|
return boxes |
|
|
|
|
|
def _best_pair(self, boxes, W, H): |
|
|
import itertools |
|
|
|
|
|
clean = [(int(b[0]), int(b[1]), int(b[2]), int(b[3])) for b in boxes] |
|
|
if len(clean) < 2: |
|
|
return [] |
|
|
|
|
|
def cxcy(b): |
|
|
x1, y1, x2, y2 = b |
|
|
return (x1 + x2) / 2.0, (y1 + y2) / 2.0 |
|
|
|
|
|
def area(b): |
|
|
x1, y1, x2, y2 = b |
|
|
return max(1, (x2 - x1) * (y2 - y1)) |
|
|
|
|
|
best = None |
|
|
best_s = 1e9 |
|
|
for b1, b2 in itertools.combinations(clean, 2): |
|
|
c1x, c1y = cxcy(b1) |
|
|
c2x, c2y = cxcy(b2) |
|
|
a1, a2 = area(b1), area(b2) |
|
|
horiz = 0.0 if c1x < c2x else 0.5 |
|
|
y_aln = abs(c1y - c2y) / max(1.0, H) |
|
|
szsim = abs(a1 - a2) / float(max(a1, a2)) |
|
|
gap = abs(c2x - c1x) / max(1.0, W) |
|
|
if 0.05 <= gap <= 0.5: |
|
|
gap_pen = 0.0 |
|
|
else: |
|
|
gap_pen = 0.5 * ((0.5 + abs(gap - 0.05) * 10) if gap < 0.05 else (gap - 0.5) * 2.0) |
|
|
mean_y = (c1y + c2y) / 2.0 / max(1.0, H) |
|
|
upper = 0.3 * max(0.0, (mean_y - 0.67) * 2.0) |
|
|
s = y_aln + szsim + gap_pen + upper + horiz |
|
|
if s < best_s: |
|
|
best_s = s |
|
|
best = (b1, b2) |
|
|
|
|
|
if best is None: |
|
|
return [] |
|
|
|
|
|
b1, b2 = best |
|
|
left, right = (b1, b2) if (b1[0] + b1[2]) <= (b2[0] + b2[2]) else (b2, b1) |
|
|
return [("left", left), ("right", right)] |
|
|
|
|
|
def extract_face(self, full_image: Image.Image) -> Optional[Image.Image]: |
|
|
rgb = np.array(full_image.convert('RGB')) |
|
|
boxes = self._detect_faces(rgb) |
|
|
if not boxes: |
|
|
return None |
|
|
|
|
|
|
|
|
def area(b): |
|
|
x1, y1, x2, y2 = b |
|
|
return max(0, x2 - x1) * max(0, y2 - y1) |
|
|
|
|
|
x1, y1, x2, y2 = max(boxes, key=area) |
|
|
H, W = rgb.shape[:2] |
|
|
x1 = max(0, min(W, x1)) |
|
|
x2 = max(0, min(W, x2)) |
|
|
y1 = max(0, min(H, y1)) |
|
|
y2 = max(0, min(H, y2)) |
|
|
if x2 <= x1 or y2 <= y1: |
|
|
return None |
|
|
face = rgb[y1:y2, x1:x2] |
|
|
return Image.fromarray(face) |
|
|
|
|
|
def extract_eye_region(self, face_image: Image.Image) -> Optional[Image.Image]: |
|
|
|
|
|
self._ensure_ready() |
|
|
|
|
|
rgb_face = np.array(face_image.convert('RGB')) |
|
|
H, W = rgb_face.shape[:2] |
|
|
if H < 2 or W < 2: |
|
|
return None |
|
|
|
|
|
roi_h = int(H * float(self.eye_roi_frac)) |
|
|
roi_h = max(1, min(H, roi_h)) |
|
|
roi = rgb_face[0:roi_h, :] |
|
|
|
|
|
roi_small, s_roi = self._shrink_for_eye(roi, limit=512) |
|
|
face_small, s_face = self._shrink_for_eye(rgb_face, limit=768) |
|
|
|
|
|
eyes_roi = self._detect_eyes_in_roi(roi_small) |
|
|
eyes_roi = [ |
|
|
(int(x1 / s_roi), int(y1 / s_roi), int(x2 / s_roi), int(y2 / s_roi)) |
|
|
for (x1, y1, x2, y2) in eyes_roi |
|
|
] |
|
|
labs = self._best_pair(eyes_roi, W, roi_h) |
|
|
origin = 'roi' if labs else None |
|
|
|
|
|
eyes_full = [] |
|
|
if self.eye_fallback_to_face and (not labs or len(labs) < 2): |
|
|
eyes_full = self._detect_eyes_in_roi(face_small) |
|
|
eyes_full = [ |
|
|
(int(x1 / s_face), int(y1 / s_face), int(x2 / s_face), int(y2 / s_face)) |
|
|
for (x1, y1, x2, y2) in eyes_full |
|
|
] |
|
|
if len(eyes_full) >= 2: |
|
|
labs = self._best_pair(eyes_full, W, H) |
|
|
origin = 'face' if labs else origin |
|
|
|
|
|
if not labs: |
|
|
cand = eyes_roi |
|
|
cand_origin = 'roi' |
|
|
if self.eye_fallback_to_face and len(eyes_full) >= 1: |
|
|
cand = eyes_full |
|
|
cand_origin = 'face' |
|
|
if len(cand) >= 2: |
|
|
top2 = sorted(cand, key=lambda b: (b[2] - b[0]) * (b[3] - b[1]), reverse=True)[:2] |
|
|
top2 = sorted(top2, key=lambda b: (b[0] + b[2])) |
|
|
labs = [("left", top2[0]), ("right", top2[1])] |
|
|
origin = cand_origin |
|
|
elif len(cand) == 1: |
|
|
labs = [("left", cand[0])] |
|
|
origin = cand_origin |
|
|
|
|
|
if not labs: |
|
|
return None |
|
|
|
|
|
boxes = [box for _label, box in labs] |
|
|
if len(boxes) >= 2: |
|
|
boxes = sorted(boxes, key=lambda b: (b[0] + b[2]))[:2] |
|
|
|
|
|
src_img = roi if origin == 'roi' else rgb_face |
|
|
bound_h = roi_h if origin == 'roi' else H |
|
|
|
|
|
|
|
|
target_box = boxes[0] |
|
|
bx1, by1, bx2, by2 = target_box |
|
|
|
|
|
|
|
|
ex1, ey1, ex2, ey2 = self._expand((bx1, by1, bx2, by2), self.eye_margin, W, bound_h) |
|
|
|
|
|
|
|
|
ew = ex2 - ex1 |
|
|
eh = ey2 - ey1 |
|
|
if ew > eh: |
|
|
|
|
|
diff = ew - eh |
|
|
ey1 = max(0, ey1 - diff // 2) |
|
|
ey2 = min(bound_h, ey2 + (diff - diff // 2)) |
|
|
elif eh > ew: |
|
|
|
|
|
diff = eh - ew |
|
|
ex1 = max(0, ex1 - diff // 2) |
|
|
ex2 = min(W, ex2 + (diff - diff // 2)) |
|
|
|
|
|
crop = src_img[ey1:ey2, ex1:ex2] |
|
|
if crop.size == 0 or min(crop.shape[0], crop.shape[1]) < self.eye_min_size: |
|
|
return None |
|
|
return Image.fromarray(crop) |
|
|
|
|
|
|
|
|
class StyleEmbedderApp: |
|
|
"""Web UI μ± - Lazy loading for Spaces compatibility""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
checkpoint_path: str, |
|
|
embeddings_path: str, |
|
|
device: str = 'cuda', |
|
|
yolo_dir: Optional[str] = None, |
|
|
yolo_weights: Optional[str] = None, |
|
|
eyes_cascade: Optional[str] = None, |
|
|
detector_device: str = 'cpu', |
|
|
): |
|
|
|
|
|
self.checkpoint_path = checkpoint_path |
|
|
self.embeddings_path = embeddings_path |
|
|
self.requested_device = device |
|
|
self.detector_device = detector_device |
|
|
|
|
|
|
|
|
self._model = None |
|
|
self._model_loading = False |
|
|
self._embeddings_loaded = False |
|
|
self._artist_names = None |
|
|
self._embeddings = None |
|
|
|
|
|
|
|
|
self._extractor = None |
|
|
self._extractor_yolo_dir = yolo_dir |
|
|
self._extractor_yolo_weights = yolo_weights |
|
|
self._extractor_eyes_cascade = eyes_cascade |
|
|
|
|
|
|
|
|
self.transform = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
|
]) |
|
|
|
|
|
def _ensure_model_loaded(self): |
|
|
"""Lazy load model - only called inside @spaces.GPU decorated function""" |
|
|
if self._model is not None: |
|
|
return |
|
|
|
|
|
|
|
|
if self._model_loading: |
|
|
|
|
|
import time |
|
|
while self._model_loading and self._model is None: |
|
|
time.sleep(0.01) |
|
|
return |
|
|
|
|
|
if self._model is not None: |
|
|
return |
|
|
|
|
|
self._model_loading = True |
|
|
try: |
|
|
print("Loading model (lazy)...") |
|
|
|
|
|
checkpoint = torch.load(self.checkpoint_path, map_location='cpu') |
|
|
config = get_config() |
|
|
|
|
|
self._model = ArtistStyleModel( |
|
|
num_classes=len(checkpoint['artist_to_idx']), |
|
|
embedding_dim=config.model.embedding_dim, |
|
|
hidden_dim=config.model.hidden_dim, |
|
|
) |
|
|
self._model.load_state_dict(checkpoint['model_state_dict']) |
|
|
|
|
|
|
|
|
if self.requested_device.startswith('cuda') and torch.cuda.is_available(): |
|
|
device = torch.device(self.requested_device) |
|
|
|
|
|
self._model = self._model.to(dtype=torch.float16) |
|
|
else: |
|
|
device = torch.device('cpu') |
|
|
|
|
|
self._model = self._model.to(device) |
|
|
self._model.eval() |
|
|
self.device = device |
|
|
self.embedding_dim = config.model.embedding_dim |
|
|
|
|
|
print("Model loaded successfully") |
|
|
finally: |
|
|
self._model_loading = False |
|
|
|
|
|
def _ensure_embeddings_loaded(self): |
|
|
"""Lazy load embeddings - no CUDA needed""" |
|
|
if self._embeddings_loaded: |
|
|
return |
|
|
|
|
|
|
|
|
if self._embeddings_loaded: |
|
|
return |
|
|
|
|
|
print("Loading embeddings...") |
|
|
data = np.load(self.embeddings_path) |
|
|
self._artist_names = data['artist_names'].tolist() |
|
|
self._embeddings = data['embeddings'] |
|
|
self._embeddings_loaded = True |
|
|
print(f"Loaded {len(self._artist_names)} artist embeddings") |
|
|
|
|
|
def preprocess_image(self, image: Optional[Image.Image]) -> Optional[torch.Tensor]: |
|
|
"""μ΄λ―Έμ§ μ μ²λ¦¬""" |
|
|
if image is None: |
|
|
return None |
|
|
|
|
|
try: |
|
|
if image.mode in ('RGBA', 'LA', 'P'): |
|
|
background = Image.new('RGB', image.size, (255, 255, 255)) |
|
|
if image.mode == 'P': |
|
|
image = image.convert('RGBA') |
|
|
if image.mode in ('RGBA', 'LA'): |
|
|
background.paste(image, mask=image.split()[-1]) |
|
|
image = background |
|
|
else: |
|
|
image = image.convert('RGB') |
|
|
else: |
|
|
image = image.convert('RGB') |
|
|
|
|
|
return self.transform(image) |
|
|
except: |
|
|
return None |
|
|
|
|
|
@spaces.GPU |
|
|
@torch.no_grad() |
|
|
def get_embedding( |
|
|
self, |
|
|
full_image: Image.Image, |
|
|
face_image: Optional[Image.Image] = None, |
|
|
eye_image: Optional[Image.Image] = None, |
|
|
) -> np.ndarray: |
|
|
"""μ΄λ―Έμ§μμ μλ² λ© μΆμΆ - GPU lazy loading""" |
|
|
|
|
|
|
|
|
self._ensure_model_loaded() |
|
|
|
|
|
full_tensor = self.preprocess_image(full_image) |
|
|
if full_tensor is None: |
|
|
raise ValueError("Full image is required") |
|
|
|
|
|
full = full_tensor.unsqueeze(0).to(self.device) |
|
|
|
|
|
|
|
|
auto_face_image = face_image |
|
|
auto_eye_image = eye_image |
|
|
|
|
|
if auto_face_image is None or auto_eye_image is None: |
|
|
try: |
|
|
extractor = self._get_extractor() |
|
|
if auto_face_image is None: |
|
|
auto_face_image = extractor.extract_face(full_image) |
|
|
if auto_eye_image is None: |
|
|
|
|
|
if auto_face_image is not None: |
|
|
auto_eye_image = extractor.extract_eye_region(auto_face_image) |
|
|
except Exception as e: |
|
|
|
|
|
print(f"[WARN] Auto face/eye extraction failed: {e}") |
|
|
|
|
|
face_tensor = self.preprocess_image(auto_face_image) |
|
|
if face_tensor is not None: |
|
|
face = face_tensor.unsqueeze(0).to(self.device) |
|
|
has_face = torch.tensor([True]).to(self.device) |
|
|
else: |
|
|
face = torch.zeros(1, 3, 224, 224).to(self.device) |
|
|
has_face = torch.tensor([False]).to(self.device) |
|
|
|
|
|
eye_tensor = self.preprocess_image(auto_eye_image) |
|
|
if eye_tensor is not None: |
|
|
eye = eye_tensor.unsqueeze(0).to(self.device) |
|
|
has_eye = torch.tensor([True]).to(self.device) |
|
|
else: |
|
|
eye = torch.zeros(1, 3, 224, 224).to(self.device) |
|
|
has_eye = torch.tensor([False]).to(self.device) |
|
|
|
|
|
with torch.cuda.amp.autocast(enabled=(self.device.type == 'cuda')): |
|
|
embedding = self._model.get_embeddings(full, face, eye, has_face, has_eye) |
|
|
|
|
|
|
|
|
return embedding.squeeze(0).float().cpu().numpy() |
|
|
|
|
|
def find_similar_artists( |
|
|
self, |
|
|
query_embedding: np.ndarray, |
|
|
top_k: int = 10, |
|
|
) -> List[Tuple[str, float]]: |
|
|
"""μ μ¬ μκ° κ²μ""" |
|
|
|
|
|
self._ensure_embeddings_loaded() |
|
|
|
|
|
query_norm = query_embedding / np.linalg.norm(query_embedding) |
|
|
embeddings_norm = self._embeddings / np.linalg.norm(self._embeddings, axis=1, keepdims=True) |
|
|
similarities = embeddings_norm @ query_norm |
|
|
|
|
|
top_indices = np.argsort(similarities)[::-1][:top_k] |
|
|
return [(self._artist_names[i], float(similarities[i])) for i in top_indices] |
|
|
|
|
|
def _get_extractor(self): |
|
|
"""Lazy load extractor to avoid pickle issues""" |
|
|
if self._extractor is None: |
|
|
self._extractor = FaceEyeExtractor( |
|
|
yolo_dir=_default_path('yolov5_anime') if self._extractor_yolo_dir is None else Path(self._extractor_yolo_dir), |
|
|
weights_path=_default_path('yolov5x_anime.pt') if self._extractor_yolo_weights is None else Path(self._extractor_yolo_weights), |
|
|
cascade_path=_default_path('anime-eyes-cascade.xml') if self._extractor_eyes_cascade is None else Path(self._extractor_eyes_cascade), |
|
|
device='cpu', |
|
|
) |
|
|
return self._extractor |
|
|
|
|
|
def _extract_crops_impl(self, full_image: Image.Image) -> Tuple[Optional[Image.Image], Optional[Image.Image], str]: |
|
|
"""μΌκ΅΄κ³Ό λ μλ μΆμΆ - λ΄λΆ ꡬν""" |
|
|
if full_image is None: |
|
|
return None, None, "β μ 체 μ΄λ―Έμ§λ₯Ό λ¨Όμ μ
λ‘λν΄μ£ΌμΈμ." |
|
|
|
|
|
try: |
|
|
extractor = self._get_extractor() |
|
|
face = extractor.extract_face(full_image) |
|
|
eye = None |
|
|
if face is not None: |
|
|
eye = extractor.extract_eye_region(face) |
|
|
|
|
|
status = "β
μΆμΆ μλ£:\n" |
|
|
status += f"- μΌκ΅΄: {'λ°κ²¬λ¨' if face else 'λ°κ²¬λμ§ μμ'}\n" |
|
|
status += f"- λ: {'λ°κ²¬λ¨' if eye else 'λ°κ²¬λμ§ μμ'}\n\n" |
|
|
if face is None: |
|
|
status += "π‘ μΌκ΅΄μ΄ κ°μ§λμ§ μμμ΅λλ€. μλμΌλ‘ μ
λ‘λν΄μ£ΌμΈμ." |
|
|
elif eye is None: |
|
|
status += "π‘ λμ΄ κ°μ§λμ§ μμμ΅λλ€. μλμΌλ‘ μ
λ‘λν΄μ£ΌμΈμ." |
|
|
|
|
|
return face, eye, status |
|
|
except Exception as e: |
|
|
return None, None, f"β μΆμΆ μ€ν¨: {str(e)}" |
|
|
|
|
|
def extract_crops(self, full_image: Image.Image) -> Tuple[Optional[Image.Image], Optional[Image.Image], str]: |
|
|
"""μΌκ΅΄κ³Ό λ μλ μΆμΆ - Gradioμ© λν ν¨μ""" |
|
|
|
|
|
|
|
|
return self._extract_crops_impl(full_image) |
|
|
|
|
|
def search( |
|
|
self, |
|
|
full_image: Image.Image, |
|
|
face_image: Optional[Image.Image], |
|
|
eye_image: Optional[Image.Image], |
|
|
top_k: int, |
|
|
) -> str: |
|
|
"""κ²μ μ€ν""" |
|
|
if full_image is None: |
|
|
return "β μ 체 μ΄λ―Έμ§λ₯Ό μ
λ‘λν΄μ£ΌμΈμ." |
|
|
|
|
|
try: |
|
|
|
|
|
auto_extracted = False |
|
|
if face_image is None or eye_image is None: |
|
|
auto_extracted = True |
|
|
|
|
|
|
|
|
embedding = self.get_embedding(full_image, face_image, eye_image) |
|
|
|
|
|
|
|
|
results = self.find_similar_artists(embedding, top_k=top_k) |
|
|
|
|
|
|
|
|
output = "## π¨ κ²μ κ²°κ³Ό\n\n" |
|
|
if auto_extracted: |
|
|
output += "_βΉοΈ μΌκ΅΄/λμ΄ μ
λ‘λλμ§ μμ μλ μΆμΆμ μλνμ΅λλ€._\n\n" |
|
|
|
|
|
output += "| μμ | μκ° | μ μ¬λ |\n" |
|
|
output += "|:----:|:-----|:------:|\n" |
|
|
|
|
|
for i, (name, score) in enumerate(results, 1): |
|
|
bar = "β" * int(score * 20) + "β" * (20 - int(score * 20)) |
|
|
output += f"| {i} | **{name}** | {score:.4f} {bar} |\n" |
|
|
|
|
|
return output |
|
|
|
|
|
except Exception as e: |
|
|
return f"β μ€λ₯ λ°μ: {str(e)}" |
|
|
|
|
|
def create_ui(self): |
|
|
"""Gradio UI μμ±""" |
|
|
gr = _import_gradio() |
|
|
|
|
|
with gr.Blocks(title="Three-View-Style-Embedder", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown(""" |
|
|
# π¨ Three-View-Style-Embedder |
|
|
|
|
|
μΌλ¬μ€νΈ μ΄λ―Έμ§λ₯Ό μ
λ‘λνλ©΄ κ°μ₯ μ μ¬ν μ€νμΌμ μκ°λ₯Ό μ°Ύμλ립λλ€. |
|
|
|
|
|
- **μ 체 μ΄λ―Έμ§**: νμ (μν μ 체) |
|
|
- **μΌκ΅΄/λ μ΄λ―Έμ§**: μ ν (μλ μΆμΆλκ±°λ μλ μ
λ‘λ) |
|
|
|
|
|
π‘ **μΌκ΅΄/λμ μ
λ‘λνμ§ μμΌλ©΄ μλμΌλ‘ κ°μ§νμ¬ μΆμΆν©λλ€!** |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
full_input = gr.Image( |
|
|
label="μ 체 μ΄λ―Έμ§ (νμ)", |
|
|
type="pil", |
|
|
height=256, |
|
|
) |
|
|
|
|
|
extract_btn = gr.Button("βοΈ μΌκ΅΄/λ μλ μΆμΆ", variant="secondary") |
|
|
extract_status = gr.Markdown(value="") |
|
|
|
|
|
with gr.Row(): |
|
|
face_input = gr.Image( |
|
|
label="μΌκ΅΄ (μ ν - μλμΆμΆ κ°λ₯)", |
|
|
type="pil", |
|
|
height=128, |
|
|
) |
|
|
eye_input = gr.Image( |
|
|
label="λ (μ ν - μλμΆμΆ κ°λ₯)", |
|
|
type="pil", |
|
|
height=128, |
|
|
) |
|
|
|
|
|
top_k = gr.Slider( |
|
|
minimum=5, |
|
|
maximum=50, |
|
|
value=10, |
|
|
step=5, |
|
|
label="κ²μ κ²°κ³Ό μ", |
|
|
) |
|
|
|
|
|
search_btn = gr.Button("π κ²μ", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
output = gr.Markdown( |
|
|
value="μ΄λ―Έμ§λ₯Ό μ
λ‘λνκ³ κ²μ λ²νΌμ λλ¬μ£ΌμΈμ.", |
|
|
label="κ²°κ³Ό", |
|
|
) |
|
|
|
|
|
|
|
|
extract_btn.click( |
|
|
fn=self.extract_crops, |
|
|
inputs=[full_input], |
|
|
outputs=[face_input, eye_input, extract_status], |
|
|
) |
|
|
|
|
|
search_btn.click( |
|
|
fn=self.search, |
|
|
inputs=[full_input, face_input, eye_input, top_k], |
|
|
outputs=output, |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown(""" |
|
|
--- |
|
|
### π‘ μ¬μ© λ°©λ² |
|
|
1. **μ 체 μ΄λ―Έμ§**λ₯Ό μ
λ‘λ |
|
|
2. **[βοΈ μΌκ΅΄/λ μλ μΆμΆ]** λ²νΌμ ν΄λ¦ (μ νμ¬ν) |
|
|
- λλ μ§μ μΌκ΅΄/λ μ΄λ―Έμ§λ₯Ό μ
λ‘λ |
|
|
- μ무κ²λ νμ§ μμΌλ©΄ κ²μ μ μλμΌλ‘ μΆμΆλ©λλ€ |
|
|
3. **[π κ²μ]** λ²νΌμ ν΄λ¦νμ¬ μ μ¬ μκ° μ°ΎκΈ° |
|
|
|
|
|
### π‘ ν |
|
|
- μΌκ΅΄/λμ μλμΌλ‘ μ
λ‘λνλ©΄ λ μ νν κ²°κ³Όλ₯Ό μ»μ μ μμ΅λλ€ |
|
|
- μ μ¬λ 1.0μ κ°κΉμΈμλ‘ μ€νμΌμ΄ λΉμ·ν©λλ€ |
|
|
""") |
|
|
|
|
|
return demo |
|
|
|
|
|
|