throaway2854 commited on
Commit
2e12519
·
verified ·
1 Parent(s): 2ffce11

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -31
app.py CHANGED
@@ -399,51 +399,76 @@ class VideoTagger:
399
  ) -> Tuple[str, Dict]:
400
  """
401
  Tag a video by sampling every N-th frame and aggregating tags.
402
-
403
  Returns:
404
  combined_tags_str: one unique comma-separated tag string
405
  debug_info: dict with some stats
406
  """
407
  if not video_path or not os.path.exists(video_path):
408
  raise FileNotFoundError("Video file not found.")
409
-
410
  frame_interval = max(int(frame_interval), 1)
411
-
 
 
 
 
 
 
 
 
 
 
412
  self._load_model_if_needed()
413
-
414
  if progress is not None:
415
- progress(0.0, desc="Opening video...")
416
-
 
 
 
 
 
417
  cap = cv2.VideoCapture(video_path)
418
  if not cap.isOpened():
419
  raise RuntimeError("Unable to open video file.")
420
-
421
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0
422
  if total_frames <= 0:
423
  total_frames = 1
424
-
425
  frames_to_process = max(1, (total_frames + frame_interval - 1) // frame_interval)
426
-
427
  aggregated_general: Dict[str, float] = {}
428
  aggregated_character: Dict[str, float] = {}
429
-
430
  frame_idx = 0
431
  processed_frames = 0
432
-
433
  batch_tensors: List[np.ndarray] = []
434
-
435
  try:
436
  while True:
437
  ret, frame = cap.read()
438
  if not ret:
439
  break
440
-
441
  # Only process every N-th frame
442
  if frame_idx % frame_interval == 0:
443
  # frame is BGR uint8 from OpenCV
444
  arr = self._prepare_frame_bgr(frame) # (H, W, 3) float32
445
  batch_tensors.append(arr)
446
-
 
 
 
 
 
 
 
 
 
 
 
447
  # If batch is full, run inference
448
  if len(batch_tensors) >= self.batch_size:
449
  num_done = self._run_batch_and_aggregate(
@@ -455,18 +480,21 @@ class VideoTagger:
455
  )
456
  processed_frames += num_done
457
  batch_tensors = []
458
-
459
  if progress is not None:
460
  ratio = min(processed_frames / frames_to_process, 0.99)
461
  progress(
462
  ratio,
463
- desc=f"Processing frames {processed_frames}/{frames_to_process}...",
 
 
 
464
  )
465
-
466
  frame_idx += 1
467
  finally:
468
  cap.release()
469
-
470
  # Process any leftover frames in the last partial batch
471
  if batch_tensors:
472
  num_done = self._run_batch_and_aggregate(
@@ -477,42 +505,52 @@ class VideoTagger:
477
  aggregated_character=aggregated_character,
478
  )
479
  processed_frames += num_done
480
-
 
 
 
 
 
 
 
 
 
 
481
  if progress is not None:
482
  progress(1.0, desc="Finalizing tags...")
483
-
484
  # Merge character + general tags, sorted by score (desc)
485
  all_tags_with_scores = {**aggregated_general, **aggregated_character}
486
-
487
  # Apply substitutions & exclusions BEFORE final dedup
488
  adjusted_all_tags: Dict[str, float] = {}
489
-
490
  normalized_subs = {k.strip(): v.strip() for k, v in tag_substitutes.items() if k and v}
491
  normalized_exclusions = {t.strip() for t in tag_exclusions if t}
492
-
493
  for tag, score in all_tags_with_scores.items():
494
  original_tag = tag.strip()
495
-
496
  if original_tag in normalized_exclusions:
497
  continue
498
-
499
  new_tag = normalized_subs.get(original_tag, original_tag)
500
-
501
  if new_tag in normalized_exclusions:
502
  continue
503
-
504
  if new_tag not in adjusted_all_tags or score > adjusted_all_tags[new_tag]:
505
  adjusted_all_tags[new_tag] = score
506
-
507
  sorted_tags = sorted(
508
  adjusted_all_tags.items(),
509
  key=lambda kv: kv[1],
510
  reverse=True,
511
  )
512
  unique_tags = [tag for tag, _ in sorted_tags]
513
-
514
  combined_tags_str = ", ".join(unique_tags)
515
-
516
  debug_info = {
517
  "model_repo": self.model_repo,
518
  "frames_read": int(frame_idx),
@@ -529,7 +567,7 @@ class VideoTagger:
529
  "num_exclusions": len(normalized_exclusions),
530
  "batch_size": int(self.batch_size),
531
  }
532
-
533
  return combined_tags_str, debug_info
534
 
535
 
 
399
  ) -> Tuple[str, Dict]:
400
  """
401
  Tag a video by sampling every N-th frame and aggregating tags.
402
+
403
  Returns:
404
  combined_tags_str: one unique comma-separated tag string
405
  debug_info: dict with some stats
406
  """
407
  if not video_path or not os.path.exists(video_path):
408
  raise FileNotFoundError("Video file not found.")
409
+
410
  frame_interval = max(int(frame_interval), 1)
411
+
412
+ # Detect if this is the first time the model is being loaded
413
+ is_first_load = self.model is None
414
+
415
+ if progress is not None:
416
+ if is_first_load:
417
+ progress(0.0, desc="Loading model (first run may take a while)...")
418
+ else:
419
+ progress(0.0, desc="Opening video...")
420
+
421
+ # Lazy-load model and labels once per process
422
  self._load_model_if_needed()
423
+
424
  if progress is not None:
425
+ if is_first_load:
426
+ # Model just finished loading
427
+ progress(0.0, desc="Model loaded, opening video...")
428
+ else:
429
+ # Keep the message but make clear we're past model loading
430
+ progress(0.0, desc="Opening video...")
431
+
432
  cap = cv2.VideoCapture(video_path)
433
  if not cap.isOpened():
434
  raise RuntimeError("Unable to open video file.")
435
+
436
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0
437
  if total_frames <= 0:
438
  total_frames = 1
439
+
440
  frames_to_process = max(1, (total_frames + frame_interval - 1) // frame_interval)
441
+
442
  aggregated_general: Dict[str, float] = {}
443
  aggregated_character: Dict[str, float] = {}
444
+
445
  frame_idx = 0
446
  processed_frames = 0
 
447
  batch_tensors: List[np.ndarray] = []
448
+
449
  try:
450
  while True:
451
  ret, frame = cap.read()
452
  if not ret:
453
  break
454
+
455
  # Only process every N-th frame
456
  if frame_idx % frame_interval == 0:
457
  # frame is BGR uint8 from OpenCV
458
  arr = self._prepare_frame_bgr(frame) # (H, W, 3) float32
459
  batch_tensors.append(arr)
460
+
461
+ # While building the FIRST batch, keep user informed
462
+ if progress is not None and processed_frames == 0:
463
+ frames_in_first_batch = min(self.batch_size, frames_to_process)
464
+ progress(
465
+ 0.0,
466
+ desc=(
467
+ f"Collecting frames for first batch "
468
+ f"({len(batch_tensors)}/{frames_in_first_batch})..."
469
+ ),
470
+ )
471
+
472
  # If batch is full, run inference
473
  if len(batch_tensors) >= self.batch_size:
474
  num_done = self._run_batch_and_aggregate(
 
480
  )
481
  processed_frames += num_done
482
  batch_tensors = []
483
+
484
  if progress is not None:
485
  ratio = min(processed_frames / frames_to_process, 0.99)
486
  progress(
487
  ratio,
488
+ desc=(
489
+ f"Processing frames {processed_frames}/"
490
+ f"{frames_to_process}..."
491
+ ),
492
  )
493
+
494
  frame_idx += 1
495
  finally:
496
  cap.release()
497
+
498
  # Process any leftover frames in the last partial batch
499
  if batch_tensors:
500
  num_done = self._run_batch_and_aggregate(
 
505
  aggregated_character=aggregated_character,
506
  )
507
  processed_frames += num_done
508
+
509
+ if progress is not None:
510
+ ratio = min(processed_frames / frames_to_process, 0.99)
511
+ progress(
512
+ ratio,
513
+ desc=(
514
+ f"Processing frames {processed_frames}/"
515
+ f"{frames_to_process} (final batch)..."
516
+ ),
517
+ )
518
+
519
  if progress is not None:
520
  progress(1.0, desc="Finalizing tags...")
521
+
522
  # Merge character + general tags, sorted by score (desc)
523
  all_tags_with_scores = {**aggregated_general, **aggregated_character}
524
+
525
  # Apply substitutions & exclusions BEFORE final dedup
526
  adjusted_all_tags: Dict[str, float] = {}
527
+
528
  normalized_subs = {k.strip(): v.strip() for k, v in tag_substitutes.items() if k and v}
529
  normalized_exclusions = {t.strip() for t in tag_exclusions if t}
530
+
531
  for tag, score in all_tags_with_scores.items():
532
  original_tag = tag.strip()
533
+
534
  if original_tag in normalized_exclusions:
535
  continue
536
+
537
  new_tag = normalized_subs.get(original_tag, original_tag)
538
+
539
  if new_tag in normalized_exclusions:
540
  continue
541
+
542
  if new_tag not in adjusted_all_tags or score > adjusted_all_tags[new_tag]:
543
  adjusted_all_tags[new_tag] = score
544
+
545
  sorted_tags = sorted(
546
  adjusted_all_tags.items(),
547
  key=lambda kv: kv[1],
548
  reverse=True,
549
  )
550
  unique_tags = [tag for tag, _ in sorted_tags]
551
+
552
  combined_tags_str = ", ".join(unique_tags)
553
+
554
  debug_info = {
555
  "model_repo": self.model_repo,
556
  "frames_read": int(frame_idx),
 
567
  "num_exclusions": len(normalized_exclusions),
568
  "batch_size": int(self.batch_size),
569
  }
570
+
571
  return combined_tags_str, debug_info
572
 
573