import os import torch import torchaudio import argparse from huggingface_hub import hf_hub_download # For PyHARP wrapper from pyharp import ModelCard, build_endpoint, load_audio, save_audio import gradio as gr # Create a ModelCard model_card = ModelCard( name="Apollo", description="High-quality audio restoration for lossy MP3 compressed audio. Converts low-bitrate MP3s to near-lossless quality using band-sequence modeling.", author="JusperLee", tags=["audio restoration", "music", "apollo", "mp3", "lossless"], ) def load_audio(file_path): audio, samplerate = torchaudio.load(file_path) return audio.unsqueeze(0) # [1, 1, samples] - no .cuda() def save_audio(file_path, audio, samplerate=44100): audio = audio.squeeze(0).cpu() torchaudio.save(file_path, audio, samplerate) #Defining the process function @torch.inference_mode() def process_fn( input_audio_path: str ) -> str: # Don't set CUDA device - let it use CPU device = torch.device("cpu") print(f"Using device: {device}") print("Loading Apollo model...") # Download model weights from HuggingFace model_path = hf_hub_download( repo_id="JusperLee/Apollo", filename="pytorch_model.bin", cache_dir="./checkpoints" ) # Load checkpoint WITH OmegaConf support print(f"Loading checkpoint from {model_path}") checkpoint = torch.load(model_path, map_location=device, weights_only=False) # Extract model info model_name = checkpoint['model_name'] state_dict = checkpoint['state_dict'] model_args = checkpoint.get('model_args', {}) print(f"Model class: {model_name}") print(f"Model args: {model_args}") # Import the correct model class from look2hear.models import get model_class = get(model_name) # Create model instance with model_args # Convert OmegaConf to dict if needed if hasattr(model_args, 'to_container'): model_args = model_args.to_container(resolve=True) print(f"Instantiating {model_name}...") model = model_class(**model_args) # Load state dict print("Loading state dict...") model.load_state_dict(state_dict) model = model.to(device) model.eval() print("✓ Model loaded successfully") # Commenting out excess print statement bc it uses input.wav # print(f"Processing audio: {input_wav}") sig = load_audio(input_audio_path) # Move audio data to device sig = sig.to(device) # Add batch dimension if needed (Apollo expects [batch, channels, samples]) if sig.dim() == 2: sig = sig.unsqueeze(0) with torch.no_grad(): output = model(sig) # Remove batch dimension output = output.squeeze(0) output_audio_path = os.path.join("src", "_outputs", "output_restored.wav") os.makedirs(os.path.dirname(output_audio_path), exist_ok=True) torchaudio.save(output_audio_path, output, 44100) print(f"✓ Saved output to {output_audio_path}") return output_audio_path # original export method # save_audio(output_wav, out) # print(f"✓ Saved output to {output_wav}") # Build Gradio endpoint with gr.Blocks() as demo: # Define input Gradio Components input_components = [ gr.Audio(type="filepath", label="Input Audio A") .harp_required(True), ] # Define output Gradio Components output_components = [ gr.Audio(type="filepath", label="Output Audio") .set_info("The restored audio."), ] # Build a HARP-compatible endpoint app = build_endpoint( model_card=model_card, input_components=input_components, output_components=output_components, process_fn=process_fn, ) # run the thing demo.queue().launch(share=True, show_error=False, pwa=True) # original inference function run ''' if __name__ == "__main__": parser = argparse.ArgumentParser(description="Audio Inference Script") parser.add_argument("--in_wav", type=str, required=True, help="Path to input wav file") parser.add_argument("--out_wav", type=str, required=True, help="Path to output wav file") args = parser.parse_args() main(args.in_wav, args.out_wav) '''