File size: 3,396 Bytes
e617857
 
 
2e943b4
e617857
 
 
a99daf7
e617857
 
 
 
 
 
 
 
 
 
 
 
d4f3fbb
e617857
1abb01c
95b67fb
1abb01c
95b67fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1abb01c
 
2e943b4
e617857
 
 
 
 
d4f3fbb
e617857
 
 
 
 
 
 
 
 
d4f3fbb
e617857
 
928bb51
e617857
 
 
928bb51
e617857
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a99daf7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import torch
import torchaudio
import spaces
from huggingface_hub import hf_hub_download

# For PyHARP wrapper
from pyharp import ModelCard, build_endpoint
import gradio as gr

# Create a ModelCard
model_card = ModelCard(
    name="Apollo",
    description="High-quality audio restoration for lossy MP3 compressed audio. Converts low-bitrate MP3s to near-lossless quality using band-sequence modeling.",
    author="JusperLee",
    tags=["audio restoration", "music", "apollo", "mp3", "lossless"],
)

def load_audio(file_path):
    audio, samplerate = torchaudio.load(file_path)
    return audio.unsqueeze(0) 

# Load the model outside of the process function so that it only has to happen once
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {device}")
print("Loading Apollo model...")
   
# Download model weights from HuggingFace
model_path = hf_hub_download(
    repo_id="JusperLee/Apollo",
    filename="pytorch_model.bin",
    cache_dir="./checkpoints"
)

# Load checkpoint WITH OmegaConf support
print(f"Loading checkpoint from {model_path}")
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
  
# Extract model info
model_name = checkpoint['model_name']
state_dict = checkpoint['state_dict']
model_args = checkpoint.get('model_args', {})
  
print(f"Model class: {model_name}")
print(f"Model args: {model_args}")
    
# Import the correct model class
from look2hear.models import get
model_class = get(model_name)
   
# Create model instance with model_args
# Convert OmegaConf to dict if needed
if hasattr(model_args, 'to_container'):
    model_args = model_args.to_container(resolve=True)

print(f"Instantiating {model_name}...")
model = model_class(**model_args)
    
# Load state dict
print("Loading state dict...")
model.load_state_dict(state_dict)
    
model = model.to(device)
model.eval()
print("✓ Model loaded successfully")

# Defining the process function
@spaces.GPU
@torch.inference_mode()
def process_fn(
    input_audio_path: str
) -> str:
    
    device = torch.device("cuda")
    sig = load_audio(input_audio_path)

    # Move audio data to device
    sig = sig.to(device)

    # Add batch dimension if needed (Apollo expects [batch, channels, samples])
    if sig.dim() == 2:
        sig = sig.unsqueeze(0)
    
    result = model(sig)

    # Remove batch dimension
    result = result.squeeze(0)
    
    output_audio_path = os.path.join("src", "_outputs", "output_restored.wav")
    os.makedirs(os.path.dirname(output_audio_path), exist_ok=True)
    torchaudio.save(output_audio_path, result.cpu(), 44100)
    print(f"✓ Saved output to {output_audio_path}")

    return output_audio_path

# Build Gradio endpoint
with gr.Blocks() as demo:
    # Define input Gradio Components
    input_components = [
        gr.Audio(type="filepath",
                 label="Input Audio A")
        .harp_required(True),
    ]

    # Define output Gradio Components
    output_components = [
        gr.Audio(type="filepath",
                 label="Output Audio")
        .set_info("The restored audio."),
    ]

    # Build a HARP-compatible endpoint
    app = build_endpoint(
        model_card=model_card,
        input_components=input_components,
        output_components=output_components,
        process_fn=process_fn,
    )

# run the model
demo.queue().launch(show_error=True, pwa=True)