Update app.py
Browse files
app.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
import os
|
| 2 |
-
from typing import Dict, Tuple
|
| 3 |
|
| 4 |
import cv2
|
| 5 |
import gradio as gr
|
|
@@ -17,6 +17,7 @@ combined (deduplicated) tags using a selected **WD-style tagging model**.
|
|
| 17 |
- Extract every N-th frame (e.g., every 10th frame).
|
| 18 |
- Control thresholds for **General Tags** and **Character Tags**.
|
| 19 |
- All tags from all sampled frames are merged into **one unique, comma-separated string**.
|
|
|
|
| 20 |
"""
|
| 21 |
|
| 22 |
DEFAULT_MODEL_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
|
|
@@ -218,6 +219,8 @@ class VideoTagger:
|
|
| 218 |
frame_interval: int,
|
| 219 |
general_thresh: float,
|
| 220 |
character_thresh: float,
|
|
|
|
|
|
|
| 221 |
progress=None,
|
| 222 |
) -> Tuple[str, Dict]:
|
| 223 |
"""
|
|
@@ -242,7 +245,7 @@ class VideoTagger:
|
|
| 242 |
# Estimate total frames and how many will be processed
|
| 243 |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0
|
| 244 |
if total_frames <= 0:
|
| 245 |
-
total_frames = 1 # avoid division
|
| 246 |
|
| 247 |
frames_to_process = max(1, (total_frames + frame_interval - 1) // frame_interval)
|
| 248 |
|
|
@@ -298,8 +301,35 @@ class VideoTagger:
|
|
| 298 |
|
| 299 |
# Merge character + general tags, sorted by score (desc)
|
| 300 |
all_tags_with_scores = {**aggregated_general, **aggregated_character}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
sorted_tags = sorted(
|
| 302 |
-
|
| 303 |
key=lambda kv: kv[1],
|
| 304 |
reverse=True,
|
| 305 |
)
|
|
@@ -313,12 +343,14 @@ class VideoTagger:
|
|
| 313 |
"frames_processed": int(processed_frames),
|
| 314 |
"estimated_total_frames": int(total_frames),
|
| 315 |
"estimated_frames_to_process": int(frames_to_process),
|
| 316 |
-
"
|
| 317 |
-
"
|
| 318 |
-
"
|
| 319 |
"frame_interval": int(frame_interval),
|
| 320 |
"general_threshold": float(general_thresh),
|
| 321 |
"character_threshold": float(character_thresh),
|
|
|
|
|
|
|
| 322 |
}
|
| 323 |
|
| 324 |
return combined_tags_str, debug_info
|
|
@@ -334,12 +366,57 @@ def get_tagger(model_repo: str) -> VideoTagger:
|
|
| 334 |
return _tagger_cache[model_repo]
|
| 335 |
|
| 336 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
def tag_video_interface(
|
| 338 |
video_path: str,
|
| 339 |
frame_interval: int,
|
| 340 |
general_thresh: float,
|
| 341 |
character_thresh: float,
|
| 342 |
model_repo: str,
|
|
|
|
|
|
|
| 343 |
progress=gr.Progress(track_tqdm=False),
|
| 344 |
):
|
| 345 |
if video_path is None:
|
|
@@ -347,11 +424,17 @@ def tag_video_interface(
|
|
| 347 |
|
| 348 |
try:
|
| 349 |
tagger = get_tagger(model_repo)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
return tagger.tag_video(
|
| 351 |
video_path=video_path,
|
| 352 |
frame_interval=frame_interval,
|
| 353 |
general_thresh=general_thresh,
|
| 354 |
character_thresh=character_thresh,
|
|
|
|
|
|
|
| 355 |
progress=progress,
|
| 356 |
)
|
| 357 |
except Exception as e:
|
|
@@ -362,60 +445,105 @@ with gr.Blocks(title=TITLE) as demo:
|
|
| 362 |
gr.Markdown(f"## {TITLE}")
|
| 363 |
gr.Markdown(DESCRIPTION)
|
| 364 |
|
| 365 |
-
with gr.
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
|
|
|
|
|
|
|
|
|
| 372 |
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 402 |
)
|
| 403 |
|
| 404 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
label="Combined Unique Tags (All Frames)",
|
| 410 |
-
lines=6,
|
| 411 |
)
|
| 412 |
-
|
| 413 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 414 |
)
|
| 415 |
|
|
|
|
| 416 |
run_button.click(
|
| 417 |
fn=tag_video_interface,
|
| 418 |
-
inputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 419 |
outputs=[combined_tags, debug_info],
|
| 420 |
)
|
| 421 |
|
|
|
|
| 1 |
import os
|
| 2 |
+
from typing import Dict, Tuple, List, Set
|
| 3 |
|
| 4 |
import cv2
|
| 5 |
import gradio as gr
|
|
|
|
| 17 |
- Extract every N-th frame (e.g., every 10th frame).
|
| 18 |
- Control thresholds for **General Tags** and **Character Tags**.
|
| 19 |
- All tags from all sampled frames are merged into **one unique, comma-separated string**.
|
| 20 |
+
- Use the **Tag Control** tab to define tag substitutions and exclusions for the final output.
|
| 21 |
"""
|
| 22 |
|
| 23 |
DEFAULT_MODEL_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
|
|
|
|
| 219 |
frame_interval: int,
|
| 220 |
general_thresh: float,
|
| 221 |
character_thresh: float,
|
| 222 |
+
tag_substitutes: Dict[str, str],
|
| 223 |
+
tag_exclusions: Set[str],
|
| 224 |
progress=None,
|
| 225 |
) -> Tuple[str, Dict]:
|
| 226 |
"""
|
|
|
|
| 245 |
# Estimate total frames and how many will be processed
|
| 246 |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0
|
| 247 |
if total_frames <= 0:
|
| 248 |
+
total_frames = 1 # avoid division issues
|
| 249 |
|
| 250 |
frames_to_process = max(1, (total_frames + frame_interval - 1) // frame_interval)
|
| 251 |
|
|
|
|
| 301 |
|
| 302 |
# Merge character + general tags, sorted by score (desc)
|
| 303 |
all_tags_with_scores = {**aggregated_general, **aggregated_character}
|
| 304 |
+
|
| 305 |
+
# Apply substitutions & exclusions BEFORE final dedup
|
| 306 |
+
adjusted_all_tags: Dict[str, float] = {}
|
| 307 |
+
|
| 308 |
+
# Normalize keys in substitutes/exclusions (strip whitespace)
|
| 309 |
+
normalized_subs = {k.strip(): v.strip() for k, v in tag_substitutes.items() if k and v}
|
| 310 |
+
normalized_exclusions = {t.strip() for t in tag_exclusions if t}
|
| 311 |
+
|
| 312 |
+
for tag, score in all_tags_with_scores.items():
|
| 313 |
+
original_tag = tag.strip()
|
| 314 |
+
|
| 315 |
+
# Skip if original tag is excluded
|
| 316 |
+
if original_tag in normalized_exclusions:
|
| 317 |
+
continue
|
| 318 |
+
|
| 319 |
+
# Apply substitution (if any)
|
| 320 |
+
new_tag = normalized_subs.get(original_tag, original_tag)
|
| 321 |
+
|
| 322 |
+
# Skip if substituted tag is excluded
|
| 323 |
+
if new_tag in normalized_exclusions:
|
| 324 |
+
continue
|
| 325 |
+
|
| 326 |
+
# Keep max score for each resulting tag
|
| 327 |
+
if new_tag not in adjusted_all_tags or score > adjusted_all_tags[new_tag]:
|
| 328 |
+
adjusted_all_tags[new_tag] = score
|
| 329 |
+
|
| 330 |
+
# Sort by score descending
|
| 331 |
sorted_tags = sorted(
|
| 332 |
+
adjusted_all_tags.items(),
|
| 333 |
key=lambda kv: kv[1],
|
| 334 |
reverse=True,
|
| 335 |
)
|
|
|
|
| 343 |
"frames_processed": int(processed_frames),
|
| 344 |
"estimated_total_frames": int(total_frames),
|
| 345 |
"estimated_frames_to_process": int(frames_to_process),
|
| 346 |
+
"num_general_tags_raw": len(aggregated_general),
|
| 347 |
+
"num_character_tags_raw": len(aggregated_character),
|
| 348 |
+
"total_unique_tags_after_control": len(unique_tags),
|
| 349 |
"frame_interval": int(frame_interval),
|
| 350 |
"general_threshold": float(general_thresh),
|
| 351 |
"character_threshold": float(character_thresh),
|
| 352 |
+
"num_substitution_rules": len(normalized_subs),
|
| 353 |
+
"num_exclusions": len(normalized_exclusions),
|
| 354 |
}
|
| 355 |
|
| 356 |
return combined_tags_str, debug_info
|
|
|
|
| 366 |
return _tagger_cache[model_repo]
|
| 367 |
|
| 368 |
|
| 369 |
+
def _normalize_tag_substitutes(data) -> Dict[str, str]:
|
| 370 |
+
"""
|
| 371 |
+
Convert Dataframe (as array: list[list]) into {original: substitute}.
|
| 372 |
+
"""
|
| 373 |
+
mapping: Dict[str, str] = {}
|
| 374 |
+
if data is None:
|
| 375 |
+
return mapping
|
| 376 |
+
|
| 377 |
+
# Expect data as list of [original, substitute]
|
| 378 |
+
for row in data:
|
| 379 |
+
if not row or len(row) < 2:
|
| 380 |
+
continue
|
| 381 |
+
orig = (row[0] or "").strip()
|
| 382 |
+
sub = (row[1] or "").strip()
|
| 383 |
+
if orig and sub:
|
| 384 |
+
mapping[orig] = sub
|
| 385 |
+
return mapping
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
def _normalize_tag_exclusions(data) -> Set[str]:
|
| 389 |
+
"""
|
| 390 |
+
Convert Dataframe (as array: list[list]) into set of tags to exclude.
|
| 391 |
+
"""
|
| 392 |
+
exclusions: Set[str] = set()
|
| 393 |
+
if data is None:
|
| 394 |
+
return exclusions
|
| 395 |
+
|
| 396 |
+
# Expect data as list of [tag] rows
|
| 397 |
+
for row in data:
|
| 398 |
+
if row is None:
|
| 399 |
+
continue
|
| 400 |
+
if isinstance(row, (list, tuple)):
|
| 401 |
+
if not row:
|
| 402 |
+
continue
|
| 403 |
+
val = row[0]
|
| 404 |
+
else:
|
| 405 |
+
val = row
|
| 406 |
+
val = (val or "").strip()
|
| 407 |
+
if val:
|
| 408 |
+
exclusions.add(val)
|
| 409 |
+
return exclusions
|
| 410 |
+
|
| 411 |
+
|
| 412 |
def tag_video_interface(
|
| 413 |
video_path: str,
|
| 414 |
frame_interval: int,
|
| 415 |
general_thresh: float,
|
| 416 |
character_thresh: float,
|
| 417 |
model_repo: str,
|
| 418 |
+
tag_substitutes_df,
|
| 419 |
+
tag_exclusions_df,
|
| 420 |
progress=gr.Progress(track_tqdm=False),
|
| 421 |
):
|
| 422 |
if video_path is None:
|
|
|
|
| 424 |
|
| 425 |
try:
|
| 426 |
tagger = get_tagger(model_repo)
|
| 427 |
+
|
| 428 |
+
tag_substitutes = _normalize_tag_substitutes(tag_substitutes_df)
|
| 429 |
+
tag_exclusions = _normalize_tag_exclusions(tag_exclusions_df)
|
| 430 |
+
|
| 431 |
return tagger.tag_video(
|
| 432 |
video_path=video_path,
|
| 433 |
frame_interval=frame_interval,
|
| 434 |
general_thresh=general_thresh,
|
| 435 |
character_thresh=character_thresh,
|
| 436 |
+
tag_substitutes=tag_substitutes,
|
| 437 |
+
tag_exclusions=tag_exclusions,
|
| 438 |
progress=progress,
|
| 439 |
)
|
| 440 |
except Exception as e:
|
|
|
|
| 445 |
gr.Markdown(f"## {TITLE}")
|
| 446 |
gr.Markdown(DESCRIPTION)
|
| 447 |
|
| 448 |
+
with gr.Tabs():
|
| 449 |
+
# ---------------- TAB 1: TAGGING ----------------
|
| 450 |
+
with gr.Tab("Tagging"):
|
| 451 |
+
with gr.Row():
|
| 452 |
+
with gr.Column():
|
| 453 |
+
video_input = gr.Video(
|
| 454 |
+
label="Video (.mp4 or .mov)",
|
| 455 |
+
sources=["upload"],
|
| 456 |
+
format="mp4",
|
| 457 |
+
)
|
| 458 |
|
| 459 |
+
model_choice = gr.Dropdown(
|
| 460 |
+
choices=MODEL_OPTIONS,
|
| 461 |
+
value=DEFAULT_MODEL_REPO,
|
| 462 |
+
label="Tagging Model",
|
| 463 |
+
)
|
| 464 |
|
| 465 |
+
frame_interval = gr.Slider(
|
| 466 |
+
minimum=1,
|
| 467 |
+
maximum=60,
|
| 468 |
+
step=1,
|
| 469 |
+
value=10,
|
| 470 |
+
label="Extract Every N Frames",
|
| 471 |
+
info="For example, 10 = use every 10th frame.",
|
| 472 |
+
)
|
| 473 |
|
| 474 |
+
general_thresh = gr.Slider(
|
| 475 |
+
minimum=0.0,
|
| 476 |
+
maximum=1.0,
|
| 477 |
+
step=0.01,
|
| 478 |
+
value=0.35,
|
| 479 |
+
label="General Tags Threshold",
|
| 480 |
+
)
|
| 481 |
|
| 482 |
+
character_thresh = gr.Slider(
|
| 483 |
+
minimum=0.0,
|
| 484 |
+
maximum=1.0,
|
| 485 |
+
step=0.01,
|
| 486 |
+
value=0.85,
|
| 487 |
+
label="Character Tags Threshold",
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
run_button = gr.Button("Generate Tags", variant="primary")
|
| 491 |
+
|
| 492 |
+
with gr.Column():
|
| 493 |
+
combined_tags = gr.Textbox(
|
| 494 |
+
label="Combined Unique Tags (All Frames)",
|
| 495 |
+
lines=6,
|
| 496 |
+
buttons: list[Literal['copy']],
|
| 497 |
+
)
|
| 498 |
+
debug_info = gr.JSON(
|
| 499 |
+
label="Details / Debug Info",
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
# ---------------- TAB 2: TAG CONTROL ----------------
|
| 503 |
+
with gr.Tab("Tag Control"):
|
| 504 |
+
gr.Markdown("### Tag Substitutes")
|
| 505 |
+
gr.Markdown(
|
| 506 |
+
"Add rows where **Original Tag** will be replaced by **Substitute Tag** "
|
| 507 |
+
"in the final combined output (after all frames are processed)."
|
| 508 |
)
|
| 509 |
|
| 510 |
+
tag_substitutes_df = gr.Dataframe(
|
| 511 |
+
headers=["Original Tag", "Substitute Tag"],
|
| 512 |
+
datatype=["str", "str"],
|
| 513 |
+
row_count=3,
|
| 514 |
+
col_count=2,
|
| 515 |
+
type="array",
|
| 516 |
+
label="Tag Substitutes",
|
| 517 |
+
interactive=True,
|
| 518 |
+
)
|
| 519 |
|
| 520 |
+
gr.Markdown("### Tag Exclusions")
|
| 521 |
+
gr.Markdown(
|
| 522 |
+
"Add tags that should be **removed entirely** from the final combined output."
|
|
|
|
|
|
|
| 523 |
)
|
| 524 |
+
|
| 525 |
+
tag_exclusions_df = gr.Dataframe(
|
| 526 |
+
headers=["Tag to Exclude"],
|
| 527 |
+
datatype=["str"],
|
| 528 |
+
row_count=3,
|
| 529 |
+
col_count=1,
|
| 530 |
+
type="array",
|
| 531 |
+
label="Tag Exclusions",
|
| 532 |
+
interactive=True,
|
| 533 |
)
|
| 534 |
|
| 535 |
+
# Wiring the button AFTER all components are defined
|
| 536 |
run_button.click(
|
| 537 |
fn=tag_video_interface,
|
| 538 |
+
inputs=[
|
| 539 |
+
video_input,
|
| 540 |
+
frame_interval,
|
| 541 |
+
general_thresh,
|
| 542 |
+
character_thresh,
|
| 543 |
+
model_choice,
|
| 544 |
+
tag_substitutes_df,
|
| 545 |
+
tag_exclusions_df,
|
| 546 |
+
],
|
| 547 |
outputs=[combined_tags, debug_info],
|
| 548 |
)
|
| 549 |
|