Spaces:
Paused
Paused
| import os | |
| import sys | |
| import torch | |
| import uuid | |
| # 1. Force PyTorch to use a single thread to prevent Python 3.13 async loop collisions | |
| torch.set_num_threads(1) | |
| # Fix for PyTorch 2.6+ strict weights_only default for legacy checkpoints | |
| _original_torch_load = torch.load | |
| def trusted_torch_load(*args, **kwargs): | |
| if 'weights_only' not in kwargs: | |
| kwargs['weights_only'] = False | |
| return _original_torch_load(*args, **kwargs) | |
| torch.load = trusted_torch_load | |
| # 2. Clone the repository dynamically if it doesn't exist | |
| if not os.path.exists('nuwave'): | |
| os.system('git clone https://github.com/mindslab-ai/nuwave.git') | |
| # Add the cloned directory to the system path | |
| sys.path.append('./nuwave') | |
| import gradio as gr | |
| import librosa | |
| import soundfile as sf | |
| from huggingface_hub import hf_hub_download | |
| from lightning_model import NuWave | |
| # 3. Force CPU Device Configuration | |
| device = 'cpu' | |
| ckpt_path = hf_hub_download(repo_id='nateraw/nu-wave-x2', filename='lit_model.ckpt') | |
| # π₯ Explicitly map the checkpoint structure directly to the CPU | |
| model = NuWave.load_from_checkpoint(ckpt_path, train=False, map_location='cpu').to(device) | |
| model.eval() | |
| # 4. Inference Function | |
| def upsample_audio(audio_path, steps, noise_schedule): | |
| if audio_path is None: | |
| return None | |
| # Load audio | |
| wav, sr = librosa.load(audio_path, sr=None) | |
| # Process on CPU | |
| wav_tensor = torch.FloatTensor(wav).to(device) | |
| with torch.no_grad(): | |
| # Replace this with your model's exact generation call when ready: | |
| # e.g., upsampled_wav = model.generate(wav_tensor, steps=steps, schedule=noise_schedule) | |
| upsampled_wav = wav_tensor # Placeholder | |
| output_wav = upsampled_wav.cpu().numpy() | |
| # Generate unique output file to prevent caching/file-locking errors between runs | |
| unique_filename = f"output_upsampled_{uuid.uuid4().hex[:8]}.wav" | |
| sf.write(unique_filename, output_wav, 48000) | |
| return unique_filename | |
| # 5. Build Gradio Web UI | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # π NU-Wave X2: Audio Bandwidth Extension (CPU Mode) | |
| Upload a low-resolution audio file to upsample and reconstruct missing high-frequency details. | |
| *Note: Running on CPU will take longer to process higher step counts than a GPU Space.* | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_audio = gr.Audio( | |
| type="filepath", | |
| label="Input Low-Res Audio" | |
| ) | |
| gr.Markdown("### βοΈ Inference Fine-Tuning") | |
| steps_slider = gr.Slider( | |
| minimum=4, | |
| maximum=50, # Lowered maximum recommended steps for CPU performance | |
| value=6, | |
| step=1, | |
| label="Inference Steps", | |
| info="β‘ **Impact:** Higher steps improve clarity but increase CPU processing time significantly." | |
| ) | |
| schedule_dropdown = gr.Dropdown( | |
| choices=["linear", "fibonacci", "cosine"], | |
| value="linear", | |
| label="Noise Schedule" | |
| ) | |
| submit_btn = gr.Button("Upsample Audio", variant="primary") | |
| with gr.Column(): | |
| output_audio = gr.Audio( | |
| label="Upsampled High-Res Audio (48kHz)" | |
| ) | |
| submit_btn.click( | |
| fn=upsample_audio, | |
| inputs=[input_audio, steps_slider, schedule_dropdown], | |
| outputs=output_audio | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |