nu-wave / app.py
arjunbroepic's picture
Update app.py
b7952aa verified
Raw
History Blame Contribute Delete
3.63 kB
import os
import sys
import torch
import uuid
# 1. Force PyTorch to use a single thread to prevent Python 3.13 async loop collisions
torch.set_num_threads(1)
# Fix for PyTorch 2.6+ strict weights_only default for legacy checkpoints
_original_torch_load = torch.load
def trusted_torch_load(*args, **kwargs):
if 'weights_only' not in kwargs:
kwargs['weights_only'] = False
return _original_torch_load(*args, **kwargs)
torch.load = trusted_torch_load
# 2. Clone the repository dynamically if it doesn't exist
if not os.path.exists('nuwave'):
os.system('git clone https://github.com/mindslab-ai/nuwave.git')
# Add the cloned directory to the system path
sys.path.append('./nuwave')
import gradio as gr
import librosa
import soundfile as sf
from huggingface_hub import hf_hub_download
from lightning_model import NuWave
# 3. Force CPU Device Configuration
device = 'cpu'
ckpt_path = hf_hub_download(repo_id='nateraw/nu-wave-x2', filename='lit_model.ckpt')
# πŸ”₯ Explicitly map the checkpoint structure directly to the CPU
model = NuWave.load_from_checkpoint(ckpt_path, train=False, map_location='cpu').to(device)
model.eval()
# 4. Inference Function
def upsample_audio(audio_path, steps, noise_schedule):
if audio_path is None:
return None
# Load audio
wav, sr = librosa.load(audio_path, sr=None)
# Process on CPU
wav_tensor = torch.FloatTensor(wav).to(device)
with torch.no_grad():
# Replace this with your model's exact generation call when ready:
# e.g., upsampled_wav = model.generate(wav_tensor, steps=steps, schedule=noise_schedule)
upsampled_wav = wav_tensor # Placeholder
output_wav = upsampled_wav.cpu().numpy()
# Generate unique output file to prevent caching/file-locking errors between runs
unique_filename = f"output_upsampled_{uuid.uuid4().hex[:8]}.wav"
sf.write(unique_filename, output_wav, 48000)
return unique_filename
# 5. Build Gradio Web UI
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# 🌊 NU-Wave X2: Audio Bandwidth Extension (CPU Mode)
Upload a low-resolution audio file to upsample and reconstruct missing high-frequency details.
*Note: Running on CPU will take longer to process higher step counts than a GPU Space.*
"""
)
with gr.Row():
with gr.Column():
input_audio = gr.Audio(
type="filepath",
label="Input Low-Res Audio"
)
gr.Markdown("### βš™οΈ Inference Fine-Tuning")
steps_slider = gr.Slider(
minimum=4,
maximum=50, # Lowered maximum recommended steps for CPU performance
value=6,
step=1,
label="Inference Steps",
info="⚑ **Impact:** Higher steps improve clarity but increase CPU processing time significantly."
)
schedule_dropdown = gr.Dropdown(
choices=["linear", "fibonacci", "cosine"],
value="linear",
label="Noise Schedule"
)
submit_btn = gr.Button("Upsample Audio", variant="primary")
with gr.Column():
output_audio = gr.Audio(
label="Upsampled High-Res Audio (48kHz)"
)
submit_btn.click(
fn=upsample_audio,
inputs=[input_audio, steps_slider, schedule_dropdown],
outputs=output_audio
)
if __name__ == "__main__":
demo.launch()