Spaces:
Running
Running
| import os | |
| import requests | |
| from tqdm import tqdm | |
| def download(url, path): | |
| """Download file from url to local path with progress bar.""" | |
| if os.path.exists(path): | |
| print(f"File already exists, skipping download: {path}") | |
| return | |
| os.makedirs(os.path.dirname(path), exist_ok=True) | |
| response = requests.get(url, stream=True) | |
| total_size = int(response.headers.get("content-length", 0)) | |
| with ( | |
| open(path, "wb") as f, | |
| tqdm( | |
| desc=path, | |
| total=total_size, | |
| unit="iB", | |
| unit_scale=True, | |
| unit_divisor=1024, | |
| ) as bar, | |
| ): | |
| for data in response.iter_content(chunk_size=1024): | |
| size = f.write(data) | |
| bar.update(size) | |
| def download_all(use_mirror: bool = False): | |
| """Download all required checkpoints. | |
| Args: | |
| use_mirror (bool): If True, use hf-mirror.com (for Mainland China). | |
| """ | |
| base_url = "https://hf-mirror.com" if use_mirror else "https://huggingface.co" | |
| urls = [ | |
| (f"{base_url}/minzwon/MusicFM/resolve/main/msd_stats.json", | |
| os.path.join("ckpts", "MusicFM", "msd_stats.json")), | |
| (f"{base_url}/minzwon/MusicFM/resolve/main/pretrained_msd.pt", | |
| os.path.join("ckpts", "MusicFM", "pretrained_msd.pt")), | |
| (f"{base_url}/ASLP-lab/SongFormer/resolve/main/SongFormer.safetensors", | |
| os.path.join("ckpts", "SongFormer.safetensors")), | |
| # The content of safetensors is the same as pt, it is recommended to use safetensors | |
| # (f"{base_url}/ASLP-lab/SongFormer/resolve/main/SongFormer.pt", | |
| # os.path.join("ckpts", "SongFormer.pt")), | |
| ] | |
| for url, path in urls: | |
| download(url, path) | |
| if __name__ == "__main__": | |
| # By default, use HuggingFace. If you are in Mainland China, change to True | |
| download_all(use_mirror=False) |