|
|
|
|
|
import os |
|
|
import tempfile |
|
|
|
|
|
import click |
|
|
from huggingface_hub import hf_hub_download, snapshot_download |
|
|
import shutil |
|
|
|
|
|
BASE_MODEL_FILES = [ |
|
|
|
|
|
("genmo/mochi-1-preview", "decoder.safetensors", "decoder.safetensors"), |
|
|
("genmo/mochi-1-preview", "encoder.safetensors", "encoder.safetensors"), |
|
|
("genmo/mochi-1-preview", "dit.safetensors", "dit.safetensors"), |
|
|
] |
|
|
|
|
|
FAST_MODEL_FILE = ("FastVideo/FastMochi", "dit.safetensors", "dit.fast.safetensors") |
|
|
|
|
|
|
|
|
@click.command() |
|
|
@click.argument('output_dir', required=True) |
|
|
@click.option('--fast_model', is_flag=True, help='Download FastMochi model instead of standard model') |
|
|
@click.option('--hf_transfer', is_flag=True, help='Enable faster downloads using hf_transfer (requires: pip install "huggingface_hub[hf_transfer]")') |
|
|
def download_weights(output_dir, fast_model, hf_transfer): |
|
|
if not os.path.exists(output_dir): |
|
|
print(f"Creating output directory: {output_dir}") |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
if hf_transfer: |
|
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" |
|
|
print("Using hf_transfer for faster downloads (requires: pip install 'huggingface_hub[hf_transfer]')") |
|
|
|
|
|
model_files = BASE_MODEL_FILES |
|
|
if fast_model: |
|
|
|
|
|
model_files = [f for f in model_files if not f[2].startswith("dit.")] |
|
|
model_files.append(FAST_MODEL_FILE) |
|
|
|
|
|
for repo_id, remote_path, local_path in model_files: |
|
|
local_file_path = os.path.join(output_dir, local_path) |
|
|
if not os.path.exists(local_file_path): |
|
|
if hf_transfer: |
|
|
|
|
|
print(f"Downloading {local_path} from {repo_id} to: {local_file_path}") |
|
|
out_path = hf_hub_download( |
|
|
repo_id=repo_id, |
|
|
filename=remote_path, |
|
|
local_dir=output_dir, |
|
|
) |
|
|
print(f"Copying {out_path} to {local_file_path}") |
|
|
|
|
|
shutil.copy2(out_path, local_file_path) |
|
|
else: |
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
|
snapshot_download( |
|
|
repo_id=repo_id, |
|
|
allow_patterns=[f"*{remote_path}*"], |
|
|
local_dir=tmp_dir, |
|
|
local_dir_use_symlinks=False, |
|
|
) |
|
|
shutil.move(os.path.join(tmp_dir, remote_path), local_file_path) |
|
|
else: |
|
|
print(f"{local_path} already exists in: {local_file_path}") |
|
|
assert os.path.exists(local_file_path), f"File {local_file_path} does not exist" |
|
|
|
|
|
if __name__ == "__main__": |
|
|
download_weights() |
|
|
|