ShortSmith_v3 / models /motion_detector.py
chaitanya.musale
Fix models folder issues: bugs and code cleanup
15c68da
"""
ShortSmith v2 - Motion Detector Module
Motion analysis using optical flow for:
- Detecting action-heavy segments
- Identifying camera movement vs subject movement
- Dynamic FPS scaling based on motion level
Uses RAFT (Recurrent All-Pairs Field Transforms) for high-quality
optical flow, with fallback to Farneback for speed.
"""
from typing import List, Optional, Tuple
from dataclasses import dataclass
import numpy as np
from utils.logger import get_logger, LogTimer
from config import get_config, ModelConfig
logger = get_logger("models.motion_detector")
@dataclass
class MotionScore:
"""Motion analysis result for a frame pair."""
timestamp: float # Timestamp of second frame
magnitude: float # Average motion magnitude (0-1 normalized)
direction: float # Dominant motion direction (radians)
uniformity: float # How uniform the motion is (1 = all same direction)
is_camera_motion: bool # Likely camera motion vs subject motion
@property
def is_high_motion(self) -> bool:
"""Check if this is a high-motion segment."""
return self.magnitude > 0.3
@property
def is_action(self) -> bool:
"""Check if this likely contains action (non-uniform motion)."""
return self.magnitude > 0.2 and self.uniformity < 0.7
class MotionDetector:
"""
Motion detection using optical flow.
Supports:
- RAFT optical flow (high quality, GPU)
- Farneback optical flow (faster, CPU)
- Motion magnitude scoring
- Camera vs subject motion detection
"""
def __init__(
self,
config: Optional[ModelConfig] = None,
use_raft: bool = True,
):
"""
Initialize motion detector.
Args:
config: Model configuration
use_raft: Whether to use RAFT (True) or Farneback (False)
"""
self.config = config or get_config().model
self.use_raft = use_raft
self.raft_model = None
if use_raft:
self._load_raft()
logger.info(f"MotionDetector initialized (RAFT={use_raft})")
def _load_raft(self) -> None:
"""Load RAFT optical flow model."""
try:
import torch
from torchvision.models.optical_flow import raft_small, Raft_Small_Weights
logger.info("Loading RAFT optical flow model...")
weights = Raft_Small_Weights.DEFAULT
self.raft_model = raft_small(weights=weights)
if self.config.device == "cuda" and torch.cuda.is_available():
self.raft_model = self.raft_model.cuda()
self.raft_model.eval()
logger.info("RAFT model loaded successfully")
except Exception as e:
logger.warning(f"Failed to load RAFT model, using Farneback: {e}")
self.use_raft = False
self.raft_model = None
def compute_flow(
self,
frame1: np.ndarray,
frame2: np.ndarray,
) -> np.ndarray:
"""
Compute optical flow between two frames.
Args:
frame1: First frame (BGR or RGB, HxWxC)
frame2: Second frame (BGR or RGB, HxWxC)
Returns:
Optical flow array (HxWx2), flow[y,x] = (dx, dy)
"""
if self.use_raft and self.raft_model is not None:
return self._compute_raft_flow(frame1, frame2)
else:
return self._compute_farneback_flow(frame1, frame2)
def _compute_raft_flow(
self,
frame1: np.ndarray,
frame2: np.ndarray,
) -> np.ndarray:
"""Compute flow using RAFT."""
import torch
try:
# Convert to RGB if BGR
if frame1.shape[2] == 3:
frame1_rgb = frame1[:, :, ::-1].copy()
frame2_rgb = frame2[:, :, ::-1].copy()
else:
frame1_rgb = frame1
frame2_rgb = frame2
# Convert to tensors
img1 = torch.from_numpy(frame1_rgb).permute(2, 0, 1).float().unsqueeze(0)
img2 = torch.from_numpy(frame2_rgb).permute(2, 0, 1).float().unsqueeze(0)
if self.config.device == "cuda" and torch.cuda.is_available():
img1 = img1.cuda()
img2 = img2.cuda()
# Compute flow
with torch.no_grad():
flow_predictions = self.raft_model(img1, img2)
flow = flow_predictions[-1] # Use final prediction
# Convert back to numpy
flow = flow[0].permute(1, 2, 0).cpu().numpy()
return flow
except Exception as e:
logger.warning(f"RAFT flow failed, using Farneback: {e}")
return self._compute_farneback_flow(frame1, frame2)
def _compute_farneback_flow(
self,
frame1: np.ndarray,
frame2: np.ndarray,
) -> np.ndarray:
"""Compute flow using Farneback algorithm."""
import cv2
# Convert to grayscale
if len(frame1.shape) == 3:
gray1 = cv2.cvtColor(frame1, cv2.COLOR_BGR2GRAY)
gray2 = cv2.cvtColor(frame2, cv2.COLOR_BGR2GRAY)
else:
gray1 = frame1
gray2 = frame2
# Compute Farneback optical flow
flow = cv2.calcOpticalFlowFarneback(
gray1, gray2,
None,
pyr_scale=0.5,
levels=3,
winsize=15,
iterations=3,
poly_n=5,
poly_sigma=1.2,
flags=0,
)
return flow
def analyze_motion(
self,
frame1: np.ndarray,
frame2: np.ndarray,
timestamp: float = 0.0,
) -> MotionScore:
"""
Analyze motion between two frames.
Args:
frame1: First frame
frame2: Second frame
timestamp: Timestamp of second frame
Returns:
MotionScore with analysis results
"""
flow = self.compute_flow(frame1, frame2)
# Compute magnitude and direction
magnitude = np.sqrt(flow[:, :, 0]**2 + flow[:, :, 1]**2)
direction = np.arctan2(flow[:, :, 1], flow[:, :, 0])
# Average magnitude (normalized by image diagonal)
h, w = frame1.shape[:2]
diagonal = np.sqrt(h**2 + w**2)
avg_magnitude = float(np.mean(magnitude) / diagonal)
# Dominant direction
# Weight by magnitude to get dominant direction
weighted_direction = np.average(direction, weights=magnitude + 1e-8)
# Uniformity: how consistent is the motion direction?
# High uniformity = likely camera motion
dir_std = float(np.std(direction))
uniformity = 1.0 / (1.0 + dir_std)
# Detect camera motion (uniform direction across frame)
is_camera = uniformity > 0.7 and avg_magnitude > 0.05
return MotionScore(
timestamp=timestamp,
magnitude=min(1.0, avg_magnitude * 10), # Scale up
direction=float(weighted_direction),
uniformity=uniformity,
is_camera_motion=is_camera,
)
def analyze_video_segment(
self,
frames: List[np.ndarray],
timestamps: List[float],
) -> List[MotionScore]:
"""
Analyze motion across a video segment.
Args:
frames: List of frames
timestamps: Timestamps for each frame
Returns:
List of MotionScore objects (one per frame pair)
"""
if len(frames) < 2:
return []
scores = []
with LogTimer(logger, f"Analyzing motion in {len(frames)} frames"):
for i in range(1, len(frames)):
try:
score = self.analyze_motion(
frames[i-1],
frames[i],
timestamps[i],
)
scores.append(score)
except Exception as e:
logger.warning(f"Motion analysis failed for frame {i}: {e}")
return scores
def get_motion_heatmap(
self,
frame1: np.ndarray,
frame2: np.ndarray,
) -> np.ndarray:
"""
Get motion magnitude heatmap.
Args:
frame1: First frame
frame2: Second frame
Returns:
Heatmap of motion magnitude (HxW, values 0-255)
"""
flow = self.compute_flow(frame1, frame2)
magnitude = np.sqrt(flow[:, :, 0]**2 + flow[:, :, 1]**2)
# Normalize to 0-255
max_mag = np.percentile(magnitude, 99) # Robust max
if max_mag > 0:
normalized = np.clip(magnitude / max_mag * 255, 0, 255)
else:
normalized = np.zeros_like(magnitude)
return normalized.astype(np.uint8)
def compute_aggregate_motion(
self,
scores: List[MotionScore],
) -> float:
"""
Compute aggregate motion score for a segment.
Args:
scores: List of MotionScore objects
Returns:
Aggregate motion score (0-1)
"""
if not scores:
return 0.0
# Weight by non-camera motion
weighted_sum = sum(
s.magnitude * (0.3 if s.is_camera_motion else 1.0)
for s in scores
)
return weighted_sum / len(scores)
def identify_high_motion_segments(
self,
scores: List[MotionScore],
threshold: float = 0.3,
min_duration: int = 3,
) -> List[Tuple[float, float, float]]:
"""
Identify segments with high motion.
Args:
scores: List of MotionScore objects
threshold: Minimum motion magnitude
min_duration: Minimum number of consecutive frames
Returns:
List of (start_time, end_time, avg_motion) tuples
"""
if not scores:
return []
segments = []
in_segment = False
segment_start = 0.0
segment_scores = []
for score in scores:
if score.magnitude >= threshold:
if not in_segment:
in_segment = True
segment_start = score.timestamp
segment_scores = [score.magnitude]
else:
segment_scores.append(score.magnitude)
else:
if in_segment:
if len(segment_scores) >= min_duration:
segments.append((
segment_start,
score.timestamp,
sum(segment_scores) / len(segment_scores),
))
in_segment = False
# Handle segment at end
if in_segment and len(segment_scores) >= min_duration:
segments.append((
segment_start,
scores[-1].timestamp,
sum(segment_scores) / len(segment_scores),
))
logger.info(f"Found {len(segments)} high-motion segments")
return segments
# Export public interface
__all__ = ["MotionDetector", "MotionScore"]