throaway2854 commited on
Commit
f1419d3
·
verified ·
1 Parent(s): 82b4129

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -91
app.py CHANGED
@@ -399,33 +399,22 @@ class VideoTagger:
399
  ) -> Tuple[str, Dict]:
400
  """
401
  Tag a video by sampling every N-th frame and aggregating tags.
402
-
403
- Returns:
404
- combined_tags_str: one unique comma-separated tag string
405
- debug_info: dict with some stats
406
  """
 
407
  if not video_path or not os.path.exists(video_path):
408
  raise FileNotFoundError("Video file not found.")
409
 
410
  frame_interval = max(int(frame_interval), 1)
411
-
412
- # Detect if this is the first time the model is being loaded
413
  is_first_load = self.model is None
414
 
415
- if progress is not None:
416
- if is_first_load:
417
- progress(0.0, desc="Loading model (first run may take a while)...")
418
- else:
419
- progress(0.0, desc="Opening video...")
420
 
421
- # Lazy-load model and labels once per process
422
  self._load_model_if_needed()
423
 
424
- if progress is not None:
425
- if is_first_load:
426
- progress(0.0, desc="Model loaded, opening video...")
427
- else:
428
- progress(0.0, desc="Opening video...")
429
 
430
  cap = cv2.VideoCapture(video_path)
431
  if not cap.isOpened():
@@ -435,15 +424,17 @@ class VideoTagger:
435
  if total_frames <= 0:
436
  total_frames = 1
437
 
438
- # Number of frames we will actually process (sampled every N frames)
439
- frames_to_process = max(1, (total_frames + frame_interval - 1) // frame_interval)
 
440
 
441
  aggregated_general: Dict[str, float] = {}
442
  aggregated_character: Dict[str, float] = {}
443
 
444
- frame_idx = 0 # index over all video frames
445
- processed_frames = 0 # count of sampled frames fully processed by the model
446
  batch_tensors: List[np.ndarray] = []
 
447
 
448
  try:
449
  while True:
@@ -452,115 +443,109 @@ class VideoTagger:
452
  break
453
 
454
  if frame_idx % frame_interval == 0:
455
- # This is a sampled frame
456
- sampled_index = processed_frames + len(batch_tensors) + 1 # 1-based index among sampled frames
457
- arr = self._prepare_frame_bgr(frame) # (H, W, 3) float32
458
- batch_tensors.append(arr)
459
-
460
- if progress is not None:
461
- # Show which sampled frame we're preparing, and which raw video frame it is.
462
- ratio = min(
463
- (processed_frames + len(batch_tensors)) / frames_to_process,
464
- 0.99,
465
- )
466
  progress(
467
- ratio,
468
  desc=(
469
- f"Preparing sampled frame {sampled_index}/{frames_to_process} "
470
- f"(video frame {frame_idx + 1}/{total_frames})..."
471
  ),
472
  )
473
 
474
- # If batch is full, run inference on it
475
  if len(batch_tensors) >= self.batch_size:
476
- # Inform the user we're now running the model on this batch
477
- if progress is not None:
478
- start_sample = processed_frames + 1
479
- end_sample = processed_frames + len(batch_tensors)
480
- ratio = min(
481
- (processed_frames + len(batch_tensors)) / frames_to_process,
482
- 0.99,
483
- )
484
  progress(
485
- ratio,
486
  desc=(
487
- f"Running model on batch: sampled frames "
488
- f"{start_sample}-{end_sample}/{frames_to_process}..."
489
  ),
490
  )
491
 
492
- num_done = self._run_batch_and_aggregate(
493
  batch_tensors,
494
- general_thresh=general_thresh,
495
- character_thresh=character_thresh,
496
- aggregated_general=aggregated_general,
497
- aggregated_character=aggregated_character,
498
  )
499
- processed_frames += num_done
 
500
  batch_tensors = []
 
 
501
 
502
- if progress is not None:
503
- ratio = min(processed_frames / frames_to_process, 0.99)
504
  progress(
505
- ratio,
506
  desc=(
507
- f"Finished processing sampled frames "
508
- f"{processed_frames}/{frames_to_process}..."
509
  ),
510
  )
511
 
512
  frame_idx += 1
 
513
  finally:
514
  cap.release()
515
 
516
- # Process any leftover frames in the last partial batch
517
  if batch_tensors:
518
- if progress is not None:
519
- start_sample = processed_frames + 1
520
- end_sample = processed_frames + len(batch_tensors)
521
- ratio = min(
522
- (processed_frames + len(batch_tensors)) / frames_to_process,
523
- 0.99,
524
- )
525
  progress(
526
- ratio,
527
  desc=(
528
- f"Running model on final batch: sampled frames "
529
- f"{start_sample}-{end_sample}/{frames_to_process}..."
530
  ),
531
  )
532
 
533
- num_done = self._run_batch_and_aggregate(
534
  batch_tensors,
535
- general_thresh=general_thresh,
536
- character_thresh=character_thresh,
537
- aggregated_general=aggregated_general,
538
- aggregated_character=aggregated_character,
539
  )
540
- processed_frames += num_done
541
 
542
- if progress is not None:
543
- ratio = min(processed_frames / frames_to_process, 0.99)
544
  progress(
545
- ratio,
546
  desc=(
547
- f"Finished processing all sampled frames "
548
- f"{processed_frames}/{frames_to_process}..."
549
  ),
550
  )
551
 
552
- if progress is not None:
553
  progress(1.0, desc="Finalizing tags...")
554
 
555
- # Merge character + general tags, sorted by score (desc)
556
  all_tags_with_scores = {**aggregated_general, **aggregated_character}
557
 
558
- # Apply substitutions & exclusions BEFORE final dedup
559
- adjusted_all_tags: Dict[str, float] = {}
560
-
561
  normalized_subs = {k.strip(): v.strip() for k, v in tag_substitutes.items() if k and v}
562
  normalized_exclusions = {t.strip() for t in tag_exclusions if t}
563
 
 
564
  for tag, score in all_tags_with_scores.items():
565
  original_tag = tag.strip()
566
 
@@ -588,17 +573,17 @@ class VideoTagger:
588
  "model_repo": self.model_repo,
589
  "frames_read": int(frame_idx),
590
  "frames_processed": int(processed_frames),
591
- "estimated_total_frames": int(total_frames),
592
- "estimated_frames_to_process": int(frames_to_process),
593
- "num_general_tags_raw": len(aggregated_general),
594
- "num_character_tags_raw": len(aggregated_character),
595
- "total_unique_tags_after_control": len(unique_tags),
596
  "frame_interval": int(frame_interval),
597
  "general_threshold": float(general_thresh),
598
  "character_threshold": float(character_thresh),
 
 
 
599
  "num_substitution_rules": len(normalized_subs),
600
  "num_exclusions": len(normalized_exclusions),
601
- "batch_size": int(self.batch_size),
602
  }
603
 
604
  return combined_tags_str, debug_info
 
399
  ) -> Tuple[str, Dict]:
400
  """
401
  Tag a video by sampling every N-th frame and aggregating tags.
 
 
 
 
402
  """
403
+
404
  if not video_path or not os.path.exists(video_path):
405
  raise FileNotFoundError("Video file not found.")
406
 
407
  frame_interval = max(int(frame_interval), 1)
 
 
408
  is_first_load = self.model is None
409
 
410
+ if progress:
411
+ progress(0.0, desc="Loading model..." if is_first_load else "Opening video...")
 
 
 
412
 
413
+ # Lazy-load model & labels once per process
414
  self._load_model_if_needed()
415
 
416
+ if progress and is_first_load:
417
+ progress(0.0, desc="Model loaded. Opening video...")
 
 
 
418
 
419
  cap = cv2.VideoCapture(video_path)
420
  if not cap.isOpened():
 
424
  if total_frames <= 0:
425
  total_frames = 1
426
 
427
+ # How many frames we will actually process (sampled every N frames)
428
+ sampled_frames = max(1, (total_frames + frame_interval - 1) // frame_interval)
429
+ total_batches = max(1, (sampled_frames + self.batch_size - 1) // self.batch_size)
430
 
431
  aggregated_general: Dict[str, float] = {}
432
  aggregated_character: Dict[str, float] = {}
433
 
434
+ frame_idx = 0 # raw video frame index
435
+ processed_frames = 0 # sampled frames fully processed by the model
436
  batch_tensors: List[np.ndarray] = []
437
+ current_batch = 1
438
 
439
  try:
440
  while True:
 
443
  break
444
 
445
  if frame_idx % frame_interval == 0:
446
+ # This is a sampled frame – add to current batch
447
+ batch_tensors.append(self._prepare_frame_bgr(frame))
448
+
449
+ # For the current batch, compute how many sampled frames it *should* contain
450
+ remaining_frames = sampled_frames - processed_frames
451
+ current_batch_size = min(self.batch_size, remaining_frames)
452
+
453
+ # While we are still building the batch, keep percent based on *completed* frames only
454
+ if progress:
455
+ pct = processed_frames / sampled_frames
 
456
  progress(
457
+ pct,
458
  desc=(
459
+ f"Preparing batch {current_batch}/{total_batches} "
460
+ f"({len(batch_tensors)}/{current_batch_size} frames)..."
461
  ),
462
  )
463
 
464
+ # If batch is full, run inference
465
  if len(batch_tensors) >= self.batch_size:
466
+ if progress:
467
+ beg = processed_frames + 1
468
+ end = processed_frames + len(batch_tensors)
469
+ pct = processed_frames / sampled_frames # still only count completed frames
 
 
 
 
470
  progress(
471
+ pct,
472
  desc=(
473
+ f"Processing batch {current_batch}/{total_batches} "
474
+ f"(frames {beg}-{end}/{sampled_frames})..."
475
  ),
476
  )
477
 
478
+ done = self._run_batch_and_aggregate(
479
  batch_tensors,
480
+ general_thresh,
481
+ character_thresh,
482
+ aggregated_general,
483
+ aggregated_character,
484
  )
485
+
486
+ processed_frames += done
487
  batch_tensors = []
488
+ if current_batch < total_batches:
489
+ current_batch += 1
490
 
491
+ if progress:
492
+ pct = processed_frames / sampled_frames
493
  progress(
494
+ pct,
495
  desc=(
496
+ f"Completed batch {current_batch - 1}/{total_batches} "
497
+ f"({processed_frames}/{sampled_frames} frames processed)"
498
  ),
499
  )
500
 
501
  frame_idx += 1
502
+
503
  finally:
504
  cap.release()
505
 
506
+ # Process any leftover frames in the final partial batch
507
  if batch_tensors:
508
+ if progress:
509
+ beg = processed_frames + 1
510
+ end = processed_frames + len(batch_tensors)
511
+ pct = processed_frames / sampled_frames # still only completed frames
 
 
 
512
  progress(
513
+ pct,
514
  desc=(
515
+ f"Processing final batch {current_batch}/{total_batches} "
516
+ f"(frames {beg}-{end}/{sampled_frames})..."
517
  ),
518
  )
519
 
520
+ done = self._run_batch_and_aggregate(
521
  batch_tensors,
522
+ general_thresh,
523
+ character_thresh,
524
+ aggregated_general,
525
+ aggregated_character,
526
  )
527
+ processed_frames += done
528
 
529
+ if progress:
530
+ pct = processed_frames / sampled_frames
531
  progress(
532
+ pct,
533
  desc=(
534
+ f"Completed batch {current_batch}/{total_batches} "
535
+ f"({processed_frames}/{sampled_frames} frames processed)"
536
  ),
537
  )
538
 
539
+ if progress:
540
  progress(1.0, desc="Finalizing tags...")
541
 
542
+ # Merge & finalize tags
543
  all_tags_with_scores = {**aggregated_general, **aggregated_character}
544
 
 
 
 
545
  normalized_subs = {k.strip(): v.strip() for k, v in tag_substitutes.items() if k and v}
546
  normalized_exclusions = {t.strip() for t in tag_exclusions if t}
547
 
548
+ adjusted_all_tags: Dict[str, float] = {}
549
  for tag, score in all_tags_with_scores.items():
550
  original_tag = tag.strip()
551
 
 
573
  "model_repo": self.model_repo,
574
  "frames_read": int(frame_idx),
575
  "frames_processed": int(processed_frames),
576
+ "sampled_frames": int(sampled_frames),
577
+ "total_batches": int(total_batches),
578
+ "batch_size": int(self.batch_size),
 
 
579
  "frame_interval": int(frame_interval),
580
  "general_threshold": float(general_thresh),
581
  "character_threshold": float(character_thresh),
582
+ "num_general_tags_raw": len(aggregated_general),
583
+ "num_character_tags_raw": len(aggregated_character),
584
+ "total_unique_tags_after_control": len(unique_tags),
585
  "num_substitution_rules": len(normalized_subs),
586
  "num_exclusions": len(normalized_exclusions),
 
587
  }
588
 
589
  return combined_tags_str, debug_info