File size: 5,262 Bytes
47852a7
e0b4051
47852a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340c2ff
47852a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import os
import cv2
import numpy as np
import gradio as gr
from segment_anything import sam_model_registry, SamPredictor
from youtube_transcript_api import YouTubeTranscriptApi

def video_to_frames(video_path, output_dir, frame_rate=0.7):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    cap = cv2.VideoCapture(video_path)
    fps = cap.get(cv2.CAP_PROP_FPS)
    frame_interval = int(fps / frame_rate)
    frame_count = 0
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        if frame_count % frame_interval == 0:
            cv2.imwrite(os.path.join(output_dir, f'frame_{frame_count:05d}.jpg'), frame)
        frame_count += 1
    cap.release()
    return fps

def select_background_points(image, num_points=4):
    h, w, _ = image.shape
    points = np.array([
        [0, 0],  # top-left corner
        [0, w - 1],  # top-right corner
        [h - 1, 0],  # bottom-left corner
        [h - 1, w - 1]  # bottom-right corner
    ])
    
    if num_points > 4:
        points = np.vstack([points, 
                            [0, w // 2], 
                            [h // 2, 0], 
                            [h - 1, w // 2], 
                            [h // 2, w - 1]])
    
    return points

def compare_histograms(frame1, frame2, threshold=0.4):
    hist1 = cv2.calcHist([frame1], [0, 1, 2], None, [8, 8, 8], [0, 256, 0, 256, 0, 256])
    hist2 = cv2.calcHist([frame2], [0, 1, 2], None, [8, 8, 8], [0, 256, 0, 256, 0, 256])
    hist1 = cv2.normalize(hist1, hist1).flatten()
    hist2 = cv2.normalize(hist2, hist2).flatten()
    diff = cv2.compareHist(hist1, hist2, cv2.HISTCMP_CORREL)
    return diff < threshold

def detect_scene_changes(frame_dir, fps, threshold=0.15, hist_threshold=0.3):
    frames = sorted(os.listdir(frame_dir))
    scene_changes = []
    prev_mask = None
    prev_frame = None

    for i, frame_name in enumerate(frames):
        frame = cv2.imread(os.path.join(frame_dir, frame_name))
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        predictor.set_image(frame_rgb)
        background_points = select_background_points(frame_rgb)
        point_labels = np.zeros(background_points.shape[0], dtype=int)  # Label points as background (0)
        masks, _, _ = predictor.predict(point_coords=background_points, 
                                        point_labels=point_labels, 
                                        multimask_output=False)
        mask_diff = 0
        if prev_mask is not None:
            mask_diff = np.logical_xor(masks[0], prev_mask).mean()
        hist_diff = False
        if prev_frame is not None:
            hist_diff = compare_histograms(prev_frame, frame, threshold=hist_threshold)
        
        if mask_diff > threshold or hist_diff:
            timestamp = int(frame_name.split('_')[1].split('.')[0]) / fps
            scene_changes.append(timestamp)
        
        prev_mask = masks[0]
        prev_frame = frame
    
    return scene_changes

def get_transcript(video_id):
    try:
        transcript = YouTubeTranscriptApi.get_transcript(video_id)
        return transcript
    except Exception as e:
        return []

def group_transcripts_by_scenes(transcripts, scene_changes):
    grouped_transcripts = []
    scene_index = 0
    current_group = []

    for transcript in transcripts:
        start_time = transcript['start']
        if scene_index < len(scene_changes) and start_time > scene_changes[scene_index]:
            grouped_transcripts.append(' '.join([t['text'] for t in current_group]))
            current_group = []
            scene_index += 1
        current_group.append(transcript)
    
    if current_group:
        grouped_transcripts.append(' '.join([t['text'] for t in current_group]))
    
    return grouped_transcripts

def process_video_and_transcript(video_file, youtube_video_id):
    output_dir = "Output_frames"
    
    # Save the uploaded video to a temporary location
    video_path = os.path.join(output_dir, "uploaded_video.mp4")
    with open(video_path, "wb") as f:
        f.write(video_file.read())
    
    fps = video_to_frames(video_path, output_dir, frame_rate=0.7)
    
    # Initialize the SAM predictor
    model = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
    global predictor
    predictor = SamPredictor(model)
    
    # Detect scene changes
    scene_changes = detect_scene_changes(output_dir, fps, threshold=0.15, hist_threshold=0.3)
    
    # Get YouTube transcript
    transcripts = get_transcript(youtube_video_id)
    
    # Group transcripts by scene changes
    grouped_transcripts = group_transcripts_by_scenes(transcripts, scene_changes)
    
    return "\n\n".join([f"Scene {i + 1}: {text}" for i, text in enumerate(grouped_transcripts)])

# Gradio Interface
interface = gr.Interface(
    fn=process_video_and_transcript,
    inputs=[
        gr.Video(label="Upload Video File (.mp4)"),
        gr.Textbox(label="YouTube Video ID")
    ],
    outputs="text",
    title="Scene Change Detection & Transcript Grouping",
    description="Upload a video file and input a YouTube video ID. The app will detect scene changes in the video and group the transcript text according to these scene changes."
)

interface.launch()