dev_caio / core /frame_sampler.py
Chaitanya-aitf's picture
Initializing project from local
ad4e58a verified
"""
ShortSmith v2 - Frame Sampler Module
Hierarchical frame sampling strategy:
1. Coarse pass: Sample 1 frame per N seconds to identify candidate regions
2. Dense pass: Sample at higher FPS only on promising segments
3. Dynamic FPS: Adjust sampling based on motion/content
"""
from pathlib import Path
from typing import List, Optional, Tuple, Generator
from dataclasses import dataclass, field
import numpy as np
from utils.logger import get_logger, LogTimer
from utils.helpers import VideoProcessingError, batch_list
from config import get_config, ProcessingConfig
from core.video_processor import VideoProcessor, VideoMetadata
logger = get_logger("core.frame_sampler")
@dataclass
class SampledFrame:
"""Represents a sampled frame with metadata."""
frame_path: Path # Path to the frame image file
timestamp: float # Timestamp in seconds
frame_index: int # Index in the video
is_dense_sample: bool # Whether from dense sampling pass
scene_id: Optional[int] = None # Associated scene ID
# Optional: frame data loaded into memory
frame_data: Optional[np.ndarray] = field(default=None, repr=False)
@property
def filename(self) -> str:
"""Get the frame filename."""
return self.frame_path.name
@dataclass
class SamplingRegion:
"""A region identified for dense sampling."""
start_time: float
end_time: float
priority_score: float # Higher = more likely to contain highlights
@property
def duration(self) -> float:
return self.end_time - self.start_time
class FrameSampler:
"""
Intelligent frame sampler using hierarchical strategy.
Optimizes compute by:
1. Sparse sampling to identify candidate regions
2. Dense sampling only on promising areas
3. Skipping static/low-motion content
"""
def __init__(
self,
video_processor: VideoProcessor,
config: Optional[ProcessingConfig] = None,
):
"""
Initialize frame sampler.
Args:
video_processor: VideoProcessor instance for frame extraction
config: Processing configuration (uses default if None)
"""
self.video_processor = video_processor
self.config = config or get_config().processing
logger.info(
f"FrameSampler initialized (coarse={self.config.coarse_sample_interval}s, "
f"dense_fps={self.config.dense_sample_fps})"
)
def sample_coarse(
self,
video_path: str | Path,
output_dir: str | Path,
metadata: Optional[VideoMetadata] = None,
start_time: float = 0,
end_time: Optional[float] = None,
) -> List[SampledFrame]:
"""
Perform coarse sampling pass.
Samples 1 frame every N seconds (default 5s) across the video.
Args:
video_path: Path to the video file
output_dir: Directory to save extracted frames
metadata: Video metadata (fetched if not provided)
start_time: Start sampling from this timestamp
end_time: End sampling at this timestamp
Returns:
List of SampledFrame objects
"""
video_path = Path(video_path)
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Get metadata if not provided
if metadata is None:
metadata = self.video_processor.get_metadata(video_path)
end_time = end_time or metadata.duration
# Validate time range
if end_time > metadata.duration:
end_time = metadata.duration
if start_time >= end_time:
raise VideoProcessingError(
f"Invalid time range: {start_time} to {end_time}"
)
with LogTimer(logger, f"Coarse sampling {video_path.name}"):
# Calculate timestamps
interval = self.config.coarse_sample_interval
timestamps = []
current = start_time
while current < end_time:
timestamps.append(current)
current += interval
logger.info(
f"Coarse sampling: {len(timestamps)} frames "
f"({interval}s interval over {end_time - start_time:.1f}s)"
)
# Extract frames
frame_paths = self.video_processor.extract_frames(
video_path,
output_dir / "coarse",
timestamps=timestamps,
)
# Create SampledFrame objects
frames = []
for i, (path, ts) in enumerate(zip(frame_paths, timestamps)):
frames.append(SampledFrame(
frame_path=path,
timestamp=ts,
frame_index=int(ts * metadata.fps),
is_dense_sample=False,
))
return frames
def sample_dense(
self,
video_path: str | Path,
output_dir: str | Path,
regions: List[SamplingRegion],
metadata: Optional[VideoMetadata] = None,
) -> List[SampledFrame]:
"""
Perform dense sampling on specific regions.
Args:
video_path: Path to the video file
output_dir: Directory to save extracted frames
regions: List of regions to sample densely
metadata: Video metadata (fetched if not provided)
Returns:
List of SampledFrame objects from dense regions
"""
video_path = Path(video_path)
output_dir = Path(output_dir)
if metadata is None:
metadata = self.video_processor.get_metadata(video_path)
all_frames = []
with LogTimer(logger, f"Dense sampling {len(regions)} regions"):
for i, region in enumerate(regions):
region_dir = output_dir / f"dense_region_{i:03d}"
region_dir.mkdir(parents=True, exist_ok=True)
logger.debug(
f"Dense sampling region {i}: "
f"{region.start_time:.1f}s - {region.end_time:.1f}s"
)
# Extract at dense FPS
frame_paths = self.video_processor.extract_frames(
video_path,
region_dir,
fps=self.config.dense_sample_fps,
start_time=region.start_time,
end_time=region.end_time,
)
# Calculate timestamps for each frame
for j, path in enumerate(frame_paths):
timestamp = region.start_time + (j / self.config.dense_sample_fps)
all_frames.append(SampledFrame(
frame_path=path,
timestamp=timestamp,
frame_index=int(timestamp * metadata.fps),
is_dense_sample=True,
))
logger.info(f"Dense sampling extracted {len(all_frames)} frames")
return all_frames
def sample_hierarchical(
self,
video_path: str | Path,
output_dir: str | Path,
candidate_scorer: Optional[callable] = None,
top_k_regions: int = 5,
metadata: Optional[VideoMetadata] = None,
) -> Tuple[List[SampledFrame], List[SampledFrame]]:
"""
Perform full hierarchical sampling.
1. Coarse pass to identify candidates
2. Score candidate regions
3. Dense pass on top-k regions
Args:
video_path: Path to the video file
output_dir: Directory to save extracted frames
candidate_scorer: Function to score candidate regions (optional)
top_k_regions: Number of top regions to densely sample
metadata: Video metadata (fetched if not provided)
Returns:
Tuple of (coarse_frames, dense_frames)
"""
video_path = Path(video_path)
output_dir = Path(output_dir)
if metadata is None:
metadata = self.video_processor.get_metadata(video_path)
with LogTimer(logger, "Hierarchical sampling"):
# Step 1: Coarse sampling
coarse_frames = self.sample_coarse(
video_path, output_dir, metadata
)
# Step 2: Identify candidate regions
if candidate_scorer is not None:
# Use provided scorer to identify promising regions
regions = self._identify_candidate_regions(
coarse_frames, candidate_scorer, top_k_regions
)
else:
# Default: uniform distribution
regions = self._create_uniform_regions(
metadata.duration, top_k_regions
)
# Step 3: Dense sampling on top regions
dense_frames = self.sample_dense(
video_path, output_dir, regions, metadata
)
logger.info(
f"Hierarchical sampling complete: "
f"{len(coarse_frames)} coarse, {len(dense_frames)} dense frames"
)
return coarse_frames, dense_frames
def _identify_candidate_regions(
self,
frames: List[SampledFrame],
scorer: callable,
top_k: int,
) -> List[SamplingRegion]:
"""
Identify top candidate regions based on scoring.
Args:
frames: List of coarse sampled frames
scorer: Function that takes frame and returns score (0-1)
top_k: Number of regions to return
Returns:
List of SamplingRegion objects
"""
# Score each frame
scores = []
for frame in frames:
try:
score = scorer(frame)
scores.append((frame, score))
except Exception as e:
logger.warning(f"Failed to score frame {frame.timestamp}s: {e}")
scores.append((frame, 0.0))
# Sort by score
scores.sort(key=lambda x: x[1], reverse=True)
# Create regions around top frames
interval = self.config.coarse_sample_interval
regions = []
for frame, score in scores[:top_k]:
# Expand region around this frame
start = max(0, frame.timestamp - interval)
end = frame.timestamp + interval
regions.append(SamplingRegion(
start_time=start,
end_time=end,
priority_score=score,
))
# Merge overlapping regions
regions = self._merge_overlapping_regions(regions)
return regions
def _create_uniform_regions(
self,
duration: float,
num_regions: int,
) -> List[SamplingRegion]:
"""
Create uniformly distributed sampling regions.
Args:
duration: Total video duration
num_regions: Number of regions to create
Returns:
List of uniformly spaced SamplingRegion objects
"""
region_duration = self.config.coarse_sample_interval * 2
gap = (duration - region_duration * num_regions) / (num_regions + 1)
if gap < 0:
# Video too short, create fewer regions
gap = 0
num_regions = max(1, int(duration / region_duration))
regions = []
current = gap
for i in range(num_regions):
regions.append(SamplingRegion(
start_time=current,
end_time=min(current + region_duration, duration),
priority_score=1.0 / num_regions,
))
current += region_duration + gap
return regions
def _merge_overlapping_regions(
self,
regions: List[SamplingRegion],
) -> List[SamplingRegion]:
"""
Merge overlapping sampling regions.
Args:
regions: List of potentially overlapping regions
Returns:
List of merged regions
"""
if not regions:
return []
# Sort by start time
sorted_regions = sorted(regions, key=lambda r: r.start_time)
merged = [sorted_regions[0]]
for region in sorted_regions[1:]:
last = merged[-1]
if region.start_time <= last.end_time:
# Merge
merged[-1] = SamplingRegion(
start_time=last.start_time,
end_time=max(last.end_time, region.end_time),
priority_score=max(last.priority_score, region.priority_score),
)
else:
merged.append(region)
return merged
def sample_at_timestamps(
self,
video_path: str | Path,
output_dir: str | Path,
timestamps: List[float],
metadata: Optional[VideoMetadata] = None,
) -> List[SampledFrame]:
"""
Sample frames at specific timestamps.
Args:
video_path: Path to the video file
output_dir: Directory to save extracted frames
timestamps: List of timestamps to sample
metadata: Video metadata (fetched if not provided)
Returns:
List of SampledFrame objects
"""
video_path = Path(video_path)
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
if metadata is None:
metadata = self.video_processor.get_metadata(video_path)
with LogTimer(logger, f"Sampling {len(timestamps)} specific timestamps"):
frame_paths = self.video_processor.extract_frames(
video_path,
output_dir / "specific",
timestamps=timestamps,
)
frames = []
for path, ts in zip(frame_paths, timestamps):
frames.append(SampledFrame(
frame_path=path,
timestamp=ts,
frame_index=int(ts * metadata.fps),
is_dense_sample=False,
))
return frames
def get_keyframes(
self,
video_path: str | Path,
output_dir: str | Path,
scenes: Optional[List] = None,
) -> List[SampledFrame]:
"""
Extract keyframes (one per scene).
Args:
video_path: Path to the video file
output_dir: Directory to save extracted frames
scenes: List of Scene objects (detected if not provided)
Returns:
List of keyframe SampledFrame objects
"""
from core.scene_detector import SceneDetector
video_path = Path(video_path)
if scenes is None:
detector = SceneDetector()
scenes = detector.detect_scenes(video_path)
# Get midpoint of each scene as keyframe
timestamps = [scene.midpoint for scene in scenes]
with LogTimer(logger, f"Extracting {len(timestamps)} keyframes"):
frames = self.sample_at_timestamps(
video_path, output_dir, timestamps
)
# Add scene IDs
for frame, scene_id in zip(frames, range(len(scenes))):
frame.scene_id = scene_id
return frames
# Export public interface
__all__ = ["FrameSampler", "SampledFrame", "SamplingRegion"]