File size: 1,395 Bytes
f19aa8b
 
 
 
94ff1b9
 
f19aa8b
 
 
 
 
 
 
 
 
 
 
 
 
76275b4
f19aa8b
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import subprocess
import os
from pathlib import Path

#MODELS = ["rfd3", "ligandmpnn", "rf3"]
MODELS = ["rfd3"]


# foundry is a package installed automatically upon Space initialization through the Gradio SDK because it is listed in requirements.txt. 
# model weights are however not included in the package and must be downloaded separately.
# the command "foundry install ..." automatically avoids re-downloading models if they are already present in the cache directory.
# we however manually check for debugging purposes.

def download_weights():
    """Download model weights using foundry CLI, skipping already-downloaded models."""

    checkpoint_dir = Path.home() / ".foundry" / "checkpoints"
    for model in MODELS:
        model_path = os.path.join(checkpoint_dir, model+".ckpt")
        if os.path.exists(model_path):
            print(f"{model} already exists at {model_path}, skipping download.")
        else:
            cmd = f"foundry install {model} --checkpoint-dir {checkpoint_dir}"
            print(f"Installing {model}...")
            result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
            if result.returncode == 0:
                print(f"{model} installed successfully.")
            else:
                print(f"Error installing {model}: {result.stderr}")
                print(result.stdout)
                print(result.returncode)