Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from engine import DescribeVideo, GenerateAudio | |
| import os | |
| from moviepy.editor import VideoFileClip, AudioFileClip, CompositeAudioClip | |
| from moviepy.audio.fx.volumex import volumex | |
| from streamlit.runtime.scriptrunner import get_script_run_ctx | |
| def get_session_id(): | |
| session_id = get_script_run_ctx().session_id | |
| session_id = session_id.replace("-", "_") | |
| session_id = "_id_" + session_id | |
| return session_id | |
| user_session_id = get_session_id() | |
| os.makedirs(user_session_id, exist_ok=True) | |
| # Define model maps | |
| video_model_map = { | |
| "Fast": "flash", | |
| "Quality": "pro", | |
| } | |
| music_model_map = { | |
| "Fast": "musicgen-stereo-small", | |
| "Balanced": "musicgen-stereo-medium", | |
| "Quality": "musicgen-stereo-large", | |
| } | |
| # music_model_map = { | |
| # "Fast": "facebook/musicgen-melody", | |
| # "Quality": "facebook/musicgen-melody-large", | |
| # } | |
| genre_map = { | |
| "None": None, | |
| "Pop": "Pop", | |
| "Rock": "Rock", | |
| "Hip Hop": "Hip-Hop/Rap", | |
| "Jazz": "Jazz", | |
| "Classical": "Classical", | |
| "Blues": "Blues", | |
| "Country": "Country", | |
| "EDM": "Electronic/Dance", | |
| "Metal": "Metal", | |
| "Disco": "Disco", | |
| "Lo-Fi": "Lo-Fi", | |
| } | |
| # Streamlit page configuration | |
| st.set_page_config( | |
| page_title="VidTune: Where Videos Find Their Melody", | |
| layout="centered", | |
| page_icon="assets/favicon.png", | |
| ) | |
| left_co, cent_co, last_co = st.columns(3) | |
| with cent_co: | |
| st.image("assets/VidTune-Logo-Without-BG.png", use_column_width=False, width=200) | |
| # Title and Description | |
| st.markdown( | |
| """ | |
| <style> | |
| h2, p, div, img { | |
| text-align: center; | |
| } | |
| </style> | |
| <div style="font-size: 35px; font-weight: bold;">VidTune: Where Videos Find Their Melody</div> | |
| <p>VidTune is a web application to effortlessly tailor perfect soundtracks for your videos with AI.</p> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| # Initialize session state for advanced settings and other inputs | |
| if "show_advanced" not in st.session_state: | |
| st.session_state.show_advanced = False | |
| if "video_model" not in st.session_state: | |
| st.session_state.video_model = "Fast" | |
| if "music_model" not in st.session_state: | |
| st.session_state.music_model = "Fast" | |
| if "num_samples" not in st.session_state: | |
| st.session_state.num_samples = 3 | |
| if "music_genre" not in st.session_state: | |
| st.session_state.music_genre = None | |
| if "music_bpm" not in st.session_state: | |
| st.session_state.music_bpm = 100 | |
| if "user_keywords" not in st.session_state: | |
| st.session_state.user_keywords = None | |
| if "selected_audio" not in st.session_state: | |
| st.session_state.selected_audio = "None" | |
| if "audio_paths" not in st.session_state: | |
| st.session_state.audio_paths = [] | |
| if "selected_audio_path" not in st.session_state: | |
| st.session_state.selected_audio_path = None | |
| if "orig_audio_vol" not in st.session_state: | |
| st.session_state.orig_audio_vol = 100 | |
| if "generated_audio_vol" not in st.session_state: | |
| st.session_state.generated_audio_vol = 100 | |
| if "generate_button_flag" not in st.session_state: | |
| st.session_state.generate_button_flag = False | |
| if "video_description_content" not in st.session_state: | |
| st.session_state.video_description_content = "" | |
| if "music_prompt" not in st.session_state: | |
| st.session_state.music_prompt = "" | |
| if "audio_mix_flag" not in st.session_state: | |
| st.session_state.audio_mix_flag = False | |
| if "google_api_key" not in st.session_state: | |
| st.session_state.google_api_key = "" | |
| # Sidebar | |
| st.sidebar.title("Configuration") | |
| # Google API Key | |
| st.session_state.google_api_key = st.sidebar.text_input( | |
| "Enter your [Google API Key](https://ai.google.dev/gemini-api/docs/api-key) to get started :", | |
| st.session_state.google_api_key, | |
| type="password", | |
| ) | |
| if not st.session_state.google_api_key: | |
| st.warning("Please enter your Google API Key to proceed.") | |
| st.stop() | |
| # Basic Settings | |
| st.session_state.video_model = st.sidebar.selectbox( | |
| "Select Video Descriptor", | |
| ["Fast", "Quality"], | |
| index=["Fast", "Quality"].index(st.session_state.video_model), | |
| ) | |
| st.session_state.music_model = st.sidebar.selectbox( | |
| "Select Music Generator", | |
| ["Fast", "Balanced", "Quality"], | |
| index=["Fast", "Balanced", "Quality"].index(st.session_state.music_model), | |
| ) | |
| st.session_state.num_samples = st.sidebar.slider( | |
| "Number of samples", 1, 5, st.session_state.num_samples | |
| ) | |
| # Sidebar for advanced settings | |
| with st.sidebar: | |
| # Create a placeholder for the advanced settings button | |
| placeholder = st.empty() | |
| # Button to toggle advanced settings | |
| if placeholder.button("Advanced"): | |
| st.session_state.show_advanced = not st.session_state.show_advanced | |
| st.rerun() # Refresh the layout after button click | |
| # Display advanced settings if enabled | |
| if st.session_state.show_advanced: | |
| # Advanced settings | |
| st.session_state.music_bpm = st.sidebar.slider("Beats Per Minute", 35, 180, 100) | |
| st.session_state.music_genre = st.sidebar.selectbox( | |
| "Select Music Genre", | |
| list(genre_map.keys()), | |
| index=( | |
| list(genre_map.keys()).index(st.session_state.music_genre) | |
| if st.session_state.music_genre in genre_map.keys() | |
| else 0 | |
| ), | |
| ) | |
| st.session_state.user_keywords = st.sidebar.text_input( | |
| "User Keywords", | |
| value=st.session_state.user_keywords, | |
| help="Enter keywords separated by commas.", | |
| ) | |
| else: | |
| st.session_state.music_genre = None | |
| st.session_state.music_bpm = None | |
| st.session_state.user_keywords = None | |
| # Generate Button | |
| generate_button = st.sidebar.button("Generate Music") | |
| # Cache the model loading | |
| def load_models(video_model_key, music_model_key, google_api_key): | |
| video_descriptor = DescribeVideo( | |
| model=video_model_map[video_model_key], google_api_key=google_api_key | |
| ) | |
| audio_generator = GenerateAudio(model=music_model_map[music_model_key]) | |
| if audio_generator.device == "cpu": | |
| st.warning( | |
| "The music generator model is running on CPU. For faster results, consider using a GPU." | |
| ) | |
| return video_descriptor, audio_generator | |
| # Load models | |
| video_descriptor, audio_generator = load_models( | |
| st.session_state.video_model, | |
| st.session_state.music_model, | |
| st.session_state.google_api_key, | |
| ) | |
| # Video Uploader | |
| uploaded_video = st.file_uploader("Upload Video", type=["mp4"]) | |
| if uploaded_video is not None: | |
| st.session_state.uploaded_video = uploaded_video | |
| with open(f"{user_session_id}/temp.mp4", mode="wb") as w: | |
| w.write(uploaded_video.getvalue()) | |
| # Video Player | |
| if os.path.exists(f"{user_session_id}/temp.mp4") and uploaded_video is not None: | |
| st.video(uploaded_video) | |
| # Submit button if video is not uploaded | |
| if generate_button: | |
| if uploaded_video is None: | |
| st.error("Please upload a video before generating music.") | |
| st.stop() | |
| with st.spinner("Analyzing video..."): | |
| video_description = video_descriptor.describe_video( | |
| f"{user_session_id}/temp.mp4", | |
| genre=st.session_state.music_genre, | |
| bpm=st.session_state.music_bpm, | |
| user_keywords=st.session_state.user_keywords, | |
| ) | |
| video_duration = VideoFileClip(f"{user_session_id}/temp.mp4").duration | |
| st.session_state.video_description_content = video_description[ | |
| "Content Description" | |
| ] | |
| st.session_state.music_prompt = video_description["Music Prompt"] | |
| st.success("Video description generated successfully.") | |
| st.session_state.generate_button_flag = True | |
| # Display Video Description and Music Prompt | |
| if st.session_state.generate_button_flag: | |
| st.text_area( | |
| "Video Description", | |
| st.session_state.video_description_content, | |
| disabled=True, | |
| height=120, | |
| ) | |
| music_prompt = st.text_area( | |
| "Music Prompt", | |
| st.session_state.music_prompt, | |
| disabled=True, | |
| height=120, | |
| ) | |
| if generate_button: | |
| # Generate Music | |
| with st.spinner("Generating music..."): | |
| if video_duration > 30: | |
| st.warning( | |
| "Due to hardware limitations, the maximum music length is capped at 30 seconds." | |
| ) | |
| music_prompt = [st.session_state.music_prompt] * st.session_state.num_samples | |
| audio_generator.generate_audio(music_prompt, duration=video_duration) | |
| st.session_state.audio_paths = audio_generator.save_audio() | |
| st.success("Music generated successfully.") | |
| st.balloons() | |
| # Callback function for radio button selection change | |
| def on_audio_selection_change(): | |
| st.session_state.audio_mix_flag = False | |
| selected_audio_index = st.session_state.selected_audio | |
| if selected_audio_index > 0: | |
| st.session_state.selected_audio_path = st.session_state.audio_paths[ | |
| selected_audio_index - 1 | |
| ] | |
| else: | |
| st.session_state.selected_audio_path = None | |
| if st.session_state.audio_paths: | |
| # Dropdown to select one of the generated audio files | |
| audio_options = ["None"] + [ | |
| f"Generated Music {i+1}" for i in range(len(st.session_state.audio_paths)) | |
| ] | |
| # Display the audio files | |
| for i, audio_path in enumerate(st.session_state.audio_paths): | |
| st.audio(audio_path, format="audio/wav") | |
| selected_audio_index = st.selectbox( | |
| "Select one of the generated audio files for further processing:", | |
| range(len(audio_options)), | |
| format_func=lambda x: audio_options[x], | |
| index=0, | |
| key="selected_audio", | |
| on_change=on_audio_selection_change, | |
| ) | |
| # Button to confirm the selection | |
| if st.button("Add Generated Music to Video"): | |
| st.session_state.audio_mix_flag = True | |
| # Handle Audio Mixing and Export | |
| if st.session_state.selected_audio_path is not None and st.session_state.audio_mix_flag: | |
| with st.spinner("Mixing Audio..."): | |
| orig_clip = VideoFileClip(f"{user_session_id}/temp.mp4") | |
| orig_clip_audio = orig_clip.audio | |
| generated_audio = AudioFileClip(st.session_state.selected_audio_path) | |
| st.session_state.orig_audio_vol = st.slider( | |
| "Original Audio Volume", | |
| 0, | |
| 200, | |
| st.session_state.orig_audio_vol, | |
| format="%d%%", | |
| ) | |
| st.session_state.generated_audio_vol = st.slider( | |
| "Generated Music Volume", | |
| 0, | |
| 200, | |
| st.session_state.generated_audio_vol, | |
| format="%d%%", | |
| ) | |
| orig_clip_audio = volumex( | |
| orig_clip_audio, float(st.session_state.orig_audio_vol / 100) | |
| ) | |
| generated_audio = volumex( | |
| generated_audio, float(st.session_state.generated_audio_vol / 100) | |
| ) | |
| orig_clip.audio = CompositeAudioClip([orig_clip_audio, generated_audio]) | |
| final_video_path = f"{user_session_id}/out_tmp.mp4" | |
| orig_clip.write_videofile(final_video_path) | |
| orig_clip.close() | |
| generated_audio.close() | |
| st.session_state.final_video_path = final_video_path | |
| st.video(final_video_path) | |
| if st.session_state.final_video_path: | |
| with open(st.session_state.final_video_path, "rb") as video_file: | |
| st.download_button( | |
| label="Download final video", | |
| data=video_file, | |
| file_name="final_video.mp4", | |
| mime="video/mp4", | |
| ) | |