File size: 4,858 Bytes
1c604f5
8789469
bf5cec6
746bd87
8789469
746bd87
0e0d262
 
25169b7
746bd87
25169b7
0e0d262
 
25169b7
4b63f9b
746bd87
4b63f9b
 
 
 
0e0d262
 
 
 
 
 
 
 
 
746bd87
25169b7
746bd87
68100f2
 
0e0d262
 
 
 
 
 
 
 
25169b7
0e0d262
 
 
 
 
25169b7
0e0d262
 
 
 
 
 
 
25169b7
0e0d262
 
 
 
 
 
 
 
fe10fa8
0e0d262
 
746bd87
0e0d262
746bd87
25169b7
746bd87
25169b7
0e0d262
 
746bd87
 
 
 
25169b7
8789469
746bd87
 
25169b7
746bd87
0e0d262
25169b7
 
0e0d262
 
 
 
746bd87
25169b7
746bd87
6dc4cef
 
746bd87
 
6dc4cef
746bd87
 
fe10fa8
746bd87
 
fe10fa8
 
0e0d262
 
 
 
746bd87
0e0d262
 
 
 
fe10fa8
0e0d262
 
 
fe10fa8
0e0d262
 
25169b7
0e0d262
 
 
 
fe10fa8
 
 
 
0e0d262
 
 
 
 
 
 
 
 
 
 
 
746bd87
 
25169b7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import streamlit as st
import whisper
import ffmpeg
import pandas as pd
import pickle
import os
import numpy as np
from sentence_transformers import SentenceTransformer
from chromadb import PersistentClient

# Initialize models
embed_model = SentenceTransformer('all-MiniLM-L6-v2')

# Function to extract audio
def extract_audio(uploaded_file):
    audio_path = "temp_audio.wav"
    temp_file = f"temp_{uploaded_file.name}"
    with open(temp_file, "wb") as f:
        f.write(uploaded_file.getvalue())

    try:
        if uploaded_file.name.endswith(('.mp4', '.mkv')):
            ffmpeg.input(temp_file).output(audio_path).run(overwrite_output=True)
        else:
            audio_path = temp_file
        return audio_path, temp_file
    except Exception as e:
        st.error(f"Error extracting audio: {str(e)}")
        return None, None

# Function to transcribe audio
def transcribe_audio(audio_path):
    try:
        model = whisper.load_model("base")
        result = model.transcribe(audio_path)

        subtitles = []
        for i, segment in enumerate(result['segments']):
            start_time = format_timestamp(segment['start'])
            end_time = format_timestamp(segment['end'])
            text = segment['text']
            subtitles.append(f"{i + 1}\n{start_time} --> {end_time}\n{text}\n")

        return subtitles
    except Exception as e:
        st.error(f"Error during transcription: {str(e)}")
        return []

# Timestamp formatting
def format_timestamp(seconds):
    hours = int(seconds // 3600)
    minutes = int((seconds % 3600) // 60)
    secs = int(seconds % 60)
    millis = int((seconds % 1) * 1000)
    return f"{hours:02}:{minutes:02}:{secs:02},{millis:03}"

# Embed subtitles
def embed_subtitles(subtitles):
    raw_texts = [line.split('\n')[2] for line in subtitles if line.strip()]
    embeddings = embed_model.encode(raw_texts)

    df = pd.DataFrame({
        'subtitle': raw_texts,
        'embedding': list(embeddings)
    })

    with open('subtitle_embeddings.pkl', 'wb') as f:
        pickle.dump(df, f)

    return df

# Save embeddings to ChromaDB
def save_to_chroma(embeddings):
    client = PersistentClient(path="./chroma_db")  
    collection = client.create_collection(name="subtitles")

    for idx, row in embeddings.iterrows():
        collection.add(
            documents=[row['subtitle']],
            ids=[str(idx)],
            embeddings=[row['embedding'].tolist()]  # Convert to list
        )
    return collection

# Search subtitles
def search_subtitles(query, collection):
    try:
        query_embedding = embed_model.encode([query]).tolist()
        results = collection.query(query_embeddings=query_embedding, n_results=5)
        return results['documents']
    except Exception as e:
        st.error(f"Error searching subtitles: {str(e)}")
        return []

# Main app
def main():
    st.set_page_config(page_title="Video/Audio Subtitle Generator", layout="wide")
    st.title("🎥🎵 Video/Audio Subtitle Generator")

    with st.sidebar:
        uploaded_file = st.file_uploader("Upload Video/Audio", type=["mp4", "mkv", "mp3", "wav"])
        query = st.text_input("Search Subtitles")
        download_btn = st.button("Download Subtitles")

    if uploaded_file:
        with st.spinner("Extracting audio..."):
            audio_path, temp_file = extract_audio(uploaded_file)

        if audio_path:
            with st.spinner("Generating subtitles..."):
                subtitles = transcribe_audio(audio_path)
                st.success("Subtitles Generated!")

            if uploaded_file.name.endswith(('.mp4', '.mkv')):
                st.video(uploaded_file)
            else:
                st.audio(uploaded_file)

            st.write("### Generated Subtitles:")
            for sub in subtitles:
                st.text(sub)

            with st.spinner("Embedding and storing subtitles..."):
                embeddings = embed_subtitles(subtitles)

                if embeddings.empty:
                    st.warning("No subtitles generated.")
                else:
                    collection = save_to_chroma(embeddings)

            if query:
                results = search_subtitles(query, collection)
                st.write("### Matching Subtitles:")
                if results:
                    for idx, sub in enumerate(results, start=1):
                        st.write(f"{idx}. {sub}")
                else:
                    st.warning("No matching subtitles found.")

            if download_btn:
                with open("generated_subtitles.srt", "w") as f:
                    f.writelines(subtitles)

                with open("generated_subtitles.srt", "rb") as f:
                    st.download_button("Download SRT", f, file_name="generated_subtitles.srt", mime="text/plain")

if __name__ == '__main__':
    main()