LavaSR / app.py
YatharthS's picture
Update app.py
7298c53 verified
import gradio as gr
import torch
import soundfile as sf
import numpy as np
from LavaSR.model import LavaEnhance2
# --- Setup Model ---
# Using 'cpu' as per your snippet; change to 'cuda' if a GPU is available.
device = 'cuda' if torch.cuda.is_available() else 'cpu'
lava_model = LavaEnhance2("YatharthS/LavaSR", device)
def process_audio(input_file, input_sr, denoise, batch):
if input_file is None:
return None, None
# 1. Load Audio
# LavaSR's load_audio handles the resampling and tensor conversion
input_audio_tensor, actual_sr = lava_model.load_audio(input_file, input_sr=input_sr)
# 2. Enhance Audio
# The enhance method returns a tensor
output_audio_tensor = lava_model.enhance(
input_audio_tensor,
denoise=denoise,
batch=batch
)
# 3. Prepare for Gradio Output
# Convert to numpy and ensure the correct shape for Gradio/Soundfile
input_audio_np = input_audio_tensor.cpu().numpy().squeeze()
input_audio_np = (np.clip(input_audio_np, -1.0, 1.0) * 32767).astype(np.int16)
output_audio_np = output_audio_tensor.cpu().numpy().squeeze()
output_audio_np = (np.clip(output_audio_np, -1.0, 1.0) * 32767).astype(np.int16)
# LavaSR outputs at 48kHz by default
return (16000, input_audio_np), (48000, output_audio_np)
# --- Gradio UI ---
with gr.Blocks(title="LavaSR Audio Super-Resolution") as demo:
gr.Markdown("# 🌋 LavaSR Audio Enhancement")
gr.Markdown("Upload low-quality audio to enhance it using the LavaSR model. Running on 2 core CPU.")
with gr.Row():
with gr.Column():
input_audio_ui = gr.Audio(type="filepath", label="Upload Audio")
with gr.Accordion("Advanced Settings", open=True):
sr_slider = gr.Slider(
minimum=8000, maximum=48000, value=16000, step=1000,
label="Input Sampling Rate (Hz)",
info="Match this to your source audio's quality."
)
denoise_toggle = gr.Checkbox(label="Enable Denoising", value=False)
batch_toggle = gr.Checkbox(label="Enable Batching", value=False,
info="Use for very long audio files.")
submit_btn = gr.Button("Enhance Audio", variant="primary")
with gr.Column():
# Shows the audio after initial loading/resampling
resampled_output = gr.Audio(label="Input (Resampled to 16kHz)")
# Shows the final 48kHz enhanced output
enhanced_output = gr.Audio(label="Enhanced Audio (48kHz)")
submit_btn.click(
fn=process_audio,
inputs=[input_audio_ui, sr_slider, denoise_toggle, batch_toggle],
outputs=[resampled_output, enhanced_output]
)
if __name__ == "__main__":
demo.launch()