import streamlit as st import whisper import ffmpeg import pandas as pd import pickle import os import numpy as np from sentence_transformers import SentenceTransformer from chromadb import PersistentClient # Initialize models embed_model = SentenceTransformer('all-MiniLM-L6-v2') # Function to extract audio def extract_audio(uploaded_file): audio_path = "temp_audio.wav" temp_file = f"temp_{uploaded_file.name}" with open(temp_file, "wb") as f: f.write(uploaded_file.getvalue()) try: if uploaded_file.name.endswith(('.mp4', '.mkv')): ffmpeg.input(temp_file).output(audio_path).run(overwrite_output=True) else: audio_path = temp_file return audio_path, temp_file except Exception as e: st.error(f"Error extracting audio: {str(e)}") return None, None # Function to transcribe audio def transcribe_audio(audio_path): try: model = whisper.load_model("base") result = model.transcribe(audio_path) subtitles = [] for i, segment in enumerate(result['segments']): start_time = format_timestamp(segment['start']) end_time = format_timestamp(segment['end']) text = segment['text'] subtitles.append(f"{i + 1}\n{start_time} --> {end_time}\n{text}\n") return subtitles except Exception as e: st.error(f"Error during transcription: {str(e)}") return [] # Timestamp formatting def format_timestamp(seconds): hours = int(seconds // 3600) minutes = int((seconds % 3600) // 60) secs = int(seconds % 60) millis = int((seconds % 1) * 1000) return f"{hours:02}:{minutes:02}:{secs:02},{millis:03}" # Embed subtitles def embed_subtitles(subtitles): raw_texts = [line.split('\n')[2] for line in subtitles if line.strip()] embeddings = embed_model.encode(raw_texts) df = pd.DataFrame({ 'subtitle': raw_texts, 'embedding': list(embeddings) }) with open('subtitle_embeddings.pkl', 'wb') as f: pickle.dump(df, f) return df # Save embeddings to ChromaDB def save_to_chroma(embeddings): client = PersistentClient(path="./chroma_db") collection = client.create_collection(name="subtitles") for idx, row in embeddings.iterrows(): collection.add( documents=[row['subtitle']], ids=[str(idx)], embeddings=[row['embedding'].tolist()] # Convert to list ) return collection # Search subtitles def search_subtitles(query, collection): try: query_embedding = embed_model.encode([query]).tolist() results = collection.query(query_embeddings=query_embedding, n_results=5) return results['documents'] except Exception as e: st.error(f"Error searching subtitles: {str(e)}") return [] # Main app def main(): st.set_page_config(page_title="Video/Audio Subtitle Generator", layout="wide") st.title("🎥🎵 Video/Audio Subtitle Generator") with st.sidebar: uploaded_file = st.file_uploader("Upload Video/Audio", type=["mp4", "mkv", "mp3", "wav"]) query = st.text_input("Search Subtitles") download_btn = st.button("Download Subtitles") if uploaded_file: with st.spinner("Extracting audio..."): audio_path, temp_file = extract_audio(uploaded_file) if audio_path: with st.spinner("Generating subtitles..."): subtitles = transcribe_audio(audio_path) st.success("Subtitles Generated!") if uploaded_file.name.endswith(('.mp4', '.mkv')): st.video(uploaded_file) else: st.audio(uploaded_file) st.write("### Generated Subtitles:") for sub in subtitles: st.text(sub) with st.spinner("Embedding and storing subtitles..."): embeddings = embed_subtitles(subtitles) if embeddings.empty: st.warning("No subtitles generated.") else: collection = save_to_chroma(embeddings) if query: results = search_subtitles(query, collection) st.write("### Matching Subtitles:") if results: for idx, sub in enumerate(results, start=1): st.write(f"{idx}. {sub}") else: st.warning("No matching subtitles found.") if download_btn: with open("generated_subtitles.srt", "w") as f: f.writelines(subtitles) with open("generated_subtitles.srt", "rb") as f: st.download_button("Download SRT", f, file_name="generated_subtitles.srt", mime="text/plain") if __name__ == '__main__': main()