Spaces:
Sleeping
Sleeping
| """ | |
| File-to-Music generation interface | |
| """ | |
| import gradio as gr | |
| import numpy as np | |
| from typing import Optional | |
| import logging | |
| from models.model_manager import ModelManager | |
| from utils.ui_components import UIComponents | |
| from utils.audio_processor import AudioProcessor | |
| from utils.file_handler import FileHandler | |
| logger = logging.getLogger(__name__) | |
| class FileToMusicInterface: | |
| def __init__(self, model_manager: ModelManager): | |
| self.model_manager = model_manager | |
| self.audio_processor = AudioProcessor() | |
| self.file_handler = FileHandler() | |
| def create_interface(self) -> gr.Interface: | |
| """Create the file-to-music interface""" | |
| with gr.Group(): | |
| gr.Markdown("## ๐น File-to-Music Generation") | |
| gr.Markdown("Upload an audio file to use as inspiration or conditioning") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| file_input = gr.Audio( | |
| label="Upload Audio File", | |
| type="filepath", | |
| source="upload" | |
| ) | |
| with gr.Row(): | |
| model_dropdown = UIComponents.create_model_dropdown() | |
| style_dropdown = gr.Dropdown( | |
| choices=[ | |
| ("Similar Style", "similar"), | |
| ("Different Style", "different"), | |
| ("Enhanced", "enhanced"), | |
| ("Remix", "remix") | |
| ], | |
| value="similar", | |
| label="Processing Style" | |
| ) | |
| with gr.Row(): | |
| duration_slider = UIComponents.create_duration_slider() | |
| guidance_slider = UIComponents.create_guidance_slider() | |
| # Conditioning options | |
| with gr.Accordion("Conditioning Options", open=True): | |
| conditioning_strength = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.7, | |
| label="Conditioning Strength" | |
| ) | |
| pitch_shift = gr.Slider( | |
| minimum=-12, | |
| maximum=12, | |
| value=0, | |
| step=1, | |
| label="Pitch Shift (semitones)" | |
| ) | |
| tempo_change = gr.Slider( | |
| minimum=0.5, | |
| maximum=2.0, | |
| value=1.0, | |
| step=0.1, | |
| label="Tempo Multiplier" | |
| ) | |
| with gr.Row(): | |
| generate_btn = gr.Button("๐ต Generate Music", variant="primary", scale=2) | |
| analyze_btn = gr.Button("๐ Analyze Audio", variant="secondary") | |
| with gr.Column(scale=3): | |
| audio_output = UIComponents.create_audio_player("Generated Music") | |
| original_player = UIComponents.create_audio_player("Original Audio") | |
| with gr.Row(): | |
| download_btn = gr.DownloadButton("๐พ Download", variant="secondary") | |
| compare_btn = gr.Button("โ๏ธ Compare", variant="secondary") | |
| # Analysis results | |
| with gr.Accordion("Audio Analysis", open=False): | |
| with gr.Row(): | |
| tempo_text = gr.Textbox(label="Tempo (BPM)", interactive=False) | |
| key_text = gr.Textbox(label="Estimated Key", interactive=False) | |
| with gr.Row(): | |
| energy_text = gr.Textbox(label="Energy Level", interactive=False) | |
| mood_text = gr.Textbox(label="Mood", interactive=False) | |
| # Comparison tab | |
| with gr.Tab("Compare Original vs Generated"): | |
| with gr.Row(): | |
| original_plot = gr.Plot(label="Original Waveform") | |
| generated_plot = gr.Plot(label="Generated Waveform") | |
| with gr.Row(): | |
| original_spec = gr.Plot(label="Original Spectrogram") | |
| generated_spec = gr.Plot(label="Generated Spectrogram") | |
| # Examples | |
| gr.Examples( | |
| examples=[ | |
| ["demo_files/example1.wav"], | |
| ["demo_files/example2.wav"], | |
| ["demo_files/example3.wav"] | |
| ], | |
| inputs=file_input, | |
| label="Example Audio Files" | |
| ) | |
| # Event handlers | |
| file_input.change( | |
| fn=self.analyze_uploaded_audio, | |
| inputs=file_input, | |
| outputs=[original_player, tempo_text, key_text, energy_text, mood_text, original_plot, original_spec] | |
| ) | |
| generate_btn.click( | |
| fn=self.generate_music_from_file, | |
| inputs=[ | |
| file_input, model_dropdown, duration_slider, guidance_slider, | |
| conditioning_strength, pitch_shift, tempo_change | |
| ], | |
| outputs=[audio_output, generated_plot, generated_spec] | |
| ) | |
| analyze_btn.click( | |
| fn=self.analyze_uploaded_audio, | |
| inputs=file_input, | |
| outputs=[original_player, tempo_text, key_text, energy_text, mood_text, original_plot, original_spec] | |
| ) | |
| return file_input | |
| def analyze_uploaded_audio(self, file_path: str): | |
| """Analyze uploaded audio file""" | |
| try: | |
| if not file_path: | |
| raise gr.Error("Please upload an audio file") | |
| # Load and process audio | |
| audio_array, sr = self.audio_processor.load_audio(file_path) | |
| audio_array = self.audio_processor.normalize_audio(audio_array) | |
| # Create visualizations | |
| waveform_fig = UIComponents.create_audio_visualization(audio_array) | |
| spectrogram_fig = UIComponents.create_spectrogram_visualization(audio_array, sr) | |
| # Analyze audio | |
| tempo = self.audio_processor.get_tempo(audio_array) | |
| key = self._estimate_key(audio_array) | |
| energy = self._calculate_energy(audio_array) | |
| mood = self._estimate_mood(audio_array) | |
| return ( | |
| file_path, | |
| f"{tempo:.1f}", | |
| key, | |
| f"{energy:.2f}", | |
| mood, | |
| waveform_fig, | |
| spectrogram_fig | |
| ) | |
| except Exception as e: | |
| logger.error(f"Audio analysis failed: {str(e)}") | |
| raise gr.Error(f"Analysis failed: {str(e)}") | |
| def generate_music_from_file( | |
| self, | |
| file_path: str, | |
| model_name: str, | |
| duration: int, | |
| guidance_scale: float, | |
| conditioning_strength: float, | |
| pitch_shift: int, | |
| tempo_multiplier: float | |
| ): | |
| """Generate music from uploaded file""" | |
| try: | |
| if not file_path: | |
| raise gr.Error("Please upload an audio file") | |
| # Load and process audio | |
| audio_array, sr = self.audio_processor.load_audio(file_path) | |
| # Apply modifications | |
| if pitch_shift != 0: | |
| audio_array = self.audio_processor.change_pitch(audio_array, pitch_shift) | |
| if tempo_multiplier != 1.0: | |
| audio_array = self.audio_processor.change_speed(audio_array, tempo_multiplier) | |
| audio_array = self.audio_processor.normalize_audio(audio_array) | |
| # Get model | |
| model = self.model_manager.get_model(model_name) | |
| if not model: | |
| raise gr.Error(f"Model {model_name} not available") | |
| # Generate music | |
| logger.info("Generating music from audio file...") | |
| generated_audio = model.generate_from_audio( | |
| audio_array=audio_array, | |
| duration=duration, | |
| guidance_scale=guidance_scale | |
| ) | |
| # Apply conditioning strength | |
| if conditioning_strength |