gabboud commited on
Commit
27305e9
·
1 Parent(s): 76275b4

unconditional generation

Browse files
Files changed (1) hide show
  1. app.py +61 -28
app.py CHANGED
@@ -11,35 +11,8 @@ from rfd3.engine import RFD3InferenceConfig, RFD3InferenceEngine
11
  from utils import download_weights
12
 
13
 
14
- # foundry is a package installed automatically upon Space initialization through the Gradio SDK because it is listed in requirements.txt.
15
- # model weights are however not included in the package and must be downloaded separately.
16
- # the command "foundry install ..." automatically avoids re-downloading models if they are already present in the cache directory.
17
- #cmd = f"foundry install rfd3 ligandmpnn rf3"
18
- #result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
19
- #if result.returncode == 0:
20
- # print("Models installed successfully.")
21
- #else:
22
- # print(f"Error installing models: {result.stderr}")
23
- # print(result.stdout)
24
- # print(result.returncode)
25
-
26
  download_weights()
27
 
28
- # Run once on startup: Install models if missing
29
- #checkpoint_dir = Path.home() / ".foundry" / "checkpoints"
30
- #os.environ["FOUNDRY_CHECKPOINT_DIRS"] = str(checkpoint_dir)
31
- #
32
- #def install_models():
33
- # """Download rfd3, ligandmpnn, rf3 weights once."""
34
- # #models = ["rfd3", "ligandmpnn", "rf3"]
35
- # models = ["ligandmpnn"] # let's start with only ligand mpnn for testing
36
- # for model in models:
37
- # if not (checkpoint_dir / model).exists():
38
- # print(f"Installing {model}...")
39
- # subprocess.check_call(["foundry", "install", model])
40
- # print("All models installed.")
41
- #
42
- #install_models() # Executes on app.py load
43
 
44
 
45
  @spaces.GPU(duration=300)
@@ -77,13 +50,73 @@ def test_rfd3_from_notebook():
77
 
78
  # Gradio UI
79
  with gr.Blocks(title="RFD3 Test") as demo:
80
- gr.Markdown("# RFdiffusion3 (RFD3) Model Checker")
81
  gr.Markdown("Models auto-downloaded on launch. Click to test.")
 
82
 
83
  test_btn = gr.Button("Run RFD3 Test")
84
  output = gr.Textbox(label="Test Result")
85
 
86
  test_btn.click(test_rfd3_from_notebook, outputs=output)
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  if __name__ == "__main__":
89
  demo.launch()
 
11
  from utils import download_weights
12
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  download_weights()
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
  @spaces.GPU(duration=300)
 
50
 
51
  # Gradio UI
52
  with gr.Blocks(title="RFD3 Test") as demo:
53
+ gr.Markdown("# RFdiffusion3 (RFD3) for Backbone generation")
54
  gr.Markdown("Models auto-downloaded on launch. Click to test.")
55
+
56
 
57
  test_btn = gr.Button("Run RFD3 Test")
58
  output = gr.Textbox(label="Test Result")
59
 
60
  test_btn.click(test_rfd3_from_notebook, outputs=output)
61
 
62
+ gr.Markdown("Unconditional generation of backbones")
63
+ with gr.Row():
64
+ num_designs_per_batch = gr.Number(
65
+ value=2,
66
+ label="Number of Designs per Batch",
67
+ precision=0,
68
+ minimum=1,
69
+ maximum=8
70
+ )
71
+ num_batches = gr.Number(
72
+ value=5,
73
+ label="Number of Batches",
74
+ precision=0,
75
+ minimum=1,
76
+ maximum=10
77
+ )
78
+ length = gr.Number(
79
+ value=40,
80
+ label="Length of Protein (number of residues)",
81
+ precision=0,
82
+ minimum=10,
83
+ maximum=200
84
+ )
85
+ # Configure RFD3 inference
86
+ config = RFD3InferenceConfig(
87
+ specification={
88
+ 'length': length,
89
+
90
+ },
91
+ diffusion_batch_size=num_designs_per_batch, # Generate 2 structures per batch
92
+ )
93
+
94
+ # Initialize engine and run generation
95
+ def unconditional_generation(num_batches, num_designs_per_batch, length):
96
+ try:
97
+ model = RFD3InferenceEngine(**config)
98
+ outputs = model.run(
99
+ inputs=None, # None for unconditional generation
100
+ out_dir=None, # None to return in memory (no file output)
101
+ n_batches=num_batches, # Generate 1 batch
102
+ )
103
+ return_str = "RDF3 test passed! Generated structures:\n"
104
+
105
+ for idx, data in outputs.items():
106
+ return_str += f"Batch {idx}: {len(data)} structure(s)\n"
107
+ for i, struct in enumerate(data):
108
+ return_str += f"Structure {i+1}: {struct.atom_array.array_length()} Atoms\n"
109
+ #return_str += struct.atom_array
110
+ return return_str
111
+
112
+ except Exception as e:
113
+ return f"Error: {str(e)}"
114
+
115
+ gen_btn = gr.Button("Run Unconditional Generation")
116
+ gen_output = gr.Textbox(label="Generation Result")
117
+ gen_btn.click(unconditional_generation, inputs=[num_batches, num_designs_per_batch, length], outputs=gen_output)
118
+
119
+
120
+
121
  if __name__ == "__main__":
122
  demo.launch()