ljsjdwe / interfaces /file_to_music.py
kepsmiling121's picture
Create interfaces/file_to_music.py
027ce5c verified
"""
File-to-Music generation interface
"""
import gradio as gr
import numpy as np
from typing import Optional
import logging
from models.model_manager import ModelManager
from utils.ui_components import UIComponents
from utils.audio_processor import AudioProcessor
from utils.file_handler import FileHandler
logger = logging.getLogger(__name__)
class FileToMusicInterface:
def __init__(self, model_manager: ModelManager):
self.model_manager = model_manager
self.audio_processor = AudioProcessor()
self.file_handler = FileHandler()
def create_interface(self) -> gr.Interface:
"""Create the file-to-music interface"""
with gr.Group():
gr.Markdown("## ๐ŸŽน File-to-Music Generation")
gr.Markdown("Upload an audio file to use as inspiration or conditioning")
with gr.Row():
with gr.Column(scale=2):
file_input = gr.Audio(
label="Upload Audio File",
type="filepath",
source="upload"
)
with gr.Row():
model_dropdown = UIComponents.create_model_dropdown()
style_dropdown = gr.Dropdown(
choices=[
("Similar Style", "similar"),
("Different Style", "different"),
("Enhanced", "enhanced"),
("Remix", "remix")
],
value="similar",
label="Processing Style"
)
with gr.Row():
duration_slider = UIComponents.create_duration_slider()
guidance_slider = UIComponents.create_guidance_slider()
# Conditioning options
with gr.Accordion("Conditioning Options", open=True):
conditioning_strength = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.7,
label="Conditioning Strength"
)
pitch_shift = gr.Slider(
minimum=-12,
maximum=12,
value=0,
step=1,
label="Pitch Shift (semitones)"
)
tempo_change = gr.Slider(
minimum=0.5,
maximum=2.0,
value=1.0,
step=0.1,
label="Tempo Multiplier"
)
with gr.Row():
generate_btn = gr.Button("๐ŸŽต Generate Music", variant="primary", scale=2)
analyze_btn = gr.Button("๐Ÿ” Analyze Audio", variant="secondary")
with gr.Column(scale=3):
audio_output = UIComponents.create_audio_player("Generated Music")
original_player = UIComponents.create_audio_player("Original Audio")
with gr.Row():
download_btn = gr.DownloadButton("๐Ÿ’พ Download", variant="secondary")
compare_btn = gr.Button("โš–๏ธ Compare", variant="secondary")
# Analysis results
with gr.Accordion("Audio Analysis", open=False):
with gr.Row():
tempo_text = gr.Textbox(label="Tempo (BPM)", interactive=False)
key_text = gr.Textbox(label="Estimated Key", interactive=False)
with gr.Row():
energy_text = gr.Textbox(label="Energy Level", interactive=False)
mood_text = gr.Textbox(label="Mood", interactive=False)
# Comparison tab
with gr.Tab("Compare Original vs Generated"):
with gr.Row():
original_plot = gr.Plot(label="Original Waveform")
generated_plot = gr.Plot(label="Generated Waveform")
with gr.Row():
original_spec = gr.Plot(label="Original Spectrogram")
generated_spec = gr.Plot(label="Generated Spectrogram")
# Examples
gr.Examples(
examples=[
["demo_files/example1.wav"],
["demo_files/example2.wav"],
["demo_files/example3.wav"]
],
inputs=file_input,
label="Example Audio Files"
)
# Event handlers
file_input.change(
fn=self.analyze_uploaded_audio,
inputs=file_input,
outputs=[original_player, tempo_text, key_text, energy_text, mood_text, original_plot, original_spec]
)
generate_btn.click(
fn=self.generate_music_from_file,
inputs=[
file_input, model_dropdown, duration_slider, guidance_slider,
conditioning_strength, pitch_shift, tempo_change
],
outputs=[audio_output, generated_plot, generated_spec]
)
analyze_btn.click(
fn=self.analyze_uploaded_audio,
inputs=file_input,
outputs=[original_player, tempo_text, key_text, energy_text, mood_text, original_plot, original_spec]
)
return file_input
def analyze_uploaded_audio(self, file_path: str):
"""Analyze uploaded audio file"""
try:
if not file_path:
raise gr.Error("Please upload an audio file")
# Load and process audio
audio_array, sr = self.audio_processor.load_audio(file_path)
audio_array = self.audio_processor.normalize_audio(audio_array)
# Create visualizations
waveform_fig = UIComponents.create_audio_visualization(audio_array)
spectrogram_fig = UIComponents.create_spectrogram_visualization(audio_array, sr)
# Analyze audio
tempo = self.audio_processor.get_tempo(audio_array)
key = self._estimate_key(audio_array)
energy = self._calculate_energy(audio_array)
mood = self._estimate_mood(audio_array)
return (
file_path,
f"{tempo:.1f}",
key,
f"{energy:.2f}",
mood,
waveform_fig,
spectrogram_fig
)
except Exception as e:
logger.error(f"Audio analysis failed: {str(e)}")
raise gr.Error(f"Analysis failed: {str(e)}")
def generate_music_from_file(
self,
file_path: str,
model_name: str,
duration: int,
guidance_scale: float,
conditioning_strength: float,
pitch_shift: int,
tempo_multiplier: float
):
"""Generate music from uploaded file"""
try:
if not file_path:
raise gr.Error("Please upload an audio file")
# Load and process audio
audio_array, sr = self.audio_processor.load_audio(file_path)
# Apply modifications
if pitch_shift != 0:
audio_array = self.audio_processor.change_pitch(audio_array, pitch_shift)
if tempo_multiplier != 1.0:
audio_array = self.audio_processor.change_speed(audio_array, tempo_multiplier)
audio_array = self.audio_processor.normalize_audio(audio_array)
# Get model
model = self.model_manager.get_model(model_name)
if not model:
raise gr.Error(f"Model {model_name} not available")
# Generate music
logger.info("Generating music from audio file...")
generated_audio = model.generate_from_audio(
audio_array=audio_array,
duration=duration,
guidance_scale=guidance_scale
)
# Apply conditioning strength
if conditioning_strength