| | import os |
| | import torch |
| | import torchaudio |
| | import argparse |
| | from huggingface_hub import hf_hub_download |
| |
|
| | |
| | from pyharp import ModelCard, build_endpoint, load_audio, save_audio |
| | import gradio as gr |
| |
|
| | |
| | model_card = ModelCard( |
| | name="Apollo", |
| | description="High-quality audio restoration for lossy MP3 compressed audio. Converts low-bitrate MP3s to near-lossless quality using band-sequence modeling.", |
| | author="JusperLee", |
| | tags=["audio restoration", "music", "apollo", "mp3", "lossless"], |
| | ) |
| |
|
| | def load_audio(file_path): |
| | audio, samplerate = torchaudio.load(file_path) |
| | return audio.unsqueeze(0) |
| |
|
| | def save_audio(file_path, audio, samplerate=44100): |
| | audio = audio.squeeze(0).cpu() |
| | torchaudio.save(file_path, audio, samplerate) |
| |
|
| | |
| | @torch.inference_mode() |
| | def process_fn( |
| | input_audio_path: str |
| | ) -> str: |
| | |
| | device = torch.device("cpu") |
| | |
| | print(f"Using device: {device}") |
| | print("Loading Apollo model...") |
| | |
| | |
| | model_path = hf_hub_download( |
| | repo_id="JusperLee/Apollo", |
| | filename="pytorch_model.bin", |
| | cache_dir="./checkpoints" |
| | ) |
| | |
| | |
| | print(f"Loading checkpoint from {model_path}") |
| | checkpoint = torch.load(model_path, map_location=device, weights_only=False) |
| | |
| | |
| | model_name = checkpoint['model_name'] |
| | state_dict = checkpoint['state_dict'] |
| | model_args = checkpoint.get('model_args', {}) |
| | |
| | print(f"Model class: {model_name}") |
| | print(f"Model args: {model_args}") |
| | |
| | |
| | from look2hear.models import get |
| | model_class = get(model_name) |
| | |
| | |
| | |
| | if hasattr(model_args, 'to_container'): |
| | model_args = model_args.to_container(resolve=True) |
| | |
| | print(f"Instantiating {model_name}...") |
| | model = model_class(**model_args) |
| | |
| | |
| | print("Loading state dict...") |
| | model.load_state_dict(state_dict) |
| | |
| | model = model.to(device) |
| | model.eval() |
| | print("✓ Model loaded successfully") |
| | |
| | |
| | |
| | sig = load_audio(input_audio_path) |
| |
|
| | |
| | sig = sig.to(device) |
| |
|
| | |
| | if sig.dim() == 2: |
| | sig = sig.unsqueeze(0) |
| | |
| | with torch.no_grad(): |
| | output = model(sig) |
| |
|
| | |
| | output = output.squeeze(0) |
| | |
| | output_audio_path = os.path.join("src", "_outputs", "output_restored.wav") |
| | os.makedirs(os.path.dirname(output_audio_path), exist_ok=True) |
| | torchaudio.save(output_audio_path, output, 44100) |
| | print(f"✓ Saved output to {output_audio_path}") |
| |
|
| | return output_audio_path |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | with gr.Blocks() as demo: |
| | |
| | input_components = [ |
| | gr.Audio(type="filepath", |
| | label="Input Audio A") |
| | .harp_required(True), |
| | ] |
| |
|
| | |
| | output_components = [ |
| | gr.Audio(type="filepath", |
| | label="Output Audio") |
| | .set_info("The restored audio."), |
| | ] |
| |
|
| | |
| | app = build_endpoint( |
| | model_card=model_card, |
| | input_components=input_components, |
| | output_components=output_components, |
| | process_fn=process_fn, |
| | ) |
| |
|
| | |
| | demo.queue().launch(share=True, show_error=False, pwa=True) |
| |
|
| | |
| | ''' |
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser(description="Audio Inference Script") |
| | parser.add_argument("--in_wav", type=str, required=True, help="Path to input wav file") |
| | parser.add_argument("--out_wav", type=str, required=True, help="Path to output wav file") |
| | args = parser.parse_args() |
| | |
| | main(args.in_wav, args.out_wav) |
| | ''' |
| |
|