genmoai / scripts /download_weights.py
versantus's picture
Upload folder using huggingface_hub
d6c2737 verified
#! /usr/bin/env python3
import os
import tempfile
import click
from huggingface_hub import hf_hub_download, snapshot_download
import shutil
BASE_MODEL_FILES = [
# (repo_id, remote_file_path, local_file_path)
("genmo/mochi-1-preview", "decoder.safetensors", "decoder.safetensors"),
("genmo/mochi-1-preview", "encoder.safetensors", "encoder.safetensors"),
("genmo/mochi-1-preview", "dit.safetensors", "dit.safetensors"),
]
FAST_MODEL_FILE = ("FastVideo/FastMochi", "dit.safetensors", "dit.fast.safetensors")
@click.command()
@click.argument('output_dir', required=True)
@click.option('--fast_model', is_flag=True, help='Download FastMochi model instead of standard model')
@click.option('--hf_transfer', is_flag=True, help='Enable faster downloads using hf_transfer (requires: pip install "huggingface_hub[hf_transfer]")')
def download_weights(output_dir, fast_model, hf_transfer):
if not os.path.exists(output_dir):
print(f"Creating output directory: {output_dir}")
os.makedirs(output_dir, exist_ok=True)
if hf_transfer:
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
print("Using hf_transfer for faster downloads (requires: pip install 'huggingface_hub[hf_transfer]')")
model_files = BASE_MODEL_FILES
if fast_model:
# Replace the standard DIT model with the fast model
model_files = [f for f in model_files if not f[2].startswith("dit.")]
model_files.append(FAST_MODEL_FILE)
for repo_id, remote_path, local_path in model_files:
local_file_path = os.path.join(output_dir, local_path)
if not os.path.exists(local_file_path):
if hf_transfer:
# I don't know if `hf_transfer` works with `snapshot_download`
print(f"Downloading {local_path} from {repo_id} to: {local_file_path}")
out_path = hf_hub_download(
repo_id=repo_id,
filename=remote_path,
local_dir=output_dir,
)
print(f"Copying {out_path} to {local_file_path}")
# copy instead of mv to avoid destroying huggingface cache
shutil.copy2(out_path, local_file_path)
else:
with tempfile.TemporaryDirectory() as tmp_dir:
snapshot_download(
repo_id=repo_id,
allow_patterns=[f"*{remote_path}*"],
local_dir=tmp_dir,
local_dir_use_symlinks=False,
)
shutil.move(os.path.join(tmp_dir, remote_path), local_file_path)
else:
print(f"{local_path} already exists in: {local_file_path}")
assert os.path.exists(local_file_path), f"File {local_file_path} does not exist"
if __name__ == "__main__":
download_weights()