| | from pathlib import Path |
| |
|
| | import argbind |
| | from audiotools import ml |
| |
|
| | import dac |
| |
|
| | DAC = dac.model.DAC |
| | Accelerator = ml.Accelerator |
| |
|
| | __MODEL_LATEST_TAGS__ = { |
| | ("44khz", "8kbps"): "0.0.1", |
| | ("24khz", "8kbps"): "0.0.4", |
| | ("16khz", "8kbps"): "0.0.5", |
| | ("44khz", "16kbps"): "1.0.0", |
| | } |
| |
|
| | __MODEL_URLS__ = { |
| | ( |
| | "44khz", |
| | "0.0.1", |
| | "8kbps", |
| | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth", |
| | ( |
| | "24khz", |
| | "0.0.4", |
| | "8kbps", |
| | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pth", |
| | ( |
| | "16khz", |
| | "0.0.5", |
| | "8kbps", |
| | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.5/weights_16khz.pth", |
| | ( |
| | "44khz", |
| | "1.0.0", |
| | "16kbps", |
| | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/1.0.0/weights_44khz_16kbps.pth", |
| | } |
| |
|
| |
|
| | @argbind.bind(group="download", positional=True, without_prefix=True) |
| | def download( |
| | model_type: str = "44khz", model_bitrate: str = "8kbps", tag: str = "latest" |
| | ): |
| | """ |
| | Function that downloads the weights file from URL if a local cache is not found. |
| | |
| | Parameters |
| | ---------- |
| | model_type : str |
| | The type of model to download. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". |
| | model_bitrate: str |
| | Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". |
| | Only 44khz model supports 16kbps. |
| | tag : str |
| | The tag of the model to download. Defaults to "latest". |
| | |
| | Returns |
| | ------- |
| | Path |
| | Directory path required to load model via audiotools. |
| | """ |
| | model_type = model_type.lower() |
| | tag = tag.lower() |
| |
|
| | assert model_type in [ |
| | "44khz", |
| | "24khz", |
| | "16khz", |
| | ], "model_type must be one of '44khz', '24khz', or '16khz'" |
| |
|
| | assert model_bitrate in [ |
| | "8kbps", |
| | "16kbps", |
| | ], "model_bitrate must be one of '8kbps', or '16kbps'" |
| |
|
| | if tag == "latest": |
| | tag = __MODEL_LATEST_TAGS__[(model_type, model_bitrate)] |
| |
|
| | download_link = __MODEL_URLS__.get((model_type, tag, model_bitrate), None) |
| |
|
| | if download_link is None: |
| | raise ValueError( |
| | f"Could not find model with tag {tag} and model type {model_type}" |
| | ) |
| |
|
| | local_path = ( |
| | Path.home() |
| | / ".cache" |
| | / "descript" |
| | / "dac" |
| | / f"weights_{model_type}_{model_bitrate}_{tag}.pth" |
| | ) |
| | if not local_path.exists(): |
| | local_path.parent.mkdir(parents=True, exist_ok=True) |
| |
|
| | |
| | import requests |
| |
|
| | response = requests.get(download_link) |
| |
|
| | if response.status_code != 200: |
| | raise ValueError( |
| | f"Could not download model. Received response code {response.status_code}" |
| | ) |
| | local_path.write_bytes(response.content) |
| |
|
| | return local_path |
| |
|
| |
|
| | def load_model( |
| | model_type: str = "44khz", |
| | model_bitrate: str = "8kbps", |
| | tag: str = "latest", |
| | load_path: str = None, |
| | ): |
| | if not load_path: |
| | load_path = download( |
| | model_type=model_type, model_bitrate=model_bitrate, tag=tag |
| | ) |
| | generator = DAC.load(load_path) |
| | return generator |
| |
|