Kling-Match / src /SongFormer /utils /fetch_pretrained.py
ASLP-lab's picture
fix import bug
952133e
import os
import requests
from tqdm import tqdm
def download(url, path):
"""Download file from url to local path with progress bar."""
if os.path.exists(path):
print(f"File already exists, skipping download: {path}")
return
os.makedirs(os.path.dirname(path), exist_ok=True)
response = requests.get(url, stream=True)
total_size = int(response.headers.get("content-length", 0))
with (
open(path, "wb") as f,
tqdm(
desc=path,
total=total_size,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as bar,
):
for data in response.iter_content(chunk_size=1024):
size = f.write(data)
bar.update(size)
def download_all(use_mirror: bool = False):
"""Download all required checkpoints.
Args:
use_mirror (bool): If True, use hf-mirror.com (for Mainland China).
"""
base_url = "https://hf-mirror.com" if use_mirror else "https://huggingface.co"
urls = [
(f"{base_url}/minzwon/MusicFM/resolve/main/msd_stats.json",
os.path.join("ckpts", "MusicFM", "msd_stats.json")),
(f"{base_url}/minzwon/MusicFM/resolve/main/pretrained_msd.pt",
os.path.join("ckpts", "MusicFM", "pretrained_msd.pt")),
(f"{base_url}/ASLP-lab/SongFormer/resolve/main/SongFormer.safetensors",
os.path.join("ckpts", "SongFormer.safetensors")),
# The content of safetensors is the same as pt, it is recommended to use safetensors
# (f"{base_url}/ASLP-lab/SongFormer/resolve/main/SongFormer.pt",
# os.path.join("ckpts", "SongFormer.pt")),
]
for url, path in urls:
download(url, path)
if __name__ == "__main__":
# By default, use HuggingFace. If you are in Mainland China, change to True
download_all(use_mirror=False)