|
|
import os |
|
|
import cv2 |
|
|
import numpy as np |
|
|
import gradio as gr |
|
|
from segment_anything import sam_model_registry, SamPredictor |
|
|
from youtube_transcript_api import YouTubeTranscriptApi |
|
|
|
|
|
def video_to_frames(video_path, output_dir, frame_rate=0.7): |
|
|
if not os.path.exists(output_dir): |
|
|
os.makedirs(output_dir) |
|
|
cap = cv2.VideoCapture(video_path) |
|
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
|
frame_interval = int(fps / frame_rate) |
|
|
frame_count = 0 |
|
|
while True: |
|
|
ret, frame = cap.read() |
|
|
if not ret: |
|
|
break |
|
|
if frame_count % frame_interval == 0: |
|
|
cv2.imwrite(os.path.join(output_dir, f'frame_{frame_count:05d}.jpg'), frame) |
|
|
frame_count += 1 |
|
|
cap.release() |
|
|
return fps |
|
|
|
|
|
def select_background_points(image, num_points=4): |
|
|
h, w, _ = image.shape |
|
|
points = np.array([ |
|
|
[0, 0], |
|
|
[0, w - 1], |
|
|
[h - 1, 0], |
|
|
[h - 1, w - 1] |
|
|
]) |
|
|
|
|
|
if num_points > 4: |
|
|
points = np.vstack([points, |
|
|
[0, w // 2], |
|
|
[h // 2, 0], |
|
|
[h - 1, w // 2], |
|
|
[h // 2, w - 1]]) |
|
|
|
|
|
return points |
|
|
|
|
|
def compare_histograms(frame1, frame2, threshold=0.4): |
|
|
hist1 = cv2.calcHist([frame1], [0, 1, 2], None, [8, 8, 8], [0, 256, 0, 256, 0, 256]) |
|
|
hist2 = cv2.calcHist([frame2], [0, 1, 2], None, [8, 8, 8], [0, 256, 0, 256, 0, 256]) |
|
|
hist1 = cv2.normalize(hist1, hist1).flatten() |
|
|
hist2 = cv2.normalize(hist2, hist2).flatten() |
|
|
diff = cv2.compareHist(hist1, hist2, cv2.HISTCMP_CORREL) |
|
|
return diff < threshold |
|
|
|
|
|
def detect_scene_changes(frame_dir, fps, threshold=0.15, hist_threshold=0.3): |
|
|
frames = sorted(os.listdir(frame_dir)) |
|
|
scene_changes = [] |
|
|
prev_mask = None |
|
|
prev_frame = None |
|
|
|
|
|
for i, frame_name in enumerate(frames): |
|
|
frame = cv2.imread(os.path.join(frame_dir, frame_name)) |
|
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
predictor.set_image(frame_rgb) |
|
|
background_points = select_background_points(frame_rgb) |
|
|
point_labels = np.zeros(background_points.shape[0], dtype=int) |
|
|
masks, _, _ = predictor.predict(point_coords=background_points, |
|
|
point_labels=point_labels, |
|
|
multimask_output=False) |
|
|
mask_diff = 0 |
|
|
if prev_mask is not None: |
|
|
mask_diff = np.logical_xor(masks[0], prev_mask).mean() |
|
|
hist_diff = False |
|
|
if prev_frame is not None: |
|
|
hist_diff = compare_histograms(prev_frame, frame, threshold=hist_threshold) |
|
|
|
|
|
if mask_diff > threshold or hist_diff: |
|
|
timestamp = int(frame_name.split('_')[1].split('.')[0]) / fps |
|
|
scene_changes.append(timestamp) |
|
|
|
|
|
prev_mask = masks[0] |
|
|
prev_frame = frame |
|
|
|
|
|
return scene_changes |
|
|
|
|
|
def get_transcript(video_id): |
|
|
try: |
|
|
transcript = YouTubeTranscriptApi.get_transcript(video_id) |
|
|
return transcript |
|
|
except Exception as e: |
|
|
return [] |
|
|
|
|
|
def group_transcripts_by_scenes(transcripts, scene_changes): |
|
|
grouped_transcripts = [] |
|
|
scene_index = 0 |
|
|
current_group = [] |
|
|
|
|
|
for transcript in transcripts: |
|
|
start_time = transcript['start'] |
|
|
if scene_index < len(scene_changes) and start_time > scene_changes[scene_index]: |
|
|
grouped_transcripts.append(' '.join([t['text'] for t in current_group])) |
|
|
current_group = [] |
|
|
scene_index += 1 |
|
|
current_group.append(transcript) |
|
|
|
|
|
if current_group: |
|
|
grouped_transcripts.append(' '.join([t['text'] for t in current_group])) |
|
|
|
|
|
return grouped_transcripts |
|
|
|
|
|
def process_video_and_transcript(video_file, youtube_video_id): |
|
|
output_dir = "Output_frames" |
|
|
|
|
|
|
|
|
video_path = os.path.join(output_dir, "uploaded_video.mp4") |
|
|
with open(video_path, "wb") as f: |
|
|
f.write(video_file.read()) |
|
|
|
|
|
fps = video_to_frames(video_path, output_dir, frame_rate=0.7) |
|
|
|
|
|
|
|
|
model = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth") |
|
|
global predictor |
|
|
predictor = SamPredictor(model) |
|
|
|
|
|
|
|
|
scene_changes = detect_scene_changes(output_dir, fps, threshold=0.15, hist_threshold=0.3) |
|
|
|
|
|
|
|
|
transcripts = get_transcript(youtube_video_id) |
|
|
|
|
|
|
|
|
grouped_transcripts = group_transcripts_by_scenes(transcripts, scene_changes) |
|
|
|
|
|
return "\n\n".join([f"Scene {i + 1}: {text}" for i, text in enumerate(grouped_transcripts)]) |
|
|
|
|
|
|
|
|
interface = gr.Interface( |
|
|
fn=process_video_and_transcript, |
|
|
inputs=[ |
|
|
gr.Video(label="Upload Video File (.mp4)"), |
|
|
gr.Textbox(label="YouTube Video ID") |
|
|
], |
|
|
outputs="text", |
|
|
title="Scene Change Detection & Transcript Grouping", |
|
|
description="Upload a video file and input a YouTube video ID. The app will detect scene changes in the video and group the transcript text according to these scene changes." |
|
|
) |
|
|
|
|
|
interface.launch() |
|
|
|