Chaitanya-aitf commited on
Commit
374f92b
·
verified ·
1 Parent(s): c71de09

Update pipeline/orchestrator.py

Browse files
Files changed (1) hide show
  1. pipeline/orchestrator.py +132 -5
pipeline/orchestrator.py CHANGED
@@ -14,7 +14,7 @@ Manages the flow between all components:
14
  """
15
 
16
  from pathlib import Path
17
- from typing import List, Optional, Callable, Dict, Any, Generator
18
  from dataclasses import dataclass, field
19
  from enum import Enum
20
  import time
@@ -39,7 +39,7 @@ from models.face_recognizer import FaceRecognizer
39
  from models.body_recognizer import BodyRecognizer
40
  from models.motion_detector import MotionDetector
41
  from scoring.hype_scorer import HypeScorer, SegmentScore
42
- from scoring.domain_presets import get_domain_preset, Domain
43
  from scoring.viral_hooks import ViralHookDetector, HookSignal
44
 
45
  logger = get_logger("pipeline.orchestrator")
@@ -294,6 +294,7 @@ class PipelineOrchestrator:
294
 
295
  # Visual analysis (if enabled)
296
  visual_features = []
 
297
  if self._visual_analyzer is not None:
298
  self._update_progress(PipelineStage.ANALYZING_VISUAL, 0.0, "Analyzing visual content...")
299
  try:
@@ -302,11 +303,26 @@ class PipelineOrchestrator:
302
  frame.frame_path, timestamp=frame.timestamp
303
  )
304
  visual_features.append(features)
 
 
 
 
 
 
 
 
 
 
305
  self._update_progress(
306
  PipelineStage.ANALYZING_VISUAL,
307
  (i + 1) / len(frames),
308
  f"Analyzing frame {i+1}/{len(frames)}"
309
  )
 
 
 
 
 
310
  except Exception as e:
311
  logger.warning(f"Visual analysis failed, continuing without: {e}")
312
  self._update_progress(PipelineStage.ANALYZING_VISUAL, 1.0, "Visual analysis complete")
@@ -346,9 +362,12 @@ class PipelineOrchestrator:
346
  logger.warning(f"Person detection failed: {e}")
347
  self._update_progress(PipelineStage.DETECTING_PERSON, 1.0, "Person detection complete")
348
 
349
- # Motion analysis (simplified)
350
  self._update_progress(PipelineStage.ANALYZING_MOTION, 0.0, "Analyzing motion...")
351
- motion_scores = self._estimate_motion_from_visual(visual_features)
 
 
 
352
  self._update_progress(PipelineStage.ANALYZING_MOTION, 1.0, "Motion analysis complete")
353
 
354
  # Scoring
@@ -466,6 +485,16 @@ class PipelineOrchestrator:
466
  logger.warning(f"Visual analyzer not available: {e}")
467
  self._visual_analyzer = None
468
 
 
 
 
 
 
 
 
 
 
 
469
  # Person recognition (only if needed)
470
  if person_filter:
471
  try:
@@ -586,11 +615,62 @@ class PipelineOrchestrator:
586
  for s in scores
587
  ]
588
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
589
  def _estimate_motion_from_visual(
590
  self,
591
  visual_features: List[VisualFeatures],
592
  ) -> List[float]:
593
- """Estimate motion scores from visual analysis."""
594
  if not visual_features:
595
  return []
596
 
@@ -608,6 +688,53 @@ class PipelineOrchestrator:
608
 
609
  return [motion_map.get(f.action_detected, 0.4) for f in visual_features]
610
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
611
  def _detect_viral_hooks(
612
  self,
613
  frames: List[SampledFrame],
 
14
  """
15
 
16
  from pathlib import Path
17
+ from typing import List, Optional, Callable, Dict, Any
18
  from dataclasses import dataclass, field
19
  from enum import Enum
20
  import time
 
39
  from models.body_recognizer import BodyRecognizer
40
  from models.motion_detector import MotionDetector
41
  from scoring.hype_scorer import HypeScorer, SegmentScore
42
+ from scoring.domain_presets import get_domain_preset
43
  from scoring.viral_hooks import ViralHookDetector, HookSignal
44
 
45
  logger = get_logger("pipeline.orchestrator")
 
294
 
295
  # Visual analysis (if enabled)
296
  visual_features = []
297
+ custom_analysis_results = []
298
  if self._visual_analyzer is not None:
299
  self._update_progress(PipelineStage.ANALYZING_VISUAL, 0.0, "Analyzing visual content...")
300
  try:
 
303
  frame.frame_path, timestamp=frame.timestamp
304
  )
305
  visual_features.append(features)
306
+
307
+ # Apply custom prompt analysis if provided
308
+ if custom_prompt:
309
+ custom_result = self._visual_analyzer.analyze_with_custom_prompt(
310
+ frame.frame_path,
311
+ prompt=custom_prompt,
312
+ timestamp=frame.timestamp,
313
+ )
314
+ custom_analysis_results.append(custom_result)
315
+
316
  self._update_progress(
317
  PipelineStage.ANALYZING_VISUAL,
318
  (i + 1) / len(frames),
319
  f"Analyzing frame {i+1}/{len(frames)}"
320
  )
321
+
322
+ # Boost scores based on custom prompt matches
323
+ if custom_analysis_results:
324
+ self._apply_custom_prompt_boost(visual_features, custom_analysis_results)
325
+
326
  except Exception as e:
327
  logger.warning(f"Visual analysis failed, continuing without: {e}")
328
  self._update_progress(PipelineStage.ANALYZING_VISUAL, 1.0, "Visual analysis complete")
 
362
  logger.warning(f"Person detection failed: {e}")
363
  self._update_progress(PipelineStage.DETECTING_PERSON, 1.0, "Person detection complete")
364
 
365
+ # Motion analysis
366
  self._update_progress(PipelineStage.ANALYZING_MOTION, 0.0, "Analyzing motion...")
367
+ motion_scores = self._compute_motion_scores(frames)
368
+ # Fallback to visual estimation if motion detector failed or unavailable
369
+ if not motion_scores and visual_features:
370
+ motion_scores = self._estimate_motion_from_visual(visual_features)
371
  self._update_progress(PipelineStage.ANALYZING_MOTION, 1.0, "Motion analysis complete")
372
 
373
  # Scoring
 
485
  logger.warning(f"Visual analyzer not available: {e}")
486
  self._visual_analyzer = None
487
 
488
+ # Motion detector (optional, falls back to visual estimation)
489
+ try:
490
+ self._motion_detector = MotionDetector(
491
+ self.config.model,
492
+ use_raft=True, # Use high-quality RAFT if available
493
+ )
494
+ except Exception as e:
495
+ logger.warning(f"Motion detector not available, using visual estimation: {e}")
496
+ self._motion_detector = None
497
+
498
  # Person recognition (only if needed)
499
  if person_filter:
500
  try:
 
615
  for s in scores
616
  ]
617
 
618
+ def _compute_motion_scores(
619
+ self,
620
+ frames: List[SampledFrame],
621
+ ) -> List[float]:
622
+ """
623
+ Compute motion scores using MotionDetector or fallback to visual estimation.
624
+
625
+ Args:
626
+ frames: Sampled frames with paths and timestamps
627
+
628
+ Returns:
629
+ List of motion scores (0-1) for each frame
630
+ """
631
+ if not frames:
632
+ return []
633
+
634
+ # Use real motion detector if available
635
+ if self._motion_detector is not None and len(frames) >= 2:
636
+ try:
637
+ import cv2
638
+
639
+ motion_scores = []
640
+
641
+ # Load frames and compute motion between consecutive pairs
642
+ prev_frame = None
643
+ for i, frame in enumerate(frames):
644
+ curr_frame = cv2.imread(str(frame.frame_path))
645
+
646
+ if prev_frame is not None and curr_frame is not None:
647
+ motion_result = self._motion_detector.analyze_motion(
648
+ prev_frame, curr_frame, timestamp=frame.timestamp
649
+ )
650
+ motion_scores.append(motion_result.magnitude)
651
+ else:
652
+ # First frame has no motion score
653
+ if i == 0:
654
+ motion_scores.append(0.0)
655
+
656
+ prev_frame = curr_frame
657
+
658
+ logger.info(f"Computed motion scores for {len(motion_scores)} frames using RAFT/Farneback")
659
+ return motion_scores
660
+
661
+ except Exception as e:
662
+ logger.warning(f"Motion detection failed, falling back to visual estimation: {e}")
663
+
664
+ # Fallback: estimate from visual features (requires visual_features from caller)
665
+ # Return empty list - will be filled by visual estimation in scoring
666
+ logger.info("Using visual estimation for motion scores")
667
+ return []
668
+
669
  def _estimate_motion_from_visual(
670
  self,
671
  visual_features: List[VisualFeatures],
672
  ) -> List[float]:
673
+ """Estimate motion scores from visual analysis (fallback)."""
674
  if not visual_features:
675
  return []
676
 
 
688
 
689
  return [motion_map.get(f.action_detected, 0.4) for f in visual_features]
690
 
691
+ def _apply_custom_prompt_boost(
692
+ self,
693
+ visual_features: List[VisualFeatures],
694
+ custom_results: List[Dict],
695
+ ) -> None:
696
+ """
697
+ Boost visual scores based on custom prompt responses.
698
+
699
+ Analyzes custom prompt responses and boosts hype scores for frames
700
+ where the response indicates a match with the user's criteria.
701
+
702
+ Args:
703
+ visual_features: Visual features to modify (in-place)
704
+ custom_results: Results from custom prompt analysis
705
+ """
706
+ if not custom_results or len(custom_results) != len(visual_features):
707
+ return
708
+
709
+ # Keywords that indicate positive matches
710
+ positive_keywords = [
711
+ "yes", "true", "found", "detected", "present", "visible",
712
+ "showing", "contains", "includes", "displays", "features",
713
+ "action", "exciting", "highlight", "important", "key",
714
+ "peak", "climax", "intense", "dramatic", "significant",
715
+ ]
716
+
717
+ for i, (features, custom) in enumerate(zip(visual_features, custom_results)):
718
+ response = custom.get("response", "").lower()
719
+
720
+ # Check for positive indicators
721
+ match_score = 0.0
722
+ for keyword in positive_keywords:
723
+ if keyword in response:
724
+ match_score += 0.1
725
+
726
+ # Cap the boost at 50%
727
+ boost = min(0.5, match_score)
728
+
729
+ if boost > 0:
730
+ # Boost the hype score
731
+ original_score = features.hype_score
732
+ features.hype_score = min(1.0, features.hype_score * (1 + boost))
733
+ logger.debug(
734
+ f"Frame {i}: custom prompt boost {boost:.2f} "
735
+ f"({original_score:.2f} -> {features.hype_score:.2f})"
736
+ )
737
+
738
  def _detect_viral_hooks(
739
  self,
740
  frames: List[SampledFrame],