primepake
add dac codec
541c6b7
raw
history blame
2.55 kB
from pathlib import Path
import argbind
import torch
from audiotools import AudioSignal
from audiotools.core import util
from audiotools.ml.decorators import Tracker
from train import Accelerator
from train import DAC
from dac.compare.encodec import Encodec
Encodec = argbind.bind(Encodec)
def load_state(
accel: Accelerator,
tracker: Tracker,
save_path: str,
tag: str = "latest",
load_weights: bool = False,
model_type: str = "dac",
bandwidth: float = 24.0,
):
kwargs = {
"folder": f"{save_path}/{tag}",
"map_location": "cpu",
"package": not load_weights,
}
tracker.print(f"Resuming from {str(Path('.').absolute())}/{kwargs['folder']}")
if model_type == "dac":
generator, _ = DAC.load_from_folder(**kwargs)
elif model_type == "encodec":
generator = Encodec(bandwidth=bandwidth)
generator = accel.prepare_model(generator)
return generator
@torch.no_grad()
def process(signal, accel, generator, **kwargs):
signal = signal.to(accel.device)
recons = generator(signal.audio_data, signal.sample_rate, **kwargs)["audio"]
recons = AudioSignal(recons, signal.sample_rate)
recons = recons.normalize(signal.loudness())
return recons.cpu()
@argbind.bind(without_prefix=True)
@torch.no_grad()
def get_samples(
accel,
path: str = "ckpt",
input: str = "samples/input",
output: str = "samples/output",
model_type: str = "dac",
model_tag: str = "latest",
bandwidth: float = 24.0,
n_quantizers: int = None,
):
tracker = Tracker(log_file=f"{path}/eval.txt", rank=accel.local_rank)
generator = load_state(
accel,
tracker,
save_path=path,
model_type=model_type,
bandwidth=bandwidth,
tag=model_tag,
)
generator.eval()
kwargs = {"n_quantizers": n_quantizers} if model_type == "dac" else {}
audio_files = util.find_audio(input)
global process
process = tracker.track("process", len(audio_files))(process)
output = Path(output)
output.mkdir(parents=True, exist_ok=True)
with tracker.live:
for i in range(len(audio_files)):
signal = AudioSignal(audio_files[i])
recons = process(signal, accel, generator, **kwargs)
recons.write(output / audio_files[i].name)
tracker.done("test", f"N={len(audio_files)}")
if __name__ == "__main__":
args = argbind.parse_args()
with argbind.scope(args):
with Accelerator() as accel:
get_samples(accel)