| import os |
| import cv2 |
| import numpy as np |
| import gradio as gr |
| from segment_anything import sam_model_registry, SamPredictor |
| from youtube_transcript_api import YouTubeTranscriptApi |
|
|
| def video_to_frames(video_path, output_dir, frame_rate=0.7): |
| if not os.path.exists(output_dir): |
| os.makedirs(output_dir) |
| cap = cv2.VideoCapture(video_path) |
| fps = cap.get(cv2.CAP_PROP_FPS) |
| frame_interval = int(fps / frame_rate) |
| frame_count = 0 |
| while True: |
| ret, frame = cap.read() |
| if not ret: |
| break |
| if frame_count % frame_interval == 0: |
| cv2.imwrite(os.path.join(output_dir, f'frame_{frame_count:05d}.jpg'), frame) |
| frame_count += 1 |
| cap.release() |
| return fps |
|
|
| def select_background_points(image, num_points=4): |
| h, w, _ = image.shape |
| points = np.array([ |
| [0, 0], |
| [0, w - 1], |
| [h - 1, 0], |
| [h - 1, w - 1] |
| ]) |
| |
| if num_points > 4: |
| points = np.vstack([points, |
| [0, w // 2], |
| [h // 2, 0], |
| [h - 1, w // 2], |
| [h // 2, w - 1]]) |
| |
| return points |
|
|
| def compare_histograms(frame1, frame2, threshold=0.4): |
| hist1 = cv2.calcHist([frame1], [0, 1, 2], None, [8, 8, 8], [0, 256, 0, 256, 0, 256]) |
| hist2 = cv2.calcHist([frame2], [0, 1, 2], None, [8, 8, 8], [0, 256, 0, 256, 0, 256]) |
| hist1 = cv2.normalize(hist1, hist1).flatten() |
| hist2 = cv2.normalize(hist2, hist2).flatten() |
| diff = cv2.compareHist(hist1, hist2, cv2.HISTCMP_CORREL) |
| return diff < threshold |
|
|
| def detect_scene_changes(frame_dir, fps, threshold=0.15, hist_threshold=0.3): |
| frames = sorted(os.listdir(frame_dir)) |
| scene_changes = [] |
| prev_mask = None |
| prev_frame = None |
|
|
| for i, frame_name in enumerate(frames): |
| frame = cv2.imread(os.path.join(frame_dir, frame_name)) |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| predictor.set_image(frame_rgb) |
| background_points = select_background_points(frame_rgb) |
| point_labels = np.zeros(background_points.shape[0], dtype=int) |
| masks, _, _ = predictor.predict(point_coords=background_points, |
| point_labels=point_labels, |
| multimask_output=False) |
| mask_diff = 0 |
| if prev_mask is not None: |
| mask_diff = np.logical_xor(masks[0], prev_mask).mean() |
| hist_diff = False |
| if prev_frame is not None: |
| hist_diff = compare_histograms(prev_frame, frame, threshold=hist_threshold) |
| |
| if mask_diff > threshold or hist_diff: |
| timestamp = int(frame_name.split('_')[1].split('.')[0]) / fps |
| scene_changes.append(timestamp) |
| |
| prev_mask = masks[0] |
| prev_frame = frame |
| |
| return scene_changes |
|
|
| def get_transcript(video_id): |
| try: |
| transcript = YouTubeTranscriptApi.get_transcript(video_id) |
| return transcript |
| except Exception as e: |
| return [] |
|
|
| def group_transcripts_by_scenes(transcripts, scene_changes): |
| grouped_transcripts = [] |
| scene_index = 0 |
| current_group = [] |
|
|
| for transcript in transcripts: |
| start_time = transcript['start'] |
| if scene_index < len(scene_changes) and start_time > scene_changes[scene_index]: |
| grouped_transcripts.append(' '.join([t['text'] for t in current_group])) |
| current_group = [] |
| scene_index += 1 |
| current_group.append(transcript) |
| |
| if current_group: |
| grouped_transcripts.append(' '.join([t['text'] for t in current_group])) |
| |
| return grouped_transcripts |
|
|
| def process_video_and_transcript(video_file, youtube_video_id): |
| output_dir = "Output_frames" |
| |
| |
| video_path = os.path.join(output_dir, "uploaded_video.mp4") |
| with open(video_path, "wb") as f: |
| f.write(video_file.read()) |
| |
| fps = video_to_frames(video_path, output_dir, frame_rate=0.7) |
| |
| |
| model = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth") |
| global predictor |
| predictor = SamPredictor(model) |
| |
| |
| scene_changes = detect_scene_changes(output_dir, fps, threshold=0.15, hist_threshold=0.3) |
| |
| |
| transcripts = get_transcript(youtube_video_id) |
| |
| |
| grouped_transcripts = group_transcripts_by_scenes(transcripts, scene_changes) |
| |
| return "\n\n".join([f"Scene {i + 1}: {text}" for i, text in enumerate(grouped_transcripts)]) |
|
|
| |
| interface = gr.Interface( |
| fn=process_video_and_transcript, |
| inputs=[ |
| gr.Video(label="Upload Video File (.mp4)"), |
| gr.Textbox(label="YouTube Video ID") |
| ], |
| outputs="text", |
| title="Scene Change Detection & Transcript Grouping", |
| description="Upload a video file and input a YouTube video ID. The app will detect scene changes in the video and group the transcript text according to these scene changes." |
| ) |
|
|
| interface.launch() |
|
|