geolocation / src /prompt /preprocess /keyframe_extract.py
3v324v23's picture
init prj
eff2be4
import logging
import os
import cv2
import numpy as np
from google.cloud import videointelligence_v1 as vi
from scipy.spatial.distance import cdist
from sklearn.metrics import silhouette_score
# Set up logger
logger = logging.getLogger("uvicorn.error")
def detect_shot_intervals_local(video_path: str) -> list[tuple[float, float]]:
logger.info(f"Detecting shot intervals for video: {video_path}")
client = vi.VideoIntelligenceServiceClient()
with open(video_path, "rb") as f:
input_content = f.read()
op = client.annotate_video(
request={
"input_content": input_content,
"features": [vi.Feature.SHOT_CHANGE_DETECTION],
}
)
response = op.result(timeout=300)
if not response or not response.annotation_results:
logger.error("No annotation_results found in video intelligence response.")
return []
result = response.annotation_results[0]
intervals = []
for shot in result.shot_annotations:
start = (
shot.start_time_offset.seconds + shot.start_time_offset.microseconds / 1e6
)
end = shot.end_time_offset.seconds + shot.end_time_offset.microseconds / 1e6
intervals.append((start, end))
logger.info(f"Detected {len(intervals)} shot intervals.")
return intervals
def color_histogram(img: np.ndarray) -> np.ndarray:
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
hist = cv2.calcHist([hsv], [0, 1, 2], None, [8, 8, 8], [0, 180, 0, 256, 0, 256])
return cv2.normalize(hist, hist).flatten()
def sample_frames_per_shot(
video_path: str, start: float, end: float, step: float = 1.0
) -> list[np.ndarray]:
# logger.info(f"Sampling frames from {start:.2f}s to {end:.2f}s every {step:.2f}s")
cap = cv2.VideoCapture(video_path)
frames = []
t = start
while t < end:
cap.set(cv2.CAP_PROP_POS_MSEC, t * 1000)
ret, frame = cap.read()
if not ret:
logger.warning(f"Failed to read frame at {t:.2f}s")
break
frames.append(frame)
t += step
cap.release()
# logger.info(f"Sampled {len(frames)} frames for shot interval.")
return frames
def kmeans_init(features: np.ndarray):
n, _ = features.shape
k = int(np.sqrt(n)) or 1
idx = np.random.choice(n, k, replace=False)
centers = features[idx]
clusters = np.argmin(cdist(features, centers), axis=1)
return clusters, centers
def kmeans_silhouette(features: np.ndarray):
k = max(int(np.sqrt(len(features))), 2)
best_k, best_score = k, -1
clusters, centers = kmeans_init(features)
best_centers = centers.copy()
while k > 2:
d = cdist(centers, centers)
np.fill_diagonal(d, np.inf)
i, j = np.unravel_index(np.argmin(d), d.shape)
clusters = np.where(clusters == j, i, clusters)
clusters = np.where(clusters > j, clusters - 1, clusters)
new_centers = []
for cid in range(k - 1):
cluster_feats = features[clusters == cid]
if cluster_feats.size == 0:
continue
mean_vec = np.mean(cluster_feats, axis=0)
idx_close = np.argmin(np.linalg.norm(cluster_feats - mean_vec, axis=1))
new_centers.append(cluster_feats[idx_close])
centers = new_centers
k -= 1
if len(np.unique(clusters)) > 1:
score = silhouette_score(features, clusters)
if score > best_score:
best_score, best_k = score, k
best_centers = centers.copy()
center_indices = []
for c in best_centers:
matches = np.where((features == c).all(axis=1))[0]
if matches.size > 0:
center_indices.append(int(matches[0]))
# logger.info(f"KMeans silhouette: best_k={best_k}, best_score={best_score:.4f}")
return best_k, best_centers, center_indices
def redundancy_filter(
video_path: str, indices: list[int], threshold: float
) -> list[int]:
# logger.info(f"Filtering redundant frames with threshold {threshold}")
histos = []
cap = cv2.VideoCapture(video_path)
for idx in indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
ret, frame = cap.read()
if ret:
histos.append(color_histogram(frame))
cap.release()
keep = []
for i, h in enumerate(histos):
if not any(
np.dot(h, nh) / (np.linalg.norm(h) * np.linalg.norm(nh)) > threshold
for nh in histos[:i]
):
keep.append(indices[i])
# logger.info(f"Filtered down to {len(keep)} non-redundant frames.")
return keep
def extract_and_save_keyframes(
video_path: str,
output_dir: str,
start_index: int = 0,
step: float = 1.0,
threshold: float = 0.7,
k_min: int = 2,
k_max: int = 8,
) -> int:
logger.info(f"Starting keyframe extraction for {video_path}")
os.makedirs(output_dir, exist_ok=True)
# Get FPS to convert seconds to frame indices
cap_meta = cv2.VideoCapture(video_path)
video_fps = cap_meta.get(cv2.CAP_PROP_FPS) or 1.0
cap_meta.release()
intervals = detect_shot_intervals_local(video_path)
cap = cv2.VideoCapture(video_path)
output_idx = start_index
for shot_idx, (start, end) in enumerate(intervals):
# logger.info(
# f"Processing shot {shot_idx + 1}/{len(intervals)}: {start:.2f}s to {end:.2f}s"
# )
# Sample frames & extract features
frames = sample_frames_per_shot(video_path, start, end, step)
feats = (
np.vstack([color_histogram(f) for f in frames])
if frames
else np.empty((0,))
)
# Determine intra-shot keyframe indices
if feats.size < k_min or feats.ndim == 1:
idxs = list(range(len(frames)))
else:
_, centers, cidxs = kmeans_silhouette(feats)
idxs = cidxs
# Map to global frame numbers and dedupe
global_idxs = [int(start * video_fps) + i for i in idxs]
filtered = redundancy_filter(video_path, global_idxs, threshold)
# Save each keyframe sequentially into output_dir
for frame_no in filtered:
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_no)
ret, frame = cap.read()
if not ret:
continue
out_path = os.path.join(output_dir, f"image_{output_idx:03d}.jpg")
cv2.imwrite(out_path, frame)
output_idx += 1
logger.info(
f"Shot {shot_idx + 1}: saved {len(filtered)} keyframes. Total so far: {output_idx}"
)
cap.release()
logger.info(f"Extraction complete. Total frames saved: {output_idx}")
return output_idx