throaway2854 commited on
Commit
53617e7
·
verified ·
1 Parent(s): f8ca038

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -9
app.py CHANGED
@@ -9,17 +9,28 @@ import onnxruntime as rt
9
  import pandas as pd
10
  from PIL import Image
11
 
12
- TITLE = "Video Tagger (SmilingWolf/wd-eva02-large-tagger-v3)"
13
  DESCRIPTION = """
14
  Upload a .mp4 or .mov video, choose how often to sample frames, and generate
15
- combined (deduplicated) tags using **SmilingWolf/wd-eva02-large-tagger-v3**.
16
 
17
  - Extract every N-th frame (e.g., every 10th frame).
18
  - Control thresholds for **General Tags** and **Character Tags**.
19
  - All tags from all sampled frames are merged into **one unique, comma-separated string**.
20
  """
21
 
22
- MODEL_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
 
 
 
 
 
 
 
 
 
 
 
23
  MODEL_FILENAME = "model.onnx"
24
  LABEL_FILENAME = "selected_tags.csv"
25
 
@@ -74,11 +85,11 @@ def load_labels(df: pd.DataFrame):
74
 
75
  class VideoTagger:
76
  """
77
- Wraps wd-eva02-large-tagger-v3 ONNX model and tag metadata,
78
  and exposes helpers to tag PIL images and full videos.
79
  """
80
 
81
- def __init__(self, model_repo: str = MODEL_REPO):
82
  self.model_repo = model_repo
83
  self.model = None
84
  self.model_target_size = None # will be set from ONNX input shape
@@ -207,6 +218,7 @@ class VideoTagger:
207
  frame_interval: int,
208
  general_thresh: float,
209
  character_thresh: float,
 
210
  ) -> Tuple[str, Dict]:
211
  """
212
  Tag a video by sampling every N-th frame and aggregating tags.
@@ -220,10 +232,20 @@ class VideoTagger:
220
 
221
  frame_interval = max(int(frame_interval), 1)
222
 
 
 
 
223
  cap = cv2.VideoCapture(video_path)
224
  if not cap.isOpened():
225
  raise RuntimeError("Unable to open video file.")
226
 
 
 
 
 
 
 
 
227
  # Store max score seen for each tag across all frames
228
  aggregated_general: Dict[str, float] = {}
229
  aggregated_character: Dict[str, float] = {}
@@ -260,10 +282,20 @@ class VideoTagger:
260
 
261
  processed_frames += 1
262
 
 
 
 
 
 
 
 
263
  frame_idx += 1
264
  finally:
265
  cap.release()
266
 
 
 
 
267
  # Merge character + general tags, sorted by score (desc)
268
  all_tags_with_scores = {**aggregated_general, **aggregated_character}
269
  sorted_tags = sorted(
@@ -276,8 +308,11 @@ class VideoTagger:
276
  combined_tags_str = ", ".join(unique_tags)
277
 
278
  debug_info = {
 
279
  "frames_read": int(frame_idx),
280
  "frames_processed": int(processed_frames),
 
 
281
  "num_general_tags": len(aggregated_general),
282
  "num_character_tags": len(aggregated_character),
283
  "total_unique_tags": len(unique_tags),
@@ -289,8 +324,14 @@ class VideoTagger:
289
  return combined_tags_str, debug_info
290
 
291
 
292
- # Global model instance (loaded once per Space)
293
- video_tagger = VideoTagger()
 
 
 
 
 
 
294
 
295
 
296
  def tag_video_interface(
@@ -298,16 +339,20 @@ def tag_video_interface(
298
  frame_interval: int,
299
  general_thresh: float,
300
  character_thresh: float,
 
 
301
  ):
302
  if video_path is None:
303
  return "", {"error": "Please upload a video file."}
304
 
305
  try:
306
- return video_tagger.tag_video(
 
307
  video_path=video_path,
308
  frame_interval=frame_interval,
309
  general_thresh=general_thresh,
310
  character_thresh=character_thresh,
 
311
  )
312
  except Exception as e:
313
  return "", {"error": str(e)}
@@ -325,6 +370,12 @@ with gr.Blocks(title=TITLE) as demo:
325
  format="mp4",
326
  )
327
 
 
 
 
 
 
 
328
  frame_interval = gr.Slider(
329
  minimum=1,
330
  maximum=60,
@@ -363,7 +414,7 @@ with gr.Blocks(title=TITLE) as demo:
363
 
364
  run_button.click(
365
  fn=tag_video_interface,
366
- inputs=[video_input, frame_interval, general_thresh, character_thresh],
367
  outputs=[combined_tags, debug_info],
368
  )
369
 
 
9
  import pandas as pd
10
  from PIL import Image
11
 
12
+ TITLE = "Video Tagger (WD Tagger Variants)"
13
  DESCRIPTION = """
14
  Upload a .mp4 or .mov video, choose how often to sample frames, and generate
15
+ combined (deduplicated) tags using a selected **WD-style tagging model**.
16
 
17
  - Extract every N-th frame (e.g., every 10th frame).
18
  - Control thresholds for **General Tags** and **Character Tags**.
19
  - All tags from all sampled frames are merged into **one unique, comma-separated string**.
20
  """
21
 
22
+ DEFAULT_MODEL_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
23
+
24
+ MODEL_OPTIONS = [
25
+ "SmilingWolf/wd-eva02-large-tagger-v3",
26
+ "SmilingWolf/wd-vit-large-tagger-v3",
27
+ "SmilingWolf/wd-vit-tagger-v3",
28
+ "SmilingWolf/wd-convnext-tagger-v3",
29
+ "SmilingWolf/wd-swinv2-tagger-v3",
30
+ "deepghs/idolsankaku-eva02-large-tagger-v1",
31
+ "deepghs/idolsankaku-swinv2-tagger-v1",
32
+ ]
33
+
34
  MODEL_FILENAME = "model.onnx"
35
  LABEL_FILENAME = "selected_tags.csv"
36
 
 
85
 
86
  class VideoTagger:
87
  """
88
+ Wraps a WD-style ONNX model and tag metadata,
89
  and exposes helpers to tag PIL images and full videos.
90
  """
91
 
92
+ def __init__(self, model_repo: str):
93
  self.model_repo = model_repo
94
  self.model = None
95
  self.model_target_size = None # will be set from ONNX input shape
 
218
  frame_interval: int,
219
  general_thresh: float,
220
  character_thresh: float,
221
+ progress=None,
222
  ) -> Tuple[str, Dict]:
223
  """
224
  Tag a video by sampling every N-th frame and aggregating tags.
 
232
 
233
  frame_interval = max(int(frame_interval), 1)
234
 
235
+ if progress is not None:
236
+ progress(0.0, desc="Opening video...")
237
+
238
  cap = cv2.VideoCapture(video_path)
239
  if not cap.isOpened():
240
  raise RuntimeError("Unable to open video file.")
241
 
242
+ # Estimate total frames and how many will be processed
243
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0
244
+ if total_frames <= 0:
245
+ total_frames = 1 # avoid division by zero / weird metadata
246
+
247
+ frames_to_process = max(1, (total_frames + frame_interval - 1) // frame_interval)
248
+
249
  # Store max score seen for each tag across all frames
250
  aggregated_general: Dict[str, float] = {}
251
  aggregated_character: Dict[str, float] = {}
 
282
 
283
  processed_frames += 1
284
 
285
+ if progress is not None:
286
+ ratio = min(processed_frames / frames_to_process, 0.99)
287
+ progress(
288
+ ratio,
289
+ desc=f"Processing frame {processed_frames}/{frames_to_process}...",
290
+ )
291
+
292
  frame_idx += 1
293
  finally:
294
  cap.release()
295
 
296
+ if progress is not None:
297
+ progress(1.0, desc="Finalizing tags...")
298
+
299
  # Merge character + general tags, sorted by score (desc)
300
  all_tags_with_scores = {**aggregated_general, **aggregated_character}
301
  sorted_tags = sorted(
 
308
  combined_tags_str = ", ".join(unique_tags)
309
 
310
  debug_info = {
311
+ "model_repo": self.model_repo,
312
  "frames_read": int(frame_idx),
313
  "frames_processed": int(processed_frames),
314
+ "estimated_total_frames": int(total_frames),
315
+ "estimated_frames_to_process": int(frames_to_process),
316
  "num_general_tags": len(aggregated_general),
317
  "num_character_tags": len(aggregated_character),
318
  "total_unique_tags": len(unique_tags),
 
324
  return combined_tags_str, debug_info
325
 
326
 
327
+ # Cache of VideoTagger instances per model repo
328
+ _tagger_cache: Dict[str, VideoTagger] = {}
329
+
330
+
331
+ def get_tagger(model_repo: str) -> VideoTagger:
332
+ if model_repo not in _tagger_cache:
333
+ _tagger_cache[model_repo] = VideoTagger(model_repo=model_repo)
334
+ return _tagger_cache[model_repo]
335
 
336
 
337
  def tag_video_interface(
 
339
  frame_interval: int,
340
  general_thresh: float,
341
  character_thresh: float,
342
+ model_repo: str,
343
+ progress=gr.Progress(track_tqdm=False),
344
  ):
345
  if video_path is None:
346
  return "", {"error": "Please upload a video file."}
347
 
348
  try:
349
+ tagger = get_tagger(model_repo)
350
+ return tagger.tag_video(
351
  video_path=video_path,
352
  frame_interval=frame_interval,
353
  general_thresh=general_thresh,
354
  character_thresh=character_thresh,
355
+ progress=progress,
356
  )
357
  except Exception as e:
358
  return "", {"error": str(e)}
 
370
  format="mp4",
371
  )
372
 
373
+ model_choice = gr.Dropdown(
374
+ choices=MODEL_OPTIONS,
375
+ value=DEFAULT_MODEL_REPO,
376
+ label="Tagging Model",
377
+ )
378
+
379
  frame_interval = gr.Slider(
380
  minimum=1,
381
  maximum=60,
 
414
 
415
  run_button.click(
416
  fn=tag_video_interface,
417
+ inputs=[video_input, frame_interval, general_thresh, character_thresh, model_choice],
418
  outputs=[combined_tags, debug_info],
419
  )
420