iljung1106 commited on
Commit
38ad444
·
1 Parent(s): b1b0bc5

Add list of artists and make eyes only capture one eye with square.

Browse files
Files changed (2) hide show
  1. app/view_extractor.py +87 -15
  2. webui_gradio.py +35 -0
app/view_extractor.py CHANGED
@@ -70,6 +70,69 @@ def _shrink(img: np.ndarray, limit: int):
70
  return small, s
71
 
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  def _best_pair(boxes, W: int, H: int):
74
  clean = [(int(b[0]), int(b[1]), int(b[2]), int(b[3])) for b in boxes]
75
  if len(clean) < 2:
@@ -268,7 +331,7 @@ class AnimeFaceEyeExtractor:
268
  Args:
269
  whole_rgb: HWC RGB uint8
270
  Returns:
271
- (face_rgb, eyes_rgb) as RGB uint8 crops (or None if not found)
272
  """
273
  import cv2
274
 
@@ -324,22 +387,31 @@ class AnimeFaceEyeExtractor:
324
  labs = [("left", cand[0])]
325
  origin = cand_origin
326
 
327
- eyes_crop = None
328
  if labs:
329
  src_img = roi if origin == "roi" else face
330
  bound_h = roi.shape[0] if origin == "roi" else H
331
-
332
- boxes_only = [b for _, b in labs]
333
- # union of eye boxes -> single eyes crop (works for the "eyes" view encoder)
334
- ux1 = min(b[0] for b in boxes_only)
335
- uy1 = min(b[1] for b in boxes_only)
336
- ux2 = max(b[2] for b in boxes_only)
337
- uy2 = max(b[3] for b in boxes_only)
338
- ex1, ey1, ex2, ey2 = _expand((ux1, uy1, ux2, uy2), float(self.cfg.eye_margin), W, bound_h)
339
- crop = src_img[ey1:ey2, ex1:ex2]
340
- if crop.size > 0 and min(crop.shape[0], crop.shape[1]) >= int(self.cfg.eye_min_size):
341
- eyes_crop = crop.copy()
342
-
343
- return face, eyes_crop
 
 
 
 
 
 
 
 
 
344
 
345
 
 
70
  return small, s
71
 
72
 
73
+ def _pad_to_square_rgb(img: np.ndarray) -> np.ndarray:
74
+ """
75
+ Pad an RGB crop to a square (1:1) using edge-padding.
76
+ This guarantees 1:1 aspect ratio without stretching content.
77
+ """
78
+ if img is None or img.size == 0:
79
+ return img
80
+ h, w = img.shape[:2]
81
+ if h == w:
82
+ return img
83
+ s = max(h, w)
84
+ pad_y = s - h
85
+ pad_x = s - w
86
+ top = pad_y // 2
87
+ bottom = pad_y - top
88
+ left = pad_x // 2
89
+ right = pad_x - left
90
+ return np.pad(img, ((top, bottom), (left, right), (0, 0)), mode="edge")
91
+
92
+
93
+ def _square_box_from_rect(rect, *, scale: float, W: int, H: int):
94
+ """
95
+ Convert a rectangle (x1,y1,x2,y2) into a square box centered on the rect,
96
+ scaled by `scale`, clamped to image bounds.
97
+ """
98
+ x1, y1, x2, y2 = [int(v) for v in rect]
99
+ cx = (x1 + x2) / 2.0
100
+ cy = (y1 + y2) / 2.0
101
+ bw = max(1.0, float(x2 - x1))
102
+ bh = max(1.0, float(y2 - y1))
103
+ side = max(bw, bh) * float(scale)
104
+ nx1 = int(round(cx - side / 2.0))
105
+ ny1 = int(round(cy - side / 2.0))
106
+ nx2 = int(round(cx + side / 2.0))
107
+ ny2 = int(round(cy + side / 2.0))
108
+ nx1 = max(0, min(W, nx1))
109
+ ny1 = max(0, min(H, ny1))
110
+ nx2 = max(0, min(W, nx2))
111
+ ny2 = max(0, min(H, ny2))
112
+ if nx2 <= nx1 or ny2 <= ny1:
113
+ return None
114
+ return nx1, ny1, nx2, ny2
115
+
116
+
117
+ def _split_box_by_midline(box, mid_x: int):
118
+ """
119
+ If a box crosses the vertical midline, split into left/right boxes.
120
+ Returns list of (tag, box).
121
+ """
122
+ x1, y1, x2, y2 = [int(v) for v in box]
123
+ if x1 < mid_x < x2:
124
+ left = (x1, y1, mid_x, y2)
125
+ right = (mid_x, y1, x2, y2)
126
+ out = []
127
+ if left[2] > left[0]:
128
+ out.append(("left", left))
129
+ if right[2] > right[0]:
130
+ out.append(("right", right))
131
+ return out
132
+ tag = "left" if (x1 + x2) / 2.0 <= mid_x else "right"
133
+ return [(tag, (x1, y1, x2, y2))]
134
+
135
+
136
  def _best_pair(boxes, W: int, H: int):
137
  clean = [(int(b[0]), int(b[1]), int(b[2]), int(b[3])) for b in boxes]
138
  if len(clean) < 2:
 
331
  Args:
332
  whole_rgb: HWC RGB uint8
333
  Returns:
334
+ (face_rgb, eye_rgb) as RGB uint8 crops (or None if not found)
335
  """
336
  import cv2
337
 
 
387
  labs = [("left", cand[0])]
388
  origin = cand_origin
389
 
390
+ eye_crop = None
391
  if labs:
392
  src_img = roi if origin == "roi" else face
393
  bound_h = roi.shape[0] if origin == "roi" else H
394
+ mid_x = int(round(W / 2.0))
395
+
396
+ # Build candidate eye boxes; split any box that crosses the midline
397
+ candidates = []
398
+ for tag, b in labs:
399
+ candidates.extend(_split_box_by_midline(b, mid_x))
400
+
401
+ # Deterministically choose the LEFT eye if present; otherwise fall back to largest
402
+ left_boxes = [b for (t, b) in candidates if t == "left"]
403
+ pick_from = left_boxes if left_boxes else [b for (_, b) in candidates]
404
+ chosen = max(pick_from, key=lambda bb: max(1, (bb[2] - bb[0]) * (bb[3] - bb[1])))
405
+
406
+ # Square crop around the chosen eye (no stretching); pad to square to guarantee 1:1.
407
+ scale = 1.0 + float(self.cfg.eye_margin)
408
+ sq = _square_box_from_rect(chosen, scale=scale, W=W, H=bound_h)
409
+ if sq is not None:
410
+ ex1, ey1, ex2, ey2 = sq
411
+ crop = src_img[ey1:ey2, ex1:ex2]
412
+ if crop.size > 0 and min(crop.shape[0], crop.shape[1]) >= int(self.cfg.eye_min_size):
413
+ eye_crop = _pad_to_square_rgb(crop.copy())
414
+
415
+ return face, eye_crop
416
 
417
 
webui_gradio.py CHANGED
@@ -339,6 +339,30 @@ def classify_and_analyze(
339
  return (f"❌ Failed: {e}",) + empty_result[1:]
340
 
341
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
  def _gallery_item_to_pil(item) -> Optional[Image.Image]:
343
  """Convert a Gradio gallery item to PIL Image (handles various formats)."""
344
  if item is None:
@@ -600,6 +624,17 @@ def build_ui() -> gr.Blocks:
600
  uploader.change(_files_to_gallery, inputs=[uploader], outputs=[imgs])
601
  add_btn.click(add_prototype, inputs=[label, imgs, k_proto, n_trips], outputs=[add_status])
602
 
 
 
 
 
 
 
 
 
 
 
 
603
  return demo
604
 
605
 
 
339
  return (f"❌ Failed: {e}",) + empty_result[1:]
340
 
341
 
342
+ def list_artists_in_db():
343
+ """
344
+ List all artists present in the currently loaded prototype DB.
345
+ Returns: status, rows [artist, prototype_count]
346
+ """
347
+ if APP_STATE.db is None:
348
+ return "❌ Click **Load** first.", []
349
+
350
+ db = APP_STATE.db
351
+ # Count prototypes per label id
352
+ counts: dict[int, int] = {}
353
+ for lid in db.labels.detach().cpu().tolist():
354
+ counts[int(lid)] = counts.get(int(lid), 0) + 1
355
+
356
+ rows: list[list] = []
357
+ for lid, name in enumerate(db.label_names):
358
+ c = int(counts.get(int(lid), 0))
359
+ if c > 0:
360
+ rows.append([name, c])
361
+
362
+ rows.sort(key=lambda r: (-int(r[1]), str(r[0]).lower()))
363
+ return f"✅ {len(rows)} artists in DB (total prototypes: {int(db.centers.shape[0])}).", rows
364
+
365
+
366
  def _gallery_item_to_pil(item) -> Optional[Image.Image]:
367
  """Convert a Gradio gallery item to PIL Image (handles various formats)."""
368
  if item is None:
 
624
  uploader.change(_files_to_gallery, inputs=[uploader], outputs=[imgs])
625
  add_btn.click(add_prototype, inputs=[label, imgs, k_proto, n_trips], outputs=[add_status])
626
 
627
+ with gr.Tab("Artists (in DB)"):
628
+ gr.Markdown(
629
+ "### Artists in Prototype DB\n"
630
+ "Shows which artist labels exist in the currently loaded prototype database "
631
+ "(including any temporary prototypes added in this session)."
632
+ )
633
+ refresh_artists = gr.Button("Refresh", variant="secondary")
634
+ artists_status = gr.Markdown("")
635
+ artists_table = gr.Dataframe(headers=["Artist", "#Prototypes"], datatype=["str", "number"], interactive=False)
636
+ refresh_artists.click(list_artists_in_db, inputs=[], outputs=[artists_status, artists_table])
637
+
638
  return demo
639
 
640