gabboud commited on
Commit
f19aa8b
·
1 Parent(s): 0cb79a0

modularize weight download

Browse files
Files changed (3) hide show
  1. app.py +14 -39
  2. utils/__init__.py +3 -0
  3. utils/download_weights.py +30 -0
app.py CHANGED
@@ -8,29 +8,24 @@ import spaces
8
  from atomworks.io.utils.visualize import view
9
  from lightning.fabric import seed_everything
10
  from rfd3.engine import RFD3InferenceConfig, RFD3InferenceEngine
 
11
 
12
 
13
- # Download model weights (skips already-downloaded models automatically)
14
- # In total, ~6GB (3GB for RFD3, 3GB for RF3, <100MB for MPNN); may take a few minutes depending on your connection speed
 
 
 
 
 
 
 
 
 
15
 
 
16
 
17
- cmd = f"foundry install rfd3 ligandmpnn rf3"
18
- print("Global PATH:", os.environ.get("PATH"))
19
-
20
- result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
21
- if result.returncode == 0:
22
- print("Models installed successfully.")
23
- else:
24
- print(f"Error installing models: {result.stderr}")
25
- print(result.stdout)
26
- print(result.returncode)
27
-
28
- #download_dir = "./models/"
29
- #if not os.path.exists(download_dir):
30
- # cmd = "foundry install rfd3 ligandmpnn rf3"
31
- # result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
32
- #
33
- ## Run once on startup: Install models if missing
34
  #checkpoint_dir = Path.home() / ".foundry" / "checkpoints"
35
  #os.environ["FOUNDRY_CHECKPOINT_DIRS"] = str(checkpoint_dir)
36
  #
@@ -46,26 +41,6 @@ else:
46
  #
47
  #install_models() # Executes on app.py load
48
 
49
- @spaces.GPU(duration=300)
50
- def test_rfd3():
51
- """Run a quick rfd3 test design (minimal monomer, 1 step)."""
52
- try:
53
- cmd = [
54
- "rfd3",
55
- "design",
56
- "--seed", "42",
57
- "contigmap.contigs=[A10]", # Tiny 10-res monomer
58
- "--num_designs", "1",
59
- "inference.output_prefix=test_output",
60
- "--inference.num_diffusion_steps", "10" # Fast test
61
- ]
62
- result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
63
- if result.returncode == 0:
64
- return "RFD3 test passed! Check test_output.pdb. Logs:\n" + result.stdout[-500:]
65
- else:
66
- return f"RFD3 test failed: {result.stderr}"
67
- except Exception as e:
68
- return f"Error: {str(e)}"
69
 
70
  @spaces.GPU(duration=300)
71
  def test_rfd3_from_notebook():
 
8
  from atomworks.io.utils.visualize import view
9
  from lightning.fabric import seed_everything
10
  from rfd3.engine import RFD3InferenceConfig, RFD3InferenceEngine
11
+ from utils import download_weights
12
 
13
 
14
+ # foundry is a package installed automatically upon Space initialization through the Gradio SDK because it is listed in requirements.txt.
15
+ # model weights are however not included in the package and must be downloaded separately.
16
+ # the command "foundry install ..." automatically avoids re-downloading models if they are already present in the cache directory.
17
+ #cmd = f"foundry install rfd3 ligandmpnn rf3"
18
+ #result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
19
+ #if result.returncode == 0:
20
+ # print("Models installed successfully.")
21
+ #else:
22
+ # print(f"Error installing models: {result.stderr}")
23
+ # print(result.stdout)
24
+ # print(result.returncode)
25
 
26
+ download_weights()
27
 
28
+ # Run once on startup: Install models if missing
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  #checkpoint_dir = Path.home() / ".foundry" / "checkpoints"
30
  #os.environ["FOUNDRY_CHECKPOINT_DIRS"] = str(checkpoint_dir)
31
  #
 
41
  #
42
  #install_models() # Executes on app.py load
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  @spaces.GPU(duration=300)
46
  def test_rfd3_from_notebook():
utils/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from utils.download_weights import *
2
+
3
+ __all__ = ["download_weights"]
utils/download_weights.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import os
3
+ from pathlib import Path
4
+
5
+ MODELS = ["rfd3", "ligandmpnn", "rf3"]
6
+
7
+
8
+ # foundry is a package installed automatically upon Space initialization through the Gradio SDK because it is listed in requirements.txt.
9
+ # model weights are however not included in the package and must be downloaded separately.
10
+ # the command "foundry install ..." automatically avoids re-downloading models if they are already present in the cache directory.
11
+ # we however manually check for debugging purposes.
12
+
13
+ def download_weights():
14
+ """Download model weights using foundry CLI, skipping already-downloaded models."""
15
+
16
+ checkpoint_dir = Path.home() / ".foundry" / "checkpoints"
17
+ for model in MODELS:
18
+ model_path = os.path.join(checkpoint_dir, model+".ckpt")
19
+ if model_path.exists():
20
+ print(f"{model} already exists at {model_path}, skipping download.")
21
+ else:
22
+ cmd = f"foundry install {model} --checkpoint-dir {checkpoint_dir}"
23
+ print(f"Installing {model}...")
24
+ result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
25
+ if result.returncode == 0:
26
+ print(f"{model} installed successfully.")
27
+ else:
28
+ print(f"Error installing {model}: {result.stderr}")
29
+ print(result.stdout)
30
+ print(result.returncode)