File size: 4,278 Bytes
e617857
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import os
import torch
import torchaudio
import argparse
from huggingface_hub import hf_hub_download

# For PyHARP wrapper
from pyharp import ModelCard, build_endpoint, load_audio, save_audio
import gradio as gr

# Create a ModelCard
model_card = ModelCard(
    name="Apollo",
    description="High-quality audio restoration for lossy MP3 compressed audio. Converts low-bitrate MP3s to near-lossless quality using band-sequence modeling.",
    author="JusperLee",
    tags=["audio restoration", "music", "apollo", "mp3", "lossless"],
)

def load_audio(file_path):
    audio, samplerate = torchaudio.load(file_path)
    return audio.unsqueeze(0)  # [1, 1, samples] - no .cuda()

def save_audio(file_path, audio, samplerate=44100):
    audio = audio.squeeze(0).cpu()
    torchaudio.save(file_path, audio, samplerate)

#Defining the process function
@torch.inference_mode()
def process_fn(
    input_audio_path: str
) -> str:
    # Don't set CUDA device - let it use CPU
    device = torch.device("cpu")
    
    print(f"Using device: {device}")
    print("Loading Apollo model...")
    
    # Download model weights from HuggingFace
    model_path = hf_hub_download(
        repo_id="JusperLee/Apollo",
        filename="pytorch_model.bin",
        cache_dir="./checkpoints"
    )
    
    # Load checkpoint WITH OmegaConf support
    print(f"Loading checkpoint from {model_path}")
    checkpoint = torch.load(model_path, map_location=device, weights_only=False)
    
    # Extract model info
    model_name = checkpoint['model_name']
    state_dict = checkpoint['state_dict']
    model_args = checkpoint.get('model_args', {})
    
    print(f"Model class: {model_name}")
    print(f"Model args: {model_args}")
    
    # Import the correct model class
    from look2hear.models import get
    model_class = get(model_name)
    
    # Create model instance with model_args
    # Convert OmegaConf to dict if needed
    if hasattr(model_args, 'to_container'):
        model_args = model_args.to_container(resolve=True)
    
    print(f"Instantiating {model_name}...")
    model = model_class(**model_args)
    
    # Load state dict
    print("Loading state dict...")
    model.load_state_dict(state_dict)
    
    model = model.to(device)
    model.eval()
    print("✓ Model loaded successfully")
    
    # Commenting out excess print statement bc it uses input.wav
    # print(f"Processing audio: {input_wav}")
    sig = load_audio(input_audio_path)

    # Move audio data to device
    sig = sig.to(device)

    # Add batch dimension if needed (Apollo expects [batch, channels, samples])
    if sig.dim() == 2:
        sig = sig.unsqueeze(0)
    
    with torch.no_grad():
        output = model(sig)

    # Remove batch dimension
    output = output.squeeze(0)
    
    output_audio_path = os.path.join("src", "_outputs", "output_restored.wav")
    os.makedirs(os.path.dirname(output_audio_path), exist_ok=True)
    torchaudio.save(output_audio_path, output, 44100)
    print(f"✓ Saved output to {output_audio_path}")

    return output_audio_path

    # original export method
    # save_audio(output_wav, out)
    # print(f"✓ Saved output to {output_wav}")

# Build Gradio endpoint
with gr.Blocks() as demo:
    # Define input Gradio Components
    input_components = [
        gr.Audio(type="filepath",
                 label="Input Audio A")
        .harp_required(True),
    ]

    # Define output Gradio Components
    output_components = [
        gr.Audio(type="filepath",
                 label="Output Audio")
        .set_info("The restored audio."),
    ]

    # Build a HARP-compatible endpoint
    app = build_endpoint(
        model_card=model_card,
        input_components=input_components,
        output_components=output_components,
        process_fn=process_fn,
    )

# run the thing
demo.queue().launch(share=True, show_error=False, pwa=True)

# original inference function run
'''
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Audio Inference Script")
    parser.add_argument("--in_wav", type=str, required=True, help="Path to input wav file")
    parser.add_argument("--out_wav", type=str, required=True, help="Path to output wav file")
    args = parser.parse_args()

    main(args.in_wav, args.out_wav)
    '''