| from pathlib import Path
|
|
|
| import argbind
|
| from audiotools import ml
|
|
|
| import dac
|
|
|
| DAC = dac.model.DAC
|
| Accelerator = ml.Accelerator
|
|
|
| __MODEL_LATEST_TAGS__ = {
|
| ("44khz", "8kbps"): "0.0.1",
|
| ("24khz", "8kbps"): "0.0.4",
|
| ("16khz", "8kbps"): "0.0.5",
|
| ("44khz", "16kbps"): "1.0.0",
|
| }
|
|
|
| __MODEL_URLS__ = {
|
| (
|
| "44khz",
|
| "0.0.1",
|
| "8kbps",
|
| ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth",
|
| (
|
| "24khz",
|
| "0.0.4",
|
| "8kbps",
|
| ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pth",
|
| (
|
| "16khz",
|
| "0.0.5",
|
| "8kbps",
|
| ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.5/weights_16khz.pth",
|
| (
|
| "44khz",
|
| "1.0.0",
|
| "16kbps",
|
| ): "https://github.com/descriptinc/descript-audio-codec/releases/download/1.0.0/weights_44khz_16kbps.pth",
|
| }
|
|
|
|
|
| @argbind.bind(group="download", positional=True, without_prefix=True)
|
| def download(
|
| model_type: str = "44khz", model_bitrate: str = "8kbps", tag: str = "latest"
|
| ):
|
| """
|
| Function that downloads the weights file from URL if a local cache is not found.
|
|
|
| Parameters
|
| ----------
|
| model_type : str
|
| The type of model to download. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz".
|
| model_bitrate: str
|
| Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
|
| Only 44khz model supports 16kbps.
|
| tag : str
|
| The tag of the model to download. Defaults to "latest".
|
|
|
| Returns
|
| -------
|
| Path
|
| Directory path required to load model via audiotools.
|
| """
|
| model_type = model_type.lower()
|
| tag = tag.lower()
|
|
|
| assert model_type in [
|
| "44khz",
|
| "24khz",
|
| "16khz",
|
| ], "model_type must be one of '44khz', '24khz', or '16khz'"
|
|
|
| assert model_bitrate in [
|
| "8kbps",
|
| "16kbps",
|
| ], "model_bitrate must be one of '8kbps', or '16kbps'"
|
|
|
| if tag == "latest":
|
| tag = __MODEL_LATEST_TAGS__[(model_type, model_bitrate)]
|
|
|
| download_link = __MODEL_URLS__.get((model_type, tag, model_bitrate), None)
|
|
|
| if download_link is None:
|
| raise ValueError(
|
| f"Could not find model with tag {tag} and model type {model_type}"
|
| )
|
|
|
| local_path = (
|
| Path.home()
|
| / ".cache"
|
| / "descript"
|
| / "dac"
|
| / f"weights_{model_type}_{model_bitrate}_{tag}.pth"
|
| )
|
| if not local_path.exists():
|
| local_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
| import requests
|
|
|
| response = requests.get(download_link)
|
|
|
| if response.status_code != 200:
|
| raise ValueError(
|
| f"Could not download model. Received response code {response.status_code}"
|
| )
|
| local_path.write_bytes(response.content)
|
|
|
| return local_path
|
|
|
|
|
| def load_model(
|
| model_type: str = "44khz",
|
| model_bitrate: str = "8kbps",
|
| tag: str = "latest",
|
| load_path: str = None,
|
| ):
|
| if not load_path:
|
| load_path = download(
|
| model_type=model_type, model_bitrate=model_bitrate, tag=tag
|
| )
|
| generator = DAC.load(load_path)
|
| return generator
|
|
|