Spaces:
Sleeping
Sleeping
iljung1106
commited on
Commit
·
38ad444
1
Parent(s):
b1b0bc5
Add list of artists and make eyes only capture one eye with square.
Browse files- app/view_extractor.py +87 -15
- webui_gradio.py +35 -0
app/view_extractor.py
CHANGED
|
@@ -70,6 +70,69 @@ def _shrink(img: np.ndarray, limit: int):
|
|
| 70 |
return small, s
|
| 71 |
|
| 72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
def _best_pair(boxes, W: int, H: int):
|
| 74 |
clean = [(int(b[0]), int(b[1]), int(b[2]), int(b[3])) for b in boxes]
|
| 75 |
if len(clean) < 2:
|
|
@@ -268,7 +331,7 @@ class AnimeFaceEyeExtractor:
|
|
| 268 |
Args:
|
| 269 |
whole_rgb: HWC RGB uint8
|
| 270 |
Returns:
|
| 271 |
-
(face_rgb,
|
| 272 |
"""
|
| 273 |
import cv2
|
| 274 |
|
|
@@ -324,22 +387,31 @@ class AnimeFaceEyeExtractor:
|
|
| 324 |
labs = [("left", cand[0])]
|
| 325 |
origin = cand_origin
|
| 326 |
|
| 327 |
-
|
| 328 |
if labs:
|
| 329 |
src_img = roi if origin == "roi" else face
|
| 330 |
bound_h = roi.shape[0] if origin == "roi" else H
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
#
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
if
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 344 |
|
| 345 |
|
|
|
|
| 70 |
return small, s
|
| 71 |
|
| 72 |
|
| 73 |
+
def _pad_to_square_rgb(img: np.ndarray) -> np.ndarray:
|
| 74 |
+
"""
|
| 75 |
+
Pad an RGB crop to a square (1:1) using edge-padding.
|
| 76 |
+
This guarantees 1:1 aspect ratio without stretching content.
|
| 77 |
+
"""
|
| 78 |
+
if img is None or img.size == 0:
|
| 79 |
+
return img
|
| 80 |
+
h, w = img.shape[:2]
|
| 81 |
+
if h == w:
|
| 82 |
+
return img
|
| 83 |
+
s = max(h, w)
|
| 84 |
+
pad_y = s - h
|
| 85 |
+
pad_x = s - w
|
| 86 |
+
top = pad_y // 2
|
| 87 |
+
bottom = pad_y - top
|
| 88 |
+
left = pad_x // 2
|
| 89 |
+
right = pad_x - left
|
| 90 |
+
return np.pad(img, ((top, bottom), (left, right), (0, 0)), mode="edge")
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def _square_box_from_rect(rect, *, scale: float, W: int, H: int):
|
| 94 |
+
"""
|
| 95 |
+
Convert a rectangle (x1,y1,x2,y2) into a square box centered on the rect,
|
| 96 |
+
scaled by `scale`, clamped to image bounds.
|
| 97 |
+
"""
|
| 98 |
+
x1, y1, x2, y2 = [int(v) for v in rect]
|
| 99 |
+
cx = (x1 + x2) / 2.0
|
| 100 |
+
cy = (y1 + y2) / 2.0
|
| 101 |
+
bw = max(1.0, float(x2 - x1))
|
| 102 |
+
bh = max(1.0, float(y2 - y1))
|
| 103 |
+
side = max(bw, bh) * float(scale)
|
| 104 |
+
nx1 = int(round(cx - side / 2.0))
|
| 105 |
+
ny1 = int(round(cy - side / 2.0))
|
| 106 |
+
nx2 = int(round(cx + side / 2.0))
|
| 107 |
+
ny2 = int(round(cy + side / 2.0))
|
| 108 |
+
nx1 = max(0, min(W, nx1))
|
| 109 |
+
ny1 = max(0, min(H, ny1))
|
| 110 |
+
nx2 = max(0, min(W, nx2))
|
| 111 |
+
ny2 = max(0, min(H, ny2))
|
| 112 |
+
if nx2 <= nx1 or ny2 <= ny1:
|
| 113 |
+
return None
|
| 114 |
+
return nx1, ny1, nx2, ny2
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def _split_box_by_midline(box, mid_x: int):
|
| 118 |
+
"""
|
| 119 |
+
If a box crosses the vertical midline, split into left/right boxes.
|
| 120 |
+
Returns list of (tag, box).
|
| 121 |
+
"""
|
| 122 |
+
x1, y1, x2, y2 = [int(v) for v in box]
|
| 123 |
+
if x1 < mid_x < x2:
|
| 124 |
+
left = (x1, y1, mid_x, y2)
|
| 125 |
+
right = (mid_x, y1, x2, y2)
|
| 126 |
+
out = []
|
| 127 |
+
if left[2] > left[0]:
|
| 128 |
+
out.append(("left", left))
|
| 129 |
+
if right[2] > right[0]:
|
| 130 |
+
out.append(("right", right))
|
| 131 |
+
return out
|
| 132 |
+
tag = "left" if (x1 + x2) / 2.0 <= mid_x else "right"
|
| 133 |
+
return [(tag, (x1, y1, x2, y2))]
|
| 134 |
+
|
| 135 |
+
|
| 136 |
def _best_pair(boxes, W: int, H: int):
|
| 137 |
clean = [(int(b[0]), int(b[1]), int(b[2]), int(b[3])) for b in boxes]
|
| 138 |
if len(clean) < 2:
|
|
|
|
| 331 |
Args:
|
| 332 |
whole_rgb: HWC RGB uint8
|
| 333 |
Returns:
|
| 334 |
+
(face_rgb, eye_rgb) as RGB uint8 crops (or None if not found)
|
| 335 |
"""
|
| 336 |
import cv2
|
| 337 |
|
|
|
|
| 387 |
labs = [("left", cand[0])]
|
| 388 |
origin = cand_origin
|
| 389 |
|
| 390 |
+
eye_crop = None
|
| 391 |
if labs:
|
| 392 |
src_img = roi if origin == "roi" else face
|
| 393 |
bound_h = roi.shape[0] if origin == "roi" else H
|
| 394 |
+
mid_x = int(round(W / 2.0))
|
| 395 |
+
|
| 396 |
+
# Build candidate eye boxes; split any box that crosses the midline
|
| 397 |
+
candidates = []
|
| 398 |
+
for tag, b in labs:
|
| 399 |
+
candidates.extend(_split_box_by_midline(b, mid_x))
|
| 400 |
+
|
| 401 |
+
# Deterministically choose the LEFT eye if present; otherwise fall back to largest
|
| 402 |
+
left_boxes = [b for (t, b) in candidates if t == "left"]
|
| 403 |
+
pick_from = left_boxes if left_boxes else [b for (_, b) in candidates]
|
| 404 |
+
chosen = max(pick_from, key=lambda bb: max(1, (bb[2] - bb[0]) * (bb[3] - bb[1])))
|
| 405 |
+
|
| 406 |
+
# Square crop around the chosen eye (no stretching); pad to square to guarantee 1:1.
|
| 407 |
+
scale = 1.0 + float(self.cfg.eye_margin)
|
| 408 |
+
sq = _square_box_from_rect(chosen, scale=scale, W=W, H=bound_h)
|
| 409 |
+
if sq is not None:
|
| 410 |
+
ex1, ey1, ex2, ey2 = sq
|
| 411 |
+
crop = src_img[ey1:ey2, ex1:ex2]
|
| 412 |
+
if crop.size > 0 and min(crop.shape[0], crop.shape[1]) >= int(self.cfg.eye_min_size):
|
| 413 |
+
eye_crop = _pad_to_square_rgb(crop.copy())
|
| 414 |
+
|
| 415 |
+
return face, eye_crop
|
| 416 |
|
| 417 |
|
webui_gradio.py
CHANGED
|
@@ -339,6 +339,30 @@ def classify_and_analyze(
|
|
| 339 |
return (f"❌ Failed: {e}",) + empty_result[1:]
|
| 340 |
|
| 341 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
def _gallery_item_to_pil(item) -> Optional[Image.Image]:
|
| 343 |
"""Convert a Gradio gallery item to PIL Image (handles various formats)."""
|
| 344 |
if item is None:
|
|
@@ -600,6 +624,17 @@ def build_ui() -> gr.Blocks:
|
|
| 600 |
uploader.change(_files_to_gallery, inputs=[uploader], outputs=[imgs])
|
| 601 |
add_btn.click(add_prototype, inputs=[label, imgs, k_proto, n_trips], outputs=[add_status])
|
| 602 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 603 |
return demo
|
| 604 |
|
| 605 |
|
|
|
|
| 339 |
return (f"❌ Failed: {e}",) + empty_result[1:]
|
| 340 |
|
| 341 |
|
| 342 |
+
def list_artists_in_db():
|
| 343 |
+
"""
|
| 344 |
+
List all artists present in the currently loaded prototype DB.
|
| 345 |
+
Returns: status, rows [artist, prototype_count]
|
| 346 |
+
"""
|
| 347 |
+
if APP_STATE.db is None:
|
| 348 |
+
return "❌ Click **Load** first.", []
|
| 349 |
+
|
| 350 |
+
db = APP_STATE.db
|
| 351 |
+
# Count prototypes per label id
|
| 352 |
+
counts: dict[int, int] = {}
|
| 353 |
+
for lid in db.labels.detach().cpu().tolist():
|
| 354 |
+
counts[int(lid)] = counts.get(int(lid), 0) + 1
|
| 355 |
+
|
| 356 |
+
rows: list[list] = []
|
| 357 |
+
for lid, name in enumerate(db.label_names):
|
| 358 |
+
c = int(counts.get(int(lid), 0))
|
| 359 |
+
if c > 0:
|
| 360 |
+
rows.append([name, c])
|
| 361 |
+
|
| 362 |
+
rows.sort(key=lambda r: (-int(r[1]), str(r[0]).lower()))
|
| 363 |
+
return f"✅ {len(rows)} artists in DB (total prototypes: {int(db.centers.shape[0])}).", rows
|
| 364 |
+
|
| 365 |
+
|
| 366 |
def _gallery_item_to_pil(item) -> Optional[Image.Image]:
|
| 367 |
"""Convert a Gradio gallery item to PIL Image (handles various formats)."""
|
| 368 |
if item is None:
|
|
|
|
| 624 |
uploader.change(_files_to_gallery, inputs=[uploader], outputs=[imgs])
|
| 625 |
add_btn.click(add_prototype, inputs=[label, imgs, k_proto, n_trips], outputs=[add_status])
|
| 626 |
|
| 627 |
+
with gr.Tab("Artists (in DB)"):
|
| 628 |
+
gr.Markdown(
|
| 629 |
+
"### Artists in Prototype DB\n"
|
| 630 |
+
"Shows which artist labels exist in the currently loaded prototype database "
|
| 631 |
+
"(including any temporary prototypes added in this session)."
|
| 632 |
+
)
|
| 633 |
+
refresh_artists = gr.Button("Refresh", variant="secondary")
|
| 634 |
+
artists_status = gr.Markdown("")
|
| 635 |
+
artists_table = gr.Dataframe(headers=["Artist", "#Prototypes"], datatype=["str", "number"], interactive=False)
|
| 636 |
+
refresh_artists.click(list_artists_in_db, inputs=[], outputs=[artists_status, artists_table])
|
| 637 |
+
|
| 638 |
return demo
|
| 639 |
|
| 640 |
|