Spaces:
Sleeping
Sleeping
File size: 6,710 Bytes
eff2be4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 | import logging
import os
import cv2
import numpy as np
from google.cloud import videointelligence_v1 as vi
from scipy.spatial.distance import cdist
from sklearn.metrics import silhouette_score
# Set up logger
logger = logging.getLogger("uvicorn.error")
def detect_shot_intervals_local(video_path: str) -> list[tuple[float, float]]:
logger.info(f"Detecting shot intervals for video: {video_path}")
client = vi.VideoIntelligenceServiceClient()
with open(video_path, "rb") as f:
input_content = f.read()
op = client.annotate_video(
request={
"input_content": input_content,
"features": [vi.Feature.SHOT_CHANGE_DETECTION],
}
)
response = op.result(timeout=300)
if not response or not response.annotation_results:
logger.error("No annotation_results found in video intelligence response.")
return []
result = response.annotation_results[0]
intervals = []
for shot in result.shot_annotations:
start = (
shot.start_time_offset.seconds + shot.start_time_offset.microseconds / 1e6
)
end = shot.end_time_offset.seconds + shot.end_time_offset.microseconds / 1e6
intervals.append((start, end))
logger.info(f"Detected {len(intervals)} shot intervals.")
return intervals
def color_histogram(img: np.ndarray) -> np.ndarray:
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
hist = cv2.calcHist([hsv], [0, 1, 2], None, [8, 8, 8], [0, 180, 0, 256, 0, 256])
return cv2.normalize(hist, hist).flatten()
def sample_frames_per_shot(
video_path: str, start: float, end: float, step: float = 1.0
) -> list[np.ndarray]:
# logger.info(f"Sampling frames from {start:.2f}s to {end:.2f}s every {step:.2f}s")
cap = cv2.VideoCapture(video_path)
frames = []
t = start
while t < end:
cap.set(cv2.CAP_PROP_POS_MSEC, t * 1000)
ret, frame = cap.read()
if not ret:
logger.warning(f"Failed to read frame at {t:.2f}s")
break
frames.append(frame)
t += step
cap.release()
# logger.info(f"Sampled {len(frames)} frames for shot interval.")
return frames
def kmeans_init(features: np.ndarray):
n, _ = features.shape
k = int(np.sqrt(n)) or 1
idx = np.random.choice(n, k, replace=False)
centers = features[idx]
clusters = np.argmin(cdist(features, centers), axis=1)
return clusters, centers
def kmeans_silhouette(features: np.ndarray):
k = max(int(np.sqrt(len(features))), 2)
best_k, best_score = k, -1
clusters, centers = kmeans_init(features)
best_centers = centers.copy()
while k > 2:
d = cdist(centers, centers)
np.fill_diagonal(d, np.inf)
i, j = np.unravel_index(np.argmin(d), d.shape)
clusters = np.where(clusters == j, i, clusters)
clusters = np.where(clusters > j, clusters - 1, clusters)
new_centers = []
for cid in range(k - 1):
cluster_feats = features[clusters == cid]
if cluster_feats.size == 0:
continue
mean_vec = np.mean(cluster_feats, axis=0)
idx_close = np.argmin(np.linalg.norm(cluster_feats - mean_vec, axis=1))
new_centers.append(cluster_feats[idx_close])
centers = new_centers
k -= 1
if len(np.unique(clusters)) > 1:
score = silhouette_score(features, clusters)
if score > best_score:
best_score, best_k = score, k
best_centers = centers.copy()
center_indices = []
for c in best_centers:
matches = np.where((features == c).all(axis=1))[0]
if matches.size > 0:
center_indices.append(int(matches[0]))
# logger.info(f"KMeans silhouette: best_k={best_k}, best_score={best_score:.4f}")
return best_k, best_centers, center_indices
def redundancy_filter(
video_path: str, indices: list[int], threshold: float
) -> list[int]:
# logger.info(f"Filtering redundant frames with threshold {threshold}")
histos = []
cap = cv2.VideoCapture(video_path)
for idx in indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
ret, frame = cap.read()
if ret:
histos.append(color_histogram(frame))
cap.release()
keep = []
for i, h in enumerate(histos):
if not any(
np.dot(h, nh) / (np.linalg.norm(h) * np.linalg.norm(nh)) > threshold
for nh in histos[:i]
):
keep.append(indices[i])
# logger.info(f"Filtered down to {len(keep)} non-redundant frames.")
return keep
def extract_and_save_keyframes(
video_path: str,
output_dir: str,
start_index: int = 0,
step: float = 1.0,
threshold: float = 0.7,
k_min: int = 2,
k_max: int = 8,
) -> int:
logger.info(f"Starting keyframe extraction for {video_path}")
os.makedirs(output_dir, exist_ok=True)
# Get FPS to convert seconds to frame indices
cap_meta = cv2.VideoCapture(video_path)
video_fps = cap_meta.get(cv2.CAP_PROP_FPS) or 1.0
cap_meta.release()
intervals = detect_shot_intervals_local(video_path)
cap = cv2.VideoCapture(video_path)
output_idx = start_index
for shot_idx, (start, end) in enumerate(intervals):
# logger.info(
# f"Processing shot {shot_idx + 1}/{len(intervals)}: {start:.2f}s to {end:.2f}s"
# )
# Sample frames & extract features
frames = sample_frames_per_shot(video_path, start, end, step)
feats = (
np.vstack([color_histogram(f) for f in frames])
if frames
else np.empty((0,))
)
# Determine intra-shot keyframe indices
if feats.size < k_min or feats.ndim == 1:
idxs = list(range(len(frames)))
else:
_, centers, cidxs = kmeans_silhouette(feats)
idxs = cidxs
# Map to global frame numbers and dedupe
global_idxs = [int(start * video_fps) + i for i in idxs]
filtered = redundancy_filter(video_path, global_idxs, threshold)
# Save each keyframe sequentially into output_dir
for frame_no in filtered:
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_no)
ret, frame = cap.read()
if not ret:
continue
out_path = os.path.join(output_dir, f"image_{output_idx:03d}.jpg")
cv2.imwrite(out_path, frame)
output_idx += 1
logger.info(
f"Shot {shot_idx + 1}: saved {len(filtered)} keyframes. Total so far: {output_idx}"
)
cap.release()
logger.info(f"Extraction complete. Total frames saved: {output_idx}")
return output_idx
|