throaway2854 commited on
Commit
b00ebae
·
verified ·
1 Parent(s): cd33eb4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -48
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import os
2
- from typing import Dict, Tuple
3
 
4
  import cv2
5
  import gradio as gr
@@ -17,6 +17,7 @@ combined (deduplicated) tags using a selected **WD-style tagging model**.
17
  - Extract every N-th frame (e.g., every 10th frame).
18
  - Control thresholds for **General Tags** and **Character Tags**.
19
  - All tags from all sampled frames are merged into **one unique, comma-separated string**.
 
20
  """
21
 
22
  DEFAULT_MODEL_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
@@ -218,6 +219,8 @@ class VideoTagger:
218
  frame_interval: int,
219
  general_thresh: float,
220
  character_thresh: float,
 
 
221
  progress=None,
222
  ) -> Tuple[str, Dict]:
223
  """
@@ -242,7 +245,7 @@ class VideoTagger:
242
  # Estimate total frames and how many will be processed
243
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0
244
  if total_frames <= 0:
245
- total_frames = 1 # avoid division by zero / weird metadata
246
 
247
  frames_to_process = max(1, (total_frames + frame_interval - 1) // frame_interval)
248
 
@@ -298,8 +301,35 @@ class VideoTagger:
298
 
299
  # Merge character + general tags, sorted by score (desc)
300
  all_tags_with_scores = {**aggregated_general, **aggregated_character}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
  sorted_tags = sorted(
302
- all_tags_with_scores.items(),
303
  key=lambda kv: kv[1],
304
  reverse=True,
305
  )
@@ -313,12 +343,14 @@ class VideoTagger:
313
  "frames_processed": int(processed_frames),
314
  "estimated_total_frames": int(total_frames),
315
  "estimated_frames_to_process": int(frames_to_process),
316
- "num_general_tags": len(aggregated_general),
317
- "num_character_tags": len(aggregated_character),
318
- "total_unique_tags": len(unique_tags),
319
  "frame_interval": int(frame_interval),
320
  "general_threshold": float(general_thresh),
321
  "character_threshold": float(character_thresh),
 
 
322
  }
323
 
324
  return combined_tags_str, debug_info
@@ -334,12 +366,57 @@ def get_tagger(model_repo: str) -> VideoTagger:
334
  return _tagger_cache[model_repo]
335
 
336
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
  def tag_video_interface(
338
  video_path: str,
339
  frame_interval: int,
340
  general_thresh: float,
341
  character_thresh: float,
342
  model_repo: str,
 
 
343
  progress=gr.Progress(track_tqdm=False),
344
  ):
345
  if video_path is None:
@@ -347,11 +424,17 @@ def tag_video_interface(
347
 
348
  try:
349
  tagger = get_tagger(model_repo)
 
 
 
 
350
  return tagger.tag_video(
351
  video_path=video_path,
352
  frame_interval=frame_interval,
353
  general_thresh=general_thresh,
354
  character_thresh=character_thresh,
 
 
355
  progress=progress,
356
  )
357
  except Exception as e:
@@ -362,60 +445,105 @@ with gr.Blocks(title=TITLE) as demo:
362
  gr.Markdown(f"## {TITLE}")
363
  gr.Markdown(DESCRIPTION)
364
 
365
- with gr.Row():
366
- with gr.Column():
367
- video_input = gr.Video(
368
- label="Video (.mp4 or .mov)",
369
- sources=["upload"],
370
- format="mp4",
371
- )
 
 
 
372
 
373
- model_choice = gr.Dropdown(
374
- choices=MODEL_OPTIONS,
375
- value=DEFAULT_MODEL_REPO,
376
- label="Tagging Model",
377
- )
378
 
379
- frame_interval = gr.Slider(
380
- minimum=1,
381
- maximum=60,
382
- step=1,
383
- value=10,
384
- label="Extract Every N Frames",
385
- info="For example, 10 = use every 10th frame.",
386
- )
387
 
388
- general_thresh = gr.Slider(
389
- minimum=0.0,
390
- maximum=1.0,
391
- step=0.01,
392
- value=0.35,
393
- label="General Tags Threshold",
394
- )
395
 
396
- character_thresh = gr.Slider(
397
- minimum=0.0,
398
- maximum=1.0,
399
- step=0.01,
400
- value=0.85,
401
- label="Character Tags Threshold",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
402
  )
403
 
404
- run_button = gr.Button("Generate Tags", variant="primary")
 
 
 
 
 
 
 
 
405
 
406
- with gr.Column():
407
- combined_tags = gr.Textbox(
408
- show_label=True,
409
- label="Combined Unique Tags (All Frames)",
410
- lines=6,
411
  )
412
- debug_info = gr.JSON(
413
- label="Details / Debug Info",
 
 
 
 
 
 
 
414
  )
415
 
 
416
  run_button.click(
417
  fn=tag_video_interface,
418
- inputs=[video_input, frame_interval, general_thresh, character_thresh, model_choice],
 
 
 
 
 
 
 
 
419
  outputs=[combined_tags, debug_info],
420
  )
421
 
 
1
  import os
2
+ from typing import Dict, Tuple, List, Set
3
 
4
  import cv2
5
  import gradio as gr
 
17
  - Extract every N-th frame (e.g., every 10th frame).
18
  - Control thresholds for **General Tags** and **Character Tags**.
19
  - All tags from all sampled frames are merged into **one unique, comma-separated string**.
20
+ - Use the **Tag Control** tab to define tag substitutions and exclusions for the final output.
21
  """
22
 
23
  DEFAULT_MODEL_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
 
219
  frame_interval: int,
220
  general_thresh: float,
221
  character_thresh: float,
222
+ tag_substitutes: Dict[str, str],
223
+ tag_exclusions: Set[str],
224
  progress=None,
225
  ) -> Tuple[str, Dict]:
226
  """
 
245
  # Estimate total frames and how many will be processed
246
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0
247
  if total_frames <= 0:
248
+ total_frames = 1 # avoid division issues
249
 
250
  frames_to_process = max(1, (total_frames + frame_interval - 1) // frame_interval)
251
 
 
301
 
302
  # Merge character + general tags, sorted by score (desc)
303
  all_tags_with_scores = {**aggregated_general, **aggregated_character}
304
+
305
+ # Apply substitutions & exclusions BEFORE final dedup
306
+ adjusted_all_tags: Dict[str, float] = {}
307
+
308
+ # Normalize keys in substitutes/exclusions (strip whitespace)
309
+ normalized_subs = {k.strip(): v.strip() for k, v in tag_substitutes.items() if k and v}
310
+ normalized_exclusions = {t.strip() for t in tag_exclusions if t}
311
+
312
+ for tag, score in all_tags_with_scores.items():
313
+ original_tag = tag.strip()
314
+
315
+ # Skip if original tag is excluded
316
+ if original_tag in normalized_exclusions:
317
+ continue
318
+
319
+ # Apply substitution (if any)
320
+ new_tag = normalized_subs.get(original_tag, original_tag)
321
+
322
+ # Skip if substituted tag is excluded
323
+ if new_tag in normalized_exclusions:
324
+ continue
325
+
326
+ # Keep max score for each resulting tag
327
+ if new_tag not in adjusted_all_tags or score > adjusted_all_tags[new_tag]:
328
+ adjusted_all_tags[new_tag] = score
329
+
330
+ # Sort by score descending
331
  sorted_tags = sorted(
332
+ adjusted_all_tags.items(),
333
  key=lambda kv: kv[1],
334
  reverse=True,
335
  )
 
343
  "frames_processed": int(processed_frames),
344
  "estimated_total_frames": int(total_frames),
345
  "estimated_frames_to_process": int(frames_to_process),
346
+ "num_general_tags_raw": len(aggregated_general),
347
+ "num_character_tags_raw": len(aggregated_character),
348
+ "total_unique_tags_after_control": len(unique_tags),
349
  "frame_interval": int(frame_interval),
350
  "general_threshold": float(general_thresh),
351
  "character_threshold": float(character_thresh),
352
+ "num_substitution_rules": len(normalized_subs),
353
+ "num_exclusions": len(normalized_exclusions),
354
  }
355
 
356
  return combined_tags_str, debug_info
 
366
  return _tagger_cache[model_repo]
367
 
368
 
369
+ def _normalize_tag_substitutes(data) -> Dict[str, str]:
370
+ """
371
+ Convert Dataframe (as array: list[list]) into {original: substitute}.
372
+ """
373
+ mapping: Dict[str, str] = {}
374
+ if data is None:
375
+ return mapping
376
+
377
+ # Expect data as list of [original, substitute]
378
+ for row in data:
379
+ if not row or len(row) < 2:
380
+ continue
381
+ orig = (row[0] or "").strip()
382
+ sub = (row[1] or "").strip()
383
+ if orig and sub:
384
+ mapping[orig] = sub
385
+ return mapping
386
+
387
+
388
+ def _normalize_tag_exclusions(data) -> Set[str]:
389
+ """
390
+ Convert Dataframe (as array: list[list]) into set of tags to exclude.
391
+ """
392
+ exclusions: Set[str] = set()
393
+ if data is None:
394
+ return exclusions
395
+
396
+ # Expect data as list of [tag] rows
397
+ for row in data:
398
+ if row is None:
399
+ continue
400
+ if isinstance(row, (list, tuple)):
401
+ if not row:
402
+ continue
403
+ val = row[0]
404
+ else:
405
+ val = row
406
+ val = (val or "").strip()
407
+ if val:
408
+ exclusions.add(val)
409
+ return exclusions
410
+
411
+
412
  def tag_video_interface(
413
  video_path: str,
414
  frame_interval: int,
415
  general_thresh: float,
416
  character_thresh: float,
417
  model_repo: str,
418
+ tag_substitutes_df,
419
+ tag_exclusions_df,
420
  progress=gr.Progress(track_tqdm=False),
421
  ):
422
  if video_path is None:
 
424
 
425
  try:
426
  tagger = get_tagger(model_repo)
427
+
428
+ tag_substitutes = _normalize_tag_substitutes(tag_substitutes_df)
429
+ tag_exclusions = _normalize_tag_exclusions(tag_exclusions_df)
430
+
431
  return tagger.tag_video(
432
  video_path=video_path,
433
  frame_interval=frame_interval,
434
  general_thresh=general_thresh,
435
  character_thresh=character_thresh,
436
+ tag_substitutes=tag_substitutes,
437
+ tag_exclusions=tag_exclusions,
438
  progress=progress,
439
  )
440
  except Exception as e:
 
445
  gr.Markdown(f"## {TITLE}")
446
  gr.Markdown(DESCRIPTION)
447
 
448
+ with gr.Tabs():
449
+ # ---------------- TAB 1: TAGGING ----------------
450
+ with gr.Tab("Tagging"):
451
+ with gr.Row():
452
+ with gr.Column():
453
+ video_input = gr.Video(
454
+ label="Video (.mp4 or .mov)",
455
+ sources=["upload"],
456
+ format="mp4",
457
+ )
458
 
459
+ model_choice = gr.Dropdown(
460
+ choices=MODEL_OPTIONS,
461
+ value=DEFAULT_MODEL_REPO,
462
+ label="Tagging Model",
463
+ )
464
 
465
+ frame_interval = gr.Slider(
466
+ minimum=1,
467
+ maximum=60,
468
+ step=1,
469
+ value=10,
470
+ label="Extract Every N Frames",
471
+ info="For example, 10 = use every 10th frame.",
472
+ )
473
 
474
+ general_thresh = gr.Slider(
475
+ minimum=0.0,
476
+ maximum=1.0,
477
+ step=0.01,
478
+ value=0.35,
479
+ label="General Tags Threshold",
480
+ )
481
 
482
+ character_thresh = gr.Slider(
483
+ minimum=0.0,
484
+ maximum=1.0,
485
+ step=0.01,
486
+ value=0.85,
487
+ label="Character Tags Threshold",
488
+ )
489
+
490
+ run_button = gr.Button("Generate Tags", variant="primary")
491
+
492
+ with gr.Column():
493
+ combined_tags = gr.Textbox(
494
+ label="Combined Unique Tags (All Frames)",
495
+ lines=6,
496
+ buttons: list[Literal['copy']],
497
+ )
498
+ debug_info = gr.JSON(
499
+ label="Details / Debug Info",
500
+ )
501
+
502
+ # ---------------- TAB 2: TAG CONTROL ----------------
503
+ with gr.Tab("Tag Control"):
504
+ gr.Markdown("### Tag Substitutes")
505
+ gr.Markdown(
506
+ "Add rows where **Original Tag** will be replaced by **Substitute Tag** "
507
+ "in the final combined output (after all frames are processed)."
508
  )
509
 
510
+ tag_substitutes_df = gr.Dataframe(
511
+ headers=["Original Tag", "Substitute Tag"],
512
+ datatype=["str", "str"],
513
+ row_count=3,
514
+ col_count=2,
515
+ type="array",
516
+ label="Tag Substitutes",
517
+ interactive=True,
518
+ )
519
 
520
+ gr.Markdown("### Tag Exclusions")
521
+ gr.Markdown(
522
+ "Add tags that should be **removed entirely** from the final combined output."
 
 
523
  )
524
+
525
+ tag_exclusions_df = gr.Dataframe(
526
+ headers=["Tag to Exclude"],
527
+ datatype=["str"],
528
+ row_count=3,
529
+ col_count=1,
530
+ type="array",
531
+ label="Tag Exclusions",
532
+ interactive=True,
533
  )
534
 
535
+ # Wiring the button AFTER all components are defined
536
  run_button.click(
537
  fn=tag_video_interface,
538
+ inputs=[
539
+ video_input,
540
+ frame_interval,
541
+ general_thresh,
542
+ character_thresh,
543
+ model_choice,
544
+ tag_substitutes_df,
545
+ tag_exclusions_df,
546
+ ],
547
  outputs=[combined_tags, debug_info],
548
  )
549