|
|
import os |
|
|
from typing import Dict, Tuple, List, Set |
|
|
|
|
|
import cv2 |
|
|
import gradio as gr |
|
|
import huggingface_hub |
|
|
import numpy as np |
|
|
import onnxruntime as rt |
|
|
import pandas as pd |
|
|
import time |
|
|
from PIL import Image |
|
|
|
|
|
TITLE = "AI Video Auto-Tagger & Captioner" |
|
|
DESCRIPTION = """ |
|
|
Upload a .mp4 or .mov video, choose how often to sample frames, and generate |
|
|
combined (deduplicated) tags using a selected **tagging/captioning model**. |
|
|
|
|
|
- Extract every N-th frame (e.g., every 10th frame). |
|
|
- Control thresholds for **General Tags** and **Character Tags**. |
|
|
- All tags from all sampled frames are merged into **one unique, comma-separated string**. |
|
|
- Use the **Tag Control** tab to define tag substitutions and exclusions for the final output. |
|
|
|
|
|
**This space is running on the free CPU tier so it can be slow. If you want better speeds, clone the space and host it on more capable hardware.** |
|
|
""" |
|
|
|
|
|
DEFAULT_MODEL_REPO = "SmilingWolf/wd-eva02-large-tagger-v3" |
|
|
|
|
|
MODEL_OPTIONS = [ |
|
|
"SmilingWolf/wd-eva02-large-tagger-v3", |
|
|
"SmilingWolf/wd-vit-large-tagger-v3", |
|
|
"SmilingWolf/wd-vit-tagger-v3", |
|
|
"SmilingWolf/wd-convnext-tagger-v3", |
|
|
"SmilingWolf/wd-swinv2-tagger-v3", |
|
|
"deepghs/idolsankaku-eva02-large-tagger-v1", |
|
|
"deepghs/idolsankaku-swinv2-tagger-v1", |
|
|
"gokaygokay/Florence-2-SD3-Captioner", |
|
|
"gokaygokay/Florence-2-Flux", |
|
|
"gokaygokay/Florence-2-Flux-Large", |
|
|
"MiaoshouAI/Florence-2-large-PromptGen-v2.0", |
|
|
"thwri/CogFlorence-2.2-Large", |
|
|
"deepghs/deepgelbooru_onnx", |
|
|
] |
|
|
|
|
|
MODEL_FILENAME = "model.onnx" |
|
|
LABEL_FILENAME = "selected_tags.csv" |
|
|
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN") |
|
|
|
|
|
|
|
|
kaomojis = [ |
|
|
"0_0", |
|
|
"(o)_(o)", |
|
|
"+_+", |
|
|
"+_-", |
|
|
"._.", |
|
|
"<o>_<o>", |
|
|
"<|>_<|>", |
|
|
"=_=", |
|
|
">_<", |
|
|
"3_3", |
|
|
"6_9", |
|
|
">_o", |
|
|
"@_@", |
|
|
"^_^", |
|
|
"o_o", |
|
|
"u_u", |
|
|
"x_x", |
|
|
"|_|", |
|
|
"||_||", |
|
|
] |
|
|
|
|
|
css = """ |
|
|
#tagging-tab-button, |
|
|
#tag-control-tab-button { |
|
|
font-weight: 900 !important; |
|
|
} |
|
|
#tagging-tab-button:hover, |
|
|
#tag-control-tab-button:hover { |
|
|
filter: brightness(0.9); |
|
|
} |
|
|
""" |
|
|
|
|
|
def _format_duration(seconds: float) -> str: |
|
|
""" |
|
|
Format a duration in seconds as MM:SS or HH:MM:SS. |
|
|
""" |
|
|
total_seconds = int(round(seconds)) |
|
|
hours, rem = divmod(total_seconds, 3600) |
|
|
minutes, secs = divmod(rem, 60) |
|
|
|
|
|
if hours > 0: |
|
|
return f"{hours:02d}:{minutes:02d}:{secs:02d}" |
|
|
else: |
|
|
return f"{minutes:02d}:{secs:02d}" |
|
|
|
|
|
|
|
|
def load_labels(df: pd.DataFrame): |
|
|
""" |
|
|
Convert tag dataframe into: |
|
|
- tag_names (str list) |
|
|
- rating_indexes (list[int]) |
|
|
- general_indexes (list[int]) |
|
|
- character_indexes (list[int]) |
|
|
""" |
|
|
name_series = df["name"] |
|
|
name_series = name_series.map( |
|
|
lambda x: x.replace("_", " ") if x not in kaomojis else x |
|
|
) |
|
|
tag_names = name_series.tolist() |
|
|
|
|
|
|
|
|
|
|
|
rating_indexes = list(np.where(df["category"] == 9)[0]) |
|
|
general_indexes = list(np.where(df["category"] == 0)[0]) |
|
|
character_indexes = list(np.where(df["category"] == 4)[0]) |
|
|
|
|
|
return tag_names, rating_indexes, general_indexes, character_indexes |
|
|
|
|
|
|
|
|
def add_substitute_row(current): |
|
|
""" |
|
|
Append an empty [original, substitute] row to the substitutes dataframe. |
|
|
Works with type='array' (list of lists). |
|
|
""" |
|
|
if current is None: |
|
|
current = [] |
|
|
|
|
|
current = list(current) |
|
|
current.append(["", ""]) |
|
|
return current |
|
|
|
|
|
|
|
|
def add_exclusion_row(current): |
|
|
""" |
|
|
Append an empty [tag] row to the exclusions dataframe. |
|
|
""" |
|
|
if current is None: |
|
|
current = [] |
|
|
current = list(current) |
|
|
current.append([""]) |
|
|
return current |
|
|
|
|
|
def compute_recommended_batch_size(sampled_frames: int) -> int: |
|
|
""" |
|
|
Heuristic batch-size recommendation based on how many frames |
|
|
will actually be processed (after sampling). |
|
|
|
|
|
Tuned from your measurements: |
|
|
- Small clips -> smaller batches |
|
|
- Medium clips -> medium batches |
|
|
- Larger clips -> larger batches, capped at 32 |
|
|
""" |
|
|
if sampled_frames <= 0: |
|
|
return 8 |
|
|
|
|
|
if sampled_frames <= 20: |
|
|
rec = 8 |
|
|
elif sampled_frames <= 40: |
|
|
rec = 16 |
|
|
elif sampled_frames <= 80: |
|
|
rec = 24 |
|
|
elif sampled_frames <= 160: |
|
|
rec = 32 |
|
|
else: |
|
|
rec = 32 |
|
|
|
|
|
|
|
|
return max(1, min(32, rec)) |
|
|
|
|
|
def update_batch_recommendation(video_path: str, frame_interval: int) -> str: |
|
|
""" |
|
|
Compute a recommended batch size based on the video length |
|
|
and the current frame sampling interval, and return HTML |
|
|
for the UI. |
|
|
""" |
|
|
if not video_path or not os.path.exists(video_path): |
|
|
return "<span>Upload a video to see a recommended batch size.</span>" |
|
|
|
|
|
try: |
|
|
frame_interval = max(int(frame_interval), 1) |
|
|
except Exception: |
|
|
frame_interval = 1 |
|
|
|
|
|
try: |
|
|
cap = cv2.VideoCapture(video_path) |
|
|
if not cap.isOpened(): |
|
|
return "<span>Could not read video to estimate batch size.</span>" |
|
|
|
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0 |
|
|
cap.release() |
|
|
|
|
|
if total_frames <= 0: |
|
|
return "<span>Could not determine video length to recommend batch size.</span>" |
|
|
|
|
|
sampled_frames = max(1, (total_frames + frame_interval - 1) // frame_interval) |
|
|
rec = compute_recommended_batch_size(sampled_frames) |
|
|
|
|
|
return ( |
|
|
f"<span>Recommended batch size: <b>{rec}</b> " |
|
|
f"(based on ~{sampled_frames} sampled frames).</span>" |
|
|
) |
|
|
except Exception as e: |
|
|
return f"<span>Could not compute recommendation: {e}</span>" |
|
|
|
|
|
def show_batch_loading() -> str: |
|
|
""" |
|
|
Lightweight UI helper: show a pulsing 'calculating' message |
|
|
while we compute the recommended batch size. |
|
|
""" |
|
|
return "<span class='batch-loading'>Calculating recommended batch size...</span>" |
|
|
|
|
|
|
|
|
class VideoTagger: |
|
|
""" |
|
|
Wraps a WD-style ONNX model and tag metadata, |
|
|
and exposes helpers to tag PIL images and full videos. |
|
|
""" |
|
|
|
|
|
def __init__(self, model_repo: str, batch_size: int = 16): |
|
|
self.model_repo = model_repo |
|
|
self.model = None |
|
|
self.model_target_size = None |
|
|
self.tag_names = None |
|
|
self.rating_indexes = None |
|
|
self.general_indexes = None |
|
|
self.character_indexes = None |
|
|
self.batch_size = batch_size |
|
|
|
|
|
def _download_model_files(self) -> Tuple[str, str]: |
|
|
csv_path = huggingface_hub.hf_hub_download( |
|
|
repo_id=self.model_repo, |
|
|
filename=LABEL_FILENAME, |
|
|
token=HF_TOKEN, |
|
|
) |
|
|
model_path = huggingface_hub.hf_hub_download( |
|
|
repo_id=self.model_repo, |
|
|
filename=MODEL_FILENAME, |
|
|
token=HF_TOKEN, |
|
|
) |
|
|
return csv_path, model_path |
|
|
|
|
|
def _load_model_if_needed(self): |
|
|
if self.model is not None: |
|
|
return |
|
|
|
|
|
csv_path, model_path = self._download_model_files() |
|
|
|
|
|
tags_df = pd.read_csv(csv_path) |
|
|
( |
|
|
self.tag_names, |
|
|
self.rating_indexes, |
|
|
self.general_indexes, |
|
|
self.character_indexes, |
|
|
) = load_labels(tags_df) |
|
|
|
|
|
|
|
|
self.model = rt.InferenceSession(model_path) |
|
|
|
|
|
|
|
|
_, height, width, _ = self.model.get_inputs()[0].shape |
|
|
assert height == width, "Model expects square inputs" |
|
|
self.model_target_size = int(height) |
|
|
|
|
|
def _prepare_image(self, image: Image.Image) -> np.ndarray: |
|
|
""" |
|
|
Convert a PIL image into the model's expected input tensor: |
|
|
- RGBA composited onto white |
|
|
- padded to square |
|
|
- resized to model_target_size |
|
|
- converted to BGR |
|
|
- shape (1, H, W, 3), float32 |
|
|
""" |
|
|
target_size = self.model_target_size |
|
|
|
|
|
|
|
|
canvas = Image.new("RGBA", image.size, (255, 255, 255, 255)) |
|
|
canvas.alpha_composite(image) |
|
|
image_rgb = canvas.convert("RGB") |
|
|
|
|
|
|
|
|
w, h = image_rgb.size |
|
|
max_dim = max(w, h) |
|
|
pad_left = (max_dim - w) // 2 |
|
|
pad_top = (max_dim - h) // 2 |
|
|
|
|
|
padded = Image.new("RGB", (max_dim, max_dim), (255, 255, 255)) |
|
|
padded.paste(image_rgb, (pad_left, pad_top)) |
|
|
|
|
|
|
|
|
if max_dim != target_size: |
|
|
padded = padded.resize((target_size, target_size), Image.BICUBIC) |
|
|
|
|
|
|
|
|
arr = np.asarray(padded, dtype=np.float32) |
|
|
arr = arr[:, :, ::-1] |
|
|
|
|
|
|
|
|
arr = np.expand_dims(arr, axis=0) |
|
|
return arr |
|
|
|
|
|
def _prepare_frame_bgr(self, frame_bgr: np.ndarray) -> np.ndarray: |
|
|
""" |
|
|
Fast path for OpenCV frames (BGR uint8). |
|
|
Pads to square, resizes to model_target_size, converts to float32. |
|
|
|
|
|
Returns: (H, W, 3) float32 array in BGR format (no batch dim). |
|
|
""" |
|
|
target_size = self.model_target_size |
|
|
|
|
|
h, w, _ = frame_bgr.shape |
|
|
max_dim = max(h, w) |
|
|
|
|
|
|
|
|
pad_vert = max_dim - h |
|
|
pad_horiz = max_dim - w |
|
|
top = pad_vert // 2 |
|
|
bottom = pad_vert - top |
|
|
left = pad_horiz // 2 |
|
|
right = pad_horiz - left |
|
|
|
|
|
|
|
|
frame_square = cv2.copyMakeBorder( |
|
|
frame_bgr, |
|
|
top, bottom, left, right, |
|
|
borderType=cv2.BORDER_CONSTANT, |
|
|
value=(255, 255, 255), |
|
|
) |
|
|
|
|
|
|
|
|
if max_dim != target_size: |
|
|
frame_square = cv2.resize( |
|
|
frame_square, |
|
|
(target_size, target_size), |
|
|
interpolation=cv2.INTER_AREA, |
|
|
) |
|
|
|
|
|
|
|
|
arr = frame_square.astype(np.float32) |
|
|
return arr |
|
|
|
|
|
def _run_batch_and_aggregate( |
|
|
self, |
|
|
batch_tensors: List[np.ndarray], |
|
|
general_thresh: float, |
|
|
character_thresh: float, |
|
|
aggregated_general: Dict[str, float], |
|
|
aggregated_character: Dict[str, float], |
|
|
) -> int: |
|
|
""" |
|
|
Run ONNX inference on a batch of preprocessed frames and |
|
|
update aggregated_general / aggregated_character with max scores. |
|
|
|
|
|
Returns: number of frames processed in this batch. |
|
|
""" |
|
|
if not batch_tensors: |
|
|
return 0 |
|
|
|
|
|
input_name = self.model.get_inputs()[0].name |
|
|
output_name = self.model.get_outputs()[0].name |
|
|
|
|
|
|
|
|
input_tensor = np.stack(batch_tensors, axis=0) |
|
|
|
|
|
preds_batch = self.model.run([output_name], {input_name: input_tensor})[0] |
|
|
|
|
|
|
|
|
for preds in preds_batch: |
|
|
general_res, character_res = self._extract_tags_from_scores( |
|
|
preds, |
|
|
general_thresh=general_thresh, |
|
|
character_thresh=character_thresh, |
|
|
) |
|
|
|
|
|
|
|
|
for tag, score in general_res.items(): |
|
|
if tag not in aggregated_general or score > aggregated_general[tag]: |
|
|
aggregated_general[tag] = score |
|
|
|
|
|
for tag, score in character_res.items(): |
|
|
if tag not in aggregated_character or score > aggregated_character[tag]: |
|
|
aggregated_character[tag] = score |
|
|
|
|
|
return len(batch_tensors) |
|
|
|
|
|
def tag_image( |
|
|
self, |
|
|
image: Image.Image, |
|
|
general_thresh: float, |
|
|
character_thresh: float, |
|
|
) -> Tuple[Dict[str, float], Dict[str, float]]: |
|
|
""" |
|
|
Tag a single frame (PIL image). |
|
|
Returns: |
|
|
general_res: {tag -> score} |
|
|
character_res: {tag -> score} |
|
|
""" |
|
|
self._load_model_if_needed() |
|
|
|
|
|
input_tensor = self._prepare_image(image) |
|
|
input_name = self.model.get_inputs()[0].name |
|
|
output_name = self.model.get_outputs()[0].name |
|
|
|
|
|
preds = self.model.run([output_name], {input_name: input_tensor})[0] |
|
|
preds = preds[0].astype(float) |
|
|
|
|
|
labels = list(zip(self.tag_names, preds)) |
|
|
|
|
|
|
|
|
|
|
|
general_names = [labels[i] for i in self.general_indexes] |
|
|
general_res = { |
|
|
name: float(score) |
|
|
for name, score in general_names |
|
|
if score > general_thresh |
|
|
} |
|
|
|
|
|
|
|
|
character_names = [labels[i] for i in self.character_indexes] |
|
|
character_res = { |
|
|
name: float(score) |
|
|
for name, score in character_names |
|
|
if score > character_thresh |
|
|
} |
|
|
|
|
|
return general_res, character_res |
|
|
|
|
|
def _extract_tags_from_scores( |
|
|
self, |
|
|
preds: np.ndarray, |
|
|
general_thresh: float, |
|
|
character_thresh: float, |
|
|
) -> Tuple[Dict[str, float], Dict[str, float]]: |
|
|
""" |
|
|
Given a 1D preds array (num_tags,), return dicts of general/character tags. |
|
|
More efficient than rebuilding label tuples every time. |
|
|
""" |
|
|
|
|
|
preds = preds.astype(float) |
|
|
|
|
|
general_res: Dict[str, float] = {} |
|
|
character_res: Dict[str, float] = {} |
|
|
|
|
|
|
|
|
general_scores = preds[self.general_indexes] |
|
|
general_idx_array = np.array(self.general_indexes) |
|
|
general_mask = general_scores > general_thresh |
|
|
for idx, score in zip(general_idx_array[general_mask], general_scores[general_mask]): |
|
|
tag = self.tag_names[idx] |
|
|
general_res[tag] = float(score) |
|
|
|
|
|
|
|
|
character_scores = preds[self.character_indexes] |
|
|
character_idx_array = np.array(self.character_indexes) |
|
|
character_mask = character_scores > character_thresh |
|
|
for idx, score in zip(character_idx_array[character_mask], character_scores[character_mask]): |
|
|
tag = self.tag_names[idx] |
|
|
character_res[tag] = float(score) |
|
|
|
|
|
return general_res, character_res |
|
|
|
|
|
def tag_video( |
|
|
self, |
|
|
video_path: str, |
|
|
frame_interval: int, |
|
|
general_thresh: float, |
|
|
character_thresh: float, |
|
|
tag_substitutes: Dict[str, str], |
|
|
tag_exclusions: Set[str], |
|
|
progress=None, |
|
|
) -> Tuple[str, Dict]: |
|
|
""" |
|
|
Tag a video by sampling every N-th frame and aggregating tags. |
|
|
""" |
|
|
|
|
|
if not video_path or not os.path.exists(video_path): |
|
|
raise FileNotFoundError("Video file not found.") |
|
|
|
|
|
frame_interval = max(int(frame_interval), 1) |
|
|
is_first_load = self.model is None |
|
|
|
|
|
if progress is not None: |
|
|
progress(0.0, desc="Loading model..." if is_first_load else "Opening video...") |
|
|
|
|
|
|
|
|
self._load_model_if_needed() |
|
|
|
|
|
if progress is not None and is_first_load: |
|
|
progress(0.0, desc="Model loaded. Opening video...") |
|
|
|
|
|
cap = cv2.VideoCapture(video_path) |
|
|
if not cap.isOpened(): |
|
|
raise RuntimeError("Unable to open video file.") |
|
|
|
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0 |
|
|
if total_frames <= 0: |
|
|
total_frames = 1 |
|
|
|
|
|
|
|
|
sampled_frames = max(1, (total_frames + frame_interval - 1) // frame_interval) |
|
|
total_batches = max(1, (sampled_frames + self.batch_size - 1) // self.batch_size) |
|
|
recommended_batch = compute_recommended_batch_size(sampled_frames) |
|
|
|
|
|
aggregated_general: Dict[str, float] = {} |
|
|
aggregated_character: Dict[str, float] = {} |
|
|
|
|
|
frame_idx = 0 |
|
|
processed_frames = 0 |
|
|
batch_tensors: List[np.ndarray] = [] |
|
|
current_batch = 1 |
|
|
|
|
|
try: |
|
|
while True: |
|
|
ret, frame = cap.read() |
|
|
if not ret: |
|
|
break |
|
|
|
|
|
if frame_idx % frame_interval == 0: |
|
|
|
|
|
batch_tensors.append(self._prepare_frame_bgr(frame)) |
|
|
|
|
|
|
|
|
remaining_frames = sampled_frames - processed_frames |
|
|
current_batch_size = min(self.batch_size, remaining_frames) |
|
|
|
|
|
|
|
|
if progress is not None: |
|
|
pct = processed_frames / sampled_frames |
|
|
progress( |
|
|
pct, |
|
|
desc=( |
|
|
f"Preparing batch {current_batch}/{total_batches} " |
|
|
f"({len(batch_tensors)}/{current_batch_size} frames)" |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
if len(batch_tensors) >= self.batch_size: |
|
|
if progress is not None: |
|
|
beg = processed_frames + 1 |
|
|
end = processed_frames + len(batch_tensors) |
|
|
pct = processed_frames / sampled_frames |
|
|
progress( |
|
|
pct, |
|
|
desc=( |
|
|
f"Processing batch {current_batch}/{total_batches} " |
|
|
f"(frames {beg}-{end}/{sampled_frames})" |
|
|
), |
|
|
) |
|
|
|
|
|
done = self._run_batch_and_aggregate( |
|
|
batch_tensors, |
|
|
general_thresh, |
|
|
character_thresh, |
|
|
aggregated_general, |
|
|
aggregated_character, |
|
|
) |
|
|
|
|
|
processed_frames += done |
|
|
batch_tensors = [] |
|
|
if current_batch < total_batches: |
|
|
current_batch += 1 |
|
|
|
|
|
if progress is not None: |
|
|
pct = processed_frames / sampled_frames |
|
|
progress( |
|
|
pct, |
|
|
desc=( |
|
|
f"Completed batch {current_batch - 1}/{total_batches} " |
|
|
f"({processed_frames}/{sampled_frames} frames processed)" |
|
|
), |
|
|
) |
|
|
|
|
|
frame_idx += 1 |
|
|
|
|
|
finally: |
|
|
cap.release() |
|
|
|
|
|
|
|
|
if batch_tensors: |
|
|
if progress is not None: |
|
|
beg = processed_frames + 1 |
|
|
end = processed_frames + len(batch_tensors) |
|
|
pct = processed_frames / sampled_frames |
|
|
progress( |
|
|
pct, |
|
|
desc=( |
|
|
f"Processing final batch {current_batch}/{total_batches} " |
|
|
f"(frames {beg}-{end}/{sampled_frames})" |
|
|
), |
|
|
) |
|
|
|
|
|
done = self._run_batch_and_aggregate( |
|
|
batch_tensors, |
|
|
general_thresh, |
|
|
character_thresh, |
|
|
aggregated_general, |
|
|
aggregated_character, |
|
|
) |
|
|
processed_frames += done |
|
|
|
|
|
if progress is not None: |
|
|
pct = processed_frames / sampled_frames |
|
|
progress( |
|
|
pct, |
|
|
desc=( |
|
|
f"Completed batch {current_batch}/{total_batches} " |
|
|
f"({processed_frames}/{sampled_frames} frames processed)" |
|
|
), |
|
|
) |
|
|
|
|
|
if progress is not None: |
|
|
progress(1.0, desc="Finalizing tags...") |
|
|
|
|
|
|
|
|
all_tags_with_scores = {**aggregated_general, **aggregated_character} |
|
|
|
|
|
normalized_subs = {k.strip(): v.strip() for k, v in tag_substitutes.items() if k and v} |
|
|
normalized_exclusions = {t.strip() for t in tag_exclusions if t} |
|
|
|
|
|
adjusted_all_tags: Dict[str, float] = {} |
|
|
for tag, score in all_tags_with_scores.items(): |
|
|
original_tag = tag.strip() |
|
|
|
|
|
if original_tag in normalized_exclusions: |
|
|
continue |
|
|
|
|
|
new_tag = normalized_subs.get(original_tag, original_tag) |
|
|
|
|
|
if new_tag in normalized_exclusions: |
|
|
continue |
|
|
|
|
|
if new_tag not in adjusted_all_tags or score > adjusted_all_tags[new_tag]: |
|
|
adjusted_all_tags[new_tag] = score |
|
|
|
|
|
sorted_tags = sorted( |
|
|
adjusted_all_tags.items(), |
|
|
key=lambda kv: kv[1], |
|
|
reverse=True, |
|
|
) |
|
|
unique_tags = [tag for tag, _ in sorted_tags] |
|
|
|
|
|
combined_tags_str = ", ".join(unique_tags) |
|
|
|
|
|
debug_info = { |
|
|
"model_repo": self.model_repo, |
|
|
"frames_read": int(frame_idx), |
|
|
"frames_processed": int(processed_frames), |
|
|
"sampled_frames": int(sampled_frames), |
|
|
"total_batches": int(total_batches), |
|
|
"batch_size": int(self.batch_size), |
|
|
"recommended_batch_size": int(recommended_batch), |
|
|
"frame_interval": int(frame_interval), |
|
|
"general_threshold": float(general_thresh), |
|
|
"character_threshold": float(character_thresh), |
|
|
"num_general_tags_raw": len(aggregated_general), |
|
|
"num_character_tags_raw": len(aggregated_character), |
|
|
"total_unique_tags_after_control": len(unique_tags), |
|
|
"num_substitution_rules": len(normalized_subs), |
|
|
"num_exclusions": len(normalized_exclusions), |
|
|
} |
|
|
|
|
|
return combined_tags_str, debug_info |
|
|
|
|
|
|
|
|
|
|
|
_tagger_cache: Dict[str, VideoTagger] = {} |
|
|
|
|
|
|
|
|
def get_tagger(model_repo: str, batch_size: int | None = None) -> VideoTagger: |
|
|
""" |
|
|
Lazily create and cache a VideoTagger per model repo. |
|
|
Optionally update batch_size on an existing instance. |
|
|
""" |
|
|
tagger = _tagger_cache.get(model_repo) |
|
|
if tagger is None: |
|
|
|
|
|
tagger = VideoTagger(model_repo=model_repo, batch_size=batch_size or 8) |
|
|
_tagger_cache[model_repo] = tagger |
|
|
else: |
|
|
|
|
|
if batch_size is not None: |
|
|
tagger.batch_size = int(batch_size) |
|
|
|
|
|
return tagger |
|
|
|
|
|
|
|
|
def _normalize_tag_substitutes(data) -> Dict[str, str]: |
|
|
""" |
|
|
Convert Dataframe (as array: list[list]) into {original: substitute}. |
|
|
""" |
|
|
mapping: Dict[str, str] = {} |
|
|
if data is None: |
|
|
return mapping |
|
|
|
|
|
|
|
|
for row in data: |
|
|
if not row or len(row) < 2: |
|
|
continue |
|
|
orig = (row[0] or "").strip() |
|
|
sub = (row[1] or "").strip() |
|
|
if orig and sub: |
|
|
mapping[orig] = sub |
|
|
return mapping |
|
|
|
|
|
|
|
|
def _normalize_tag_exclusions(data) -> Set[str]: |
|
|
""" |
|
|
Convert Dataframe (as array: list[list]) into set of tags to exclude. |
|
|
""" |
|
|
exclusions: Set[str] = set() |
|
|
if data is None: |
|
|
return exclusions |
|
|
|
|
|
|
|
|
for row in data: |
|
|
if row is None: |
|
|
continue |
|
|
if isinstance(row, (list, tuple)): |
|
|
if not row: |
|
|
continue |
|
|
val = row[0] |
|
|
else: |
|
|
val = row |
|
|
val = (val or "").strip() |
|
|
if val: |
|
|
exclusions.add(val) |
|
|
return exclusions |
|
|
|
|
|
|
|
|
def tag_video_interface( |
|
|
video_path: str, |
|
|
frame_interval: int, |
|
|
general_thresh: float, |
|
|
character_thresh: float, |
|
|
model_repo: str, |
|
|
tag_substitutes_df, |
|
|
tag_exclusions_df, |
|
|
batch_size: int, |
|
|
progress=gr.Progress(track_tqdm=False), |
|
|
): |
|
|
if video_path is None: |
|
|
return "", {"error": "Please upload a video file."} |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
tagger = get_tagger(model_repo, batch_size=batch_size) |
|
|
|
|
|
tag_substitutes = _normalize_tag_substitutes(tag_substitutes_df) |
|
|
tag_exclusions = _normalize_tag_exclusions(tag_exclusions_df) |
|
|
|
|
|
combined_tags_str, debug_info = tagger.tag_video( |
|
|
video_path=video_path, |
|
|
frame_interval=frame_interval, |
|
|
general_thresh=general_thresh, |
|
|
character_thresh=character_thresh, |
|
|
tag_substitutes=tag_substitutes, |
|
|
tag_exclusions=tag_exclusions, |
|
|
progress=progress, |
|
|
) |
|
|
|
|
|
elapsed = time.time() - start_time |
|
|
debug_info["session_duration_seconds"] = round(elapsed, 3) |
|
|
debug_info["session_duration_hms"] = _format_duration(elapsed) |
|
|
|
|
|
return combined_tags_str, debug_info |
|
|
|
|
|
except Exception as e: |
|
|
return "", {"error": str(e)} |
|
|
|
|
|
|
|
|
with gr.Blocks(title=TITLE) as demo: |
|
|
|
|
|
gr.HTML( |
|
|
""" |
|
|
<style> |
|
|
.batch-loading { |
|
|
animation: batchPulse 1.2s ease-in-out infinite; |
|
|
color: #888888; |
|
|
} |
|
|
@keyframes batchPulse { |
|
|
0% { color: #666666; } |
|
|
50% { color: #bbbbbb; } |
|
|
100% { color: #666666; } |
|
|
} |
|
|
</style> |
|
|
""" |
|
|
) |
|
|
|
|
|
gr.Markdown(f"## {TITLE}") |
|
|
gr.Markdown(DESCRIPTION) |
|
|
|
|
|
with gr.Tabs(): |
|
|
|
|
|
with gr.Tab("Tagging", elem_id="tagging-tab"): |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
video_input = gr.Video( |
|
|
label="Video (.mp4 or .mov)", |
|
|
sources=["upload"], |
|
|
format="mp4", |
|
|
) |
|
|
|
|
|
model_choice = gr.Dropdown( |
|
|
choices=MODEL_OPTIONS, |
|
|
value=DEFAULT_MODEL_REPO, |
|
|
label="Tagging Model", |
|
|
) |
|
|
|
|
|
general_thresh = gr.Slider( |
|
|
minimum=0.0, |
|
|
maximum=1.0, |
|
|
step=0.01, |
|
|
value=0.35, |
|
|
label="General Tags Threshold", |
|
|
) |
|
|
|
|
|
character_thresh = gr.Slider( |
|
|
minimum=0.0, |
|
|
maximum=1.0, |
|
|
step=0.01, |
|
|
value=0.85, |
|
|
label="Character Tags Threshold", |
|
|
) |
|
|
|
|
|
gr.Markdown("### Processing") |
|
|
|
|
|
frame_interval = gr.Slider( |
|
|
minimum=1, |
|
|
maximum=60, |
|
|
step=1, |
|
|
value=10, |
|
|
label="Extract Every N Frames", |
|
|
info="For example, 10 = use every 10th frame.", |
|
|
) |
|
|
|
|
|
batch_size = gr.Slider( |
|
|
minimum=4, |
|
|
maximum=64, |
|
|
step=4, |
|
|
value=12, |
|
|
label="Batch Size", |
|
|
info=( |
|
|
"Larger batch sizes may increase initial loading time but can significantly " |
|
|
"improve total processing speed, especially for longer videos or high frame counts." |
|
|
), |
|
|
) |
|
|
|
|
|
batch_recommendation = gr.HTML( |
|
|
"<span>Upload a video to see a recommended batch size.</span>" |
|
|
) |
|
|
|
|
|
run_button = gr.Button("Generate Tags", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
combined_tags = gr.Textbox( |
|
|
label="Combined Unique Tags (All Frames)", |
|
|
lines=6, |
|
|
buttons=["copy"], |
|
|
) |
|
|
debug_info = gr.JSON( |
|
|
label="Details / Debug Info", |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
with gr.Tab("Tag Control", elem_id="tag-control-tab"): |
|
|
gr.Markdown("### Tag Substitutes") |
|
|
gr.Markdown( |
|
|
"Add rows where **Original Tag** will be replaced by **Substitute Tag** " |
|
|
"in the final combined output (after all frames are processed)." |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Column(): |
|
|
tag_substitutes_df = gr.Dataframe( |
|
|
headers=["Original Tag", "Substitute Tag"], |
|
|
datatype=["str", "str"], |
|
|
row_count=1, |
|
|
column_count=2, |
|
|
type="array", |
|
|
label="Tag Substitutes", |
|
|
interactive=True, |
|
|
) |
|
|
add_sub_row_btn = gr.Button("β Add substitute") |
|
|
|
|
|
gr.Markdown("### Tag Exclusions") |
|
|
gr.Markdown( |
|
|
"Add tags that should be **removed entirely** from the final combined output." |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Column(): |
|
|
tag_exclusions_df = gr.Dataframe( |
|
|
headers=["Tag to Exclude"], |
|
|
datatype=["str"], |
|
|
row_count=1, |
|
|
column_count=1, |
|
|
type="array", |
|
|
label="Tag Exclusions", |
|
|
interactive=True, |
|
|
) |
|
|
add_ex_row_btn = gr.Button("β Add exclusion") |
|
|
|
|
|
|
|
|
add_sub_row_btn.click( |
|
|
fn=add_substitute_row, |
|
|
inputs=tag_substitutes_df, |
|
|
outputs=tag_substitutes_df, |
|
|
) |
|
|
|
|
|
add_ex_row_btn.click( |
|
|
fn=add_exclusion_row, |
|
|
inputs=tag_exclusions_df, |
|
|
outputs=tag_exclusions_df, |
|
|
) |
|
|
|
|
|
|
|
|
video_input.change( |
|
|
fn=show_batch_loading, |
|
|
inputs=[], |
|
|
outputs=batch_recommendation, |
|
|
).then( |
|
|
fn=update_batch_recommendation, |
|
|
inputs=[video_input, frame_interval], |
|
|
outputs=batch_recommendation, |
|
|
) |
|
|
|
|
|
frame_interval.change( |
|
|
fn=show_batch_loading, |
|
|
inputs=[], |
|
|
outputs=batch_recommendation, |
|
|
).then( |
|
|
fn=update_batch_recommendation, |
|
|
inputs=[video_input, frame_interval], |
|
|
outputs=batch_recommendation, |
|
|
) |
|
|
|
|
|
run_button.click( |
|
|
fn=tag_video_interface, |
|
|
inputs=[ |
|
|
video_input, |
|
|
frame_interval, |
|
|
general_thresh, |
|
|
character_thresh, |
|
|
model_choice, |
|
|
tag_substitutes_df, |
|
|
tag_exclusions_df, |
|
|
batch_size, |
|
|
], |
|
|
outputs=[combined_tags, debug_info], |
|
|
) |
|
|
|
|
|
custom_theme = gr.themes.Default( |
|
|
primary_hue=gr.themes.colors.blue, |
|
|
secondary_hue=gr.themes.colors.slate, |
|
|
radius_size=gr.themes.sizes.radius_xxl, |
|
|
font=[gr.themes.GoogleFont("Raleway")], |
|
|
) |
|
|
|
|
|
|
|
|
demo.queue(max_size=4).launch( |
|
|
theme=custom_theme, |
|
|
css=css, |
|
|
) |
|
|
|