Spaces:
Running
Running
File size: 4,278 Bytes
e617857 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 | import os
import torch
import torchaudio
import argparse
from huggingface_hub import hf_hub_download
# For PyHARP wrapper
from pyharp import ModelCard, build_endpoint, load_audio, save_audio
import gradio as gr
# Create a ModelCard
model_card = ModelCard(
name="Apollo",
description="High-quality audio restoration for lossy MP3 compressed audio. Converts low-bitrate MP3s to near-lossless quality using band-sequence modeling.",
author="JusperLee",
tags=["audio restoration", "music", "apollo", "mp3", "lossless"],
)
def load_audio(file_path):
audio, samplerate = torchaudio.load(file_path)
return audio.unsqueeze(0) # [1, 1, samples] - no .cuda()
def save_audio(file_path, audio, samplerate=44100):
audio = audio.squeeze(0).cpu()
torchaudio.save(file_path, audio, samplerate)
#Defining the process function
@torch.inference_mode()
def process_fn(
input_audio_path: str
) -> str:
# Don't set CUDA device - let it use CPU
device = torch.device("cpu")
print(f"Using device: {device}")
print("Loading Apollo model...")
# Download model weights from HuggingFace
model_path = hf_hub_download(
repo_id="JusperLee/Apollo",
filename="pytorch_model.bin",
cache_dir="./checkpoints"
)
# Load checkpoint WITH OmegaConf support
print(f"Loading checkpoint from {model_path}")
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
# Extract model info
model_name = checkpoint['model_name']
state_dict = checkpoint['state_dict']
model_args = checkpoint.get('model_args', {})
print(f"Model class: {model_name}")
print(f"Model args: {model_args}")
# Import the correct model class
from look2hear.models import get
model_class = get(model_name)
# Create model instance with model_args
# Convert OmegaConf to dict if needed
if hasattr(model_args, 'to_container'):
model_args = model_args.to_container(resolve=True)
print(f"Instantiating {model_name}...")
model = model_class(**model_args)
# Load state dict
print("Loading state dict...")
model.load_state_dict(state_dict)
model = model.to(device)
model.eval()
print("✓ Model loaded successfully")
# Commenting out excess print statement bc it uses input.wav
# print(f"Processing audio: {input_wav}")
sig = load_audio(input_audio_path)
# Move audio data to device
sig = sig.to(device)
# Add batch dimension if needed (Apollo expects [batch, channels, samples])
if sig.dim() == 2:
sig = sig.unsqueeze(0)
with torch.no_grad():
output = model(sig)
# Remove batch dimension
output = output.squeeze(0)
output_audio_path = os.path.join("src", "_outputs", "output_restored.wav")
os.makedirs(os.path.dirname(output_audio_path), exist_ok=True)
torchaudio.save(output_audio_path, output, 44100)
print(f"✓ Saved output to {output_audio_path}")
return output_audio_path
# original export method
# save_audio(output_wav, out)
# print(f"✓ Saved output to {output_wav}")
# Build Gradio endpoint
with gr.Blocks() as demo:
# Define input Gradio Components
input_components = [
gr.Audio(type="filepath",
label="Input Audio A")
.harp_required(True),
]
# Define output Gradio Components
output_components = [
gr.Audio(type="filepath",
label="Output Audio")
.set_info("The restored audio."),
]
# Build a HARP-compatible endpoint
app = build_endpoint(
model_card=model_card,
input_components=input_components,
output_components=output_components,
process_fn=process_fn,
)
# run the thing
demo.queue().launch(share=True, show_error=False, pwa=True)
# original inference function run
'''
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Audio Inference Script")
parser.add_argument("--in_wav", type=str, required=True, help="Path to input wav file")
parser.add_argument("--out_wav", type=str, required=True, help="Path to output wav file")
args = parser.parse_args()
main(args.in_wav, args.out_wav)
'''
|