throaway2854 commited on
Commit
d254ba7
·
verified ·
1 Parent(s): d012bfe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +192 -50
app.py CHANGED
@@ -120,7 +120,7 @@ class VideoTagger:
120
  and exposes helpers to tag PIL images and full videos.
121
  """
122
 
123
- def __init__(self, model_repo: str):
124
  self.model_repo = model_repo
125
  self.model = None
126
  self.model_target_size = None # will be set from ONNX input shape
@@ -128,6 +128,7 @@ class VideoTagger:
128
  self.rating_indexes = None
129
  self.general_indexes = None
130
  self.character_indexes = None
 
131
 
132
  def _download_model_files(self) -> Tuple[str, str]:
133
  csv_path = huggingface_hub.hf_hub_download(
@@ -202,6 +203,92 @@ class VideoTagger:
202
  arr = np.expand_dims(arr, axis=0)
203
  return arr
204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  def tag_image(
206
  self,
207
  image: Image.Image,
@@ -225,6 +312,7 @@ class VideoTagger:
225
 
226
  labels = list(zip(self.tag_names, preds))
227
 
 
228
  # General tags
229
  general_names = [labels[i] for i in self.general_indexes]
230
  general_res = {
@@ -243,6 +331,40 @@ class VideoTagger:
243
 
244
  return general_res, character_res
245
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  def tag_video(
247
  self,
248
  video_path: str,
@@ -265,6 +387,8 @@ class VideoTagger:
265
 
266
  frame_interval = max(int(frame_interval), 1)
267
 
 
 
268
  if progress is not None:
269
  progress(0.0, desc="Opening video...")
270
 
@@ -272,20 +396,20 @@ class VideoTagger:
272
  if not cap.isOpened():
273
  raise RuntimeError("Unable to open video file.")
274
 
275
- # Estimate total frames and how many will be processed
276
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0
277
  if total_frames <= 0:
278
- total_frames = 1 # avoid division issues
279
 
280
  frames_to_process = max(1, (total_frames + frame_interval - 1) // frame_interval)
281
 
282
- # Store max score seen for each tag across all frames
283
  aggregated_general: Dict[str, float] = {}
284
  aggregated_character: Dict[str, float] = {}
285
 
286
  frame_idx = 0
287
  processed_frames = 0
288
 
 
 
289
  try:
290
  while True:
291
  ret, frame = cap.read()
@@ -294,38 +418,44 @@ class VideoTagger:
294
 
295
  # Only process every N-th frame
296
  if frame_idx % frame_interval == 0:
297
- # Convert OpenCV BGR frame -> PIL image with alpha
298
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
299
- pil_image = Image.fromarray(frame_rgb).convert("RGBA")
300
-
301
- general_res, character_res = self.tag_image(
302
- pil_image,
303
- general_thresh=general_thresh,
304
- character_thresh=character_thresh,
305
- )
306
-
307
- # Aggregate by keeping max score per tag
308
- for tag, score in general_res.items():
309
- if tag not in aggregated_general or score > aggregated_general[tag]:
310
- aggregated_general[tag] = score
311
-
312
- for tag, score in character_res.items():
313
- if tag not in aggregated_character or score > aggregated_character[tag]:
314
- aggregated_character[tag] = score
315
-
316
- processed_frames += 1
317
-
318
- if progress is not None:
319
- ratio = min(processed_frames / frames_to_process, 0.99)
320
- progress(
321
- ratio,
322
- desc=f"Processing frame {processed_frames}/{frames_to_process}...",
323
  )
 
 
 
 
 
 
 
 
 
324
 
325
  frame_idx += 1
326
  finally:
327
  cap.release()
328
 
 
 
 
 
 
 
 
 
 
 
 
329
  if progress is not None:
330
  progress(1.0, desc="Finalizing tags...")
331
 
@@ -335,29 +465,23 @@ class VideoTagger:
335
  # Apply substitutions & exclusions BEFORE final dedup
336
  adjusted_all_tags: Dict[str, float] = {}
337
 
338
- # Normalize keys in substitutes/exclusions (strip whitespace)
339
  normalized_subs = {k.strip(): v.strip() for k, v in tag_substitutes.items() if k and v}
340
  normalized_exclusions = {t.strip() for t in tag_exclusions if t}
341
 
342
  for tag, score in all_tags_with_scores.items():
343
  original_tag = tag.strip()
344
 
345
- # Skip if original tag is excluded
346
  if original_tag in normalized_exclusions:
347
  continue
348
 
349
- # Apply substitution (if any)
350
  new_tag = normalized_subs.get(original_tag, original_tag)
351
 
352
- # Skip if substituted tag is excluded
353
  if new_tag in normalized_exclusions:
354
  continue
355
 
356
- # Keep max score for each resulting tag
357
  if new_tag not in adjusted_all_tags or score > adjusted_all_tags[new_tag]:
358
  adjusted_all_tags[new_tag] = score
359
 
360
- # Sort by score descending
361
  sorted_tags = sorted(
362
  adjusted_all_tags.items(),
363
  key=lambda kv: kv[1],
@@ -381,6 +505,7 @@ class VideoTagger:
381
  "character_threshold": float(character_thresh),
382
  "num_substitution_rules": len(normalized_subs),
383
  "num_exclusions": len(normalized_exclusions),
 
384
  }
385
 
386
  return combined_tags_str, debug_info
@@ -447,6 +572,7 @@ def tag_video_interface(
447
  model_repo: str,
448
  tag_substitutes_df,
449
  tag_exclusions_df,
 
450
  progress=gr.Progress(track_tqdm=False),
451
  ):
452
  if video_path is None:
@@ -454,6 +580,7 @@ def tag_video_interface(
454
 
455
  try:
456
  tagger = get_tagger(model_repo)
 
457
 
458
  tag_substitutes = _normalize_tag_substitutes(tag_substitutes_df)
459
  tag_exclusions = _normalize_tag_exclusions(tag_exclusions_df)
@@ -485,22 +612,13 @@ with gr.Blocks(title=TITLE) as demo:
485
  sources=["upload"],
486
  format="mp4",
487
  )
488
-
489
  model_choice = gr.Dropdown(
490
  choices=MODEL_OPTIONS,
491
  value=DEFAULT_MODEL_REPO,
492
  label="Tagging Model",
493
  )
494
 
495
- frame_interval = gr.Slider(
496
- minimum=1,
497
- maximum=60,
498
- step=1,
499
- value=10,
500
- label="Extract Every N Frames",
501
- info="For example, 10 = use every 10th frame.",
502
- )
503
-
504
  general_thresh = gr.Slider(
505
  minimum=0.0,
506
  maximum=1.0,
@@ -508,7 +626,7 @@ with gr.Blocks(title=TITLE) as demo:
508
  value=0.35,
509
  label="General Tags Threshold",
510
  )
511
-
512
  character_thresh = gr.Slider(
513
  minimum=0.0,
514
  maximum=1.0,
@@ -516,9 +634,32 @@ with gr.Blocks(title=TITLE) as demo:
516
  value=0.85,
517
  label="Character Tags Threshold",
518
  )
519
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
520
  run_button = gr.Button("Generate Tags", variant="primary")
521
-
522
  with gr.Column():
523
  combined_tags = gr.Textbox(
524
  label="Combined Unique Tags (All Frames)",
@@ -529,6 +670,7 @@ with gr.Blocks(title=TITLE) as demo:
529
  label="Details / Debug Info",
530
  )
531
 
 
532
  # ---------------- TAB 2: TAG CONTROL ----------------
533
  with gr.Tab("Tag Control"):
534
  gr.Markdown("### Tag Substitutes")
@@ -582,7 +724,6 @@ with gr.Blocks(title=TITLE) as demo:
582
  )
583
 
584
 
585
- # Wiring the button AFTER all components are defined
586
  run_button.click(
587
  fn=tag_video_interface,
588
  inputs=[
@@ -593,6 +734,7 @@ with gr.Blocks(title=TITLE) as demo:
593
  model_choice,
594
  tag_substitutes_df,
595
  tag_exclusions_df,
 
596
  ],
597
  outputs=[combined_tags, debug_info],
598
  )
 
120
  and exposes helpers to tag PIL images and full videos.
121
  """
122
 
123
+ def __init__(self, model_repo: str, batch_size: int = 16):
124
  self.model_repo = model_repo
125
  self.model = None
126
  self.model_target_size = None # will be set from ONNX input shape
 
128
  self.rating_indexes = None
129
  self.general_indexes = None
130
  self.character_indexes = None
131
+ self.batch_size = batch_size
132
 
133
  def _download_model_files(self) -> Tuple[str, str]:
134
  csv_path = huggingface_hub.hf_hub_download(
 
203
  arr = np.expand_dims(arr, axis=0)
204
  return arr
205
 
206
+ def _prepare_frame_bgr(self, frame_bgr: np.ndarray) -> np.ndarray:
207
+ """
208
+ Fast path for OpenCV frames (BGR uint8).
209
+ Pads to square, resizes to model_target_size, converts to float32.
210
+
211
+ Returns: (H, W, 3) float32 array in BGR format (no batch dim).
212
+ """
213
+ self._load_model_if_needed()
214
+ target_size = self.model_target_size
215
+
216
+ h, w, _ = frame_bgr.shape
217
+ max_dim = max(h, w)
218
+
219
+ # Compute symmetric padding to make it square
220
+ pad_vert = max_dim - h
221
+ pad_horiz = max_dim - w
222
+ top = pad_vert // 2
223
+ bottom = pad_vert - top
224
+ left = pad_horiz // 2
225
+ right = pad_horiz - left
226
+
227
+ # Pad with white background (255, 255, 255) in BGR
228
+ frame_square = cv2.copyMakeBorder(
229
+ frame_bgr,
230
+ top, bottom, left, right,
231
+ borderType=cv2.BORDER_CONSTANT,
232
+ value=(255, 255, 255),
233
+ )
234
+
235
+ # Resize if needed
236
+ if max_dim != target_size:
237
+ frame_square = cv2.resize(
238
+ frame_square,
239
+ (target_size, target_size),
240
+ interpolation=cv2.INTER_AREA,
241
+ )
242
+
243
+ # To float32, no color channel reordering needed (already BGR)
244
+ arr = frame_square.astype(np.float32)
245
+ return arr # (H, W, 3)
246
+
247
+ def _run_batch_and_aggregate(
248
+ self,
249
+ batch_tensors: List[np.ndarray],
250
+ general_thresh: float,
251
+ character_thresh: float,
252
+ aggregated_general: Dict[str, float],
253
+ aggregated_character: Dict[str, float],
254
+ ) -> int:
255
+ """
256
+ Run ONNX inference on a batch of preprocessed frames and
257
+ update aggregated_general / aggregated_character with max scores.
258
+
259
+ Returns: number of frames processed in this batch.
260
+ """
261
+ if not batch_tensors:
262
+ return 0
263
+
264
+ self._load_model_if_needed()
265
+ input_name = self.model.get_inputs()[0].name
266
+ output_name = self.model.get_outputs()[0].name
267
+
268
+ # Stack into shape (B, H, W, 3)
269
+ input_tensor = np.stack(batch_tensors, axis=0) # float32
270
+
271
+ preds_batch = self.model.run([output_name], {input_name: input_tensor})[0]
272
+ # preds_batch: (B, num_tags)
273
+
274
+ for preds in preds_batch:
275
+ general_res, character_res = self._extract_tags_from_scores(
276
+ preds,
277
+ general_thresh=general_thresh,
278
+ character_thresh=character_thresh,
279
+ )
280
+
281
+ # Aggregate max score for each tag
282
+ for tag, score in general_res.items():
283
+ if tag not in aggregated_general or score > aggregated_general[tag]:
284
+ aggregated_general[tag] = score
285
+
286
+ for tag, score in character_res.items():
287
+ if tag not in aggregated_character or score > aggregated_character[tag]:
288
+ aggregated_character[tag] = score
289
+
290
+ return len(batch_tensors)
291
+
292
  def tag_image(
293
  self,
294
  image: Image.Image,
 
312
 
313
  labels = list(zip(self.tag_names, preds))
314
 
315
+
316
  # General tags
317
  general_names = [labels[i] for i in self.general_indexes]
318
  general_res = {
 
331
 
332
  return general_res, character_res
333
 
334
+ def _extract_tags_from_scores(
335
+ self,
336
+ preds: np.ndarray,
337
+ general_thresh: float,
338
+ character_thresh: float,
339
+ ) -> Tuple[Dict[str, float], Dict[str, float]]:
340
+ """
341
+ Given a 1D preds array (num_tags,), return dicts of general/character tags.
342
+ More efficient than rebuilding label tuples every time.
343
+ """
344
+ # Ensure numpy array of floats
345
+ preds = preds.astype(float)
346
+
347
+ general_res: Dict[str, float] = {}
348
+ character_res: Dict[str, float] = {}
349
+
350
+ # General tags
351
+ general_scores = preds[self.general_indexes]
352
+ general_idx_array = np.array(self.general_indexes)
353
+ general_mask = general_scores > general_thresh
354
+ for idx, score in zip(general_idx_array[general_mask], general_scores[general_mask]):
355
+ tag = self.tag_names[idx]
356
+ general_res[tag] = float(score)
357
+
358
+ # Character tags
359
+ character_scores = preds[self.character_indexes]
360
+ character_idx_array = np.array(self.character_indexes)
361
+ character_mask = character_scores > character_thresh
362
+ for idx, score in zip(character_idx_array[character_mask], character_scores[character_mask]):
363
+ tag = self.tag_names[idx]
364
+ character_res[tag] = float(score)
365
+
366
+ return general_res, character_res
367
+
368
  def tag_video(
369
  self,
370
  video_path: str,
 
387
 
388
  frame_interval = max(int(frame_interval), 1)
389
 
390
+ self._load_model_if_needed()
391
+
392
  if progress is not None:
393
  progress(0.0, desc="Opening video...")
394
 
 
396
  if not cap.isOpened():
397
  raise RuntimeError("Unable to open video file.")
398
 
 
399
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0
400
  if total_frames <= 0:
401
+ total_frames = 1
402
 
403
  frames_to_process = max(1, (total_frames + frame_interval - 1) // frame_interval)
404
 
 
405
  aggregated_general: Dict[str, float] = {}
406
  aggregated_character: Dict[str, float] = {}
407
 
408
  frame_idx = 0
409
  processed_frames = 0
410
 
411
+ batch_tensors: List[np.ndarray] = []
412
+
413
  try:
414
  while True:
415
  ret, frame = cap.read()
 
418
 
419
  # Only process every N-th frame
420
  if frame_idx % frame_interval == 0:
421
+ # frame is BGR uint8 from OpenCV
422
+ arr = self._prepare_frame_bgr(frame) # (H, W, 3) float32
423
+ batch_tensors.append(arr)
424
+
425
+ # If batch is full, run inference
426
+ if len(batch_tensors) >= self.batch_size:
427
+ num_done = self._run_batch_and_aggregate(
428
+ batch_tensors,
429
+ general_thresh=general_thresh,
430
+ character_thresh=character_thresh,
431
+ aggregated_general=aggregated_general,
432
+ aggregated_character=aggregated_character,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433
  )
434
+ processed_frames += num_done
435
+ batch_tensors = []
436
+
437
+ if progress is not None:
438
+ ratio = min(processed_frames / frames_to_process, 0.99)
439
+ progress(
440
+ ratio,
441
+ desc=f"Processing frames {processed_frames}/{frames_to_process}...",
442
+ )
443
 
444
  frame_idx += 1
445
  finally:
446
  cap.release()
447
 
448
+ # Process any leftover frames in the last partial batch
449
+ if batch_tensors:
450
+ num_done = self._run_batch_and_aggregate(
451
+ batch_tensors,
452
+ general_thresh=general_thresh,
453
+ character_thresh=character_thresh,
454
+ aggregated_general=aggregated_general,
455
+ aggregated_character=aggregated_character,
456
+ )
457
+ processed_frames += num_done
458
+
459
  if progress is not None:
460
  progress(1.0, desc="Finalizing tags...")
461
 
 
465
  # Apply substitutions & exclusions BEFORE final dedup
466
  adjusted_all_tags: Dict[str, float] = {}
467
 
 
468
  normalized_subs = {k.strip(): v.strip() for k, v in tag_substitutes.items() if k and v}
469
  normalized_exclusions = {t.strip() for t in tag_exclusions if t}
470
 
471
  for tag, score in all_tags_with_scores.items():
472
  original_tag = tag.strip()
473
 
 
474
  if original_tag in normalized_exclusions:
475
  continue
476
 
 
477
  new_tag = normalized_subs.get(original_tag, original_tag)
478
 
 
479
  if new_tag in normalized_exclusions:
480
  continue
481
 
 
482
  if new_tag not in adjusted_all_tags or score > adjusted_all_tags[new_tag]:
483
  adjusted_all_tags[new_tag] = score
484
 
 
485
  sorted_tags = sorted(
486
  adjusted_all_tags.items(),
487
  key=lambda kv: kv[1],
 
505
  "character_threshold": float(character_thresh),
506
  "num_substitution_rules": len(normalized_subs),
507
  "num_exclusions": len(normalized_exclusions),
508
+ "batch_size": int(self.batch_size),
509
  }
510
 
511
  return combined_tags_str, debug_info
 
572
  model_repo: str,
573
  tag_substitutes_df,
574
  tag_exclusions_df,
575
+ batch_size: int,
576
  progress=gr.Progress(track_tqdm=False),
577
  ):
578
  if video_path is None:
 
580
 
581
  try:
582
  tagger = get_tagger(model_repo)
583
+ tagger.batch_size = int(batch_size)
584
 
585
  tag_substitutes = _normalize_tag_substitutes(tag_substitutes_df)
586
  tag_exclusions = _normalize_tag_exclusions(tag_exclusions_df)
 
612
  sources=["upload"],
613
  format="mp4",
614
  )
615
+
616
  model_choice = gr.Dropdown(
617
  choices=MODEL_OPTIONS,
618
  value=DEFAULT_MODEL_REPO,
619
  label="Tagging Model",
620
  )
621
 
 
 
 
 
 
 
 
 
 
622
  general_thresh = gr.Slider(
623
  minimum=0.0,
624
  maximum=1.0,
 
626
  value=0.35,
627
  label="General Tags Threshold",
628
  )
629
+
630
  character_thresh = gr.Slider(
631
  minimum=0.0,
632
  maximum=1.0,
 
634
  value=0.85,
635
  label="Character Tags Threshold",
636
  )
637
+
638
+ gr.Markdown("### Processing")
639
+
640
+ frame_interval = gr.Slider(
641
+ minimum=1,
642
+ maximum=60,
643
+ step=1,
644
+ value=10,
645
+ label="Extract Every N Frames",
646
+ info="For example, 10 = use every 10th frame.",
647
+ )
648
+
649
+ batch_size = gr.Slider(
650
+ minimum=1,
651
+ maximum=32,
652
+ step=1,
653
+ value=8,
654
+ label="Batch Size",
655
+ info=(
656
+ "Larger batch sizes may increase initial loading time but can significantly "
657
+ "improve total processing speed, especially for longer videos or high frame counts."
658
+ ),
659
+ )
660
+
661
  run_button = gr.Button("Generate Tags", variant="primary")
662
+
663
  with gr.Column():
664
  combined_tags = gr.Textbox(
665
  label="Combined Unique Tags (All Frames)",
 
670
  label="Details / Debug Info",
671
  )
672
 
673
+
674
  # ---------------- TAB 2: TAG CONTROL ----------------
675
  with gr.Tab("Tag Control"):
676
  gr.Markdown("### Tag Substitutes")
 
724
  )
725
 
726
 
 
727
  run_button.click(
728
  fn=tag_video_interface,
729
  inputs=[
 
734
  model_choice,
735
  tag_substitutes_df,
736
  tag_exclusions_df,
737
+ batch_size,
738
  ],
739
  outputs=[combined_tags, debug_info],
740
  )