gabboud commited on
Commit
a887836
·
1 Parent(s): 8adc15c

correct test based on colab notebook

Browse files
Files changed (1) hide show
  1. app.py +39 -3
app.py CHANGED
@@ -1,10 +1,14 @@
1
  import gradio as gr
2
-
3
  import os
4
  import subprocess
5
  from pathlib import Path
6
  import shutil
7
  import spaces
 
 
 
 
8
 
9
  # Download model weights (skips already-downloaded models automatically)
10
  # In total, ~6GB (3GB for RFD3, 3GB for RF3, <100MB for MPNN); may take a few minutes depending on your connection speed
@@ -58,7 +62,39 @@ def test_rfd3():
58
  return f"RFD3 test failed: {result.stderr}"
59
  except Exception as e:
60
  return f"Error: {str(e)}"
61
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  # Gradio UI
64
  with gr.Blocks(title="RFD3 Test") as demo:
@@ -68,7 +104,7 @@ with gr.Blocks(title="RFD3 Test") as demo:
68
  test_btn = gr.Button("Run RFD3 Test")
69
  output = gr.Textbox(label="Test Result")
70
 
71
- test_btn.click(test_rfd3, outputs=output)
72
 
73
  if __name__ == "__main__":
74
  demo.launch()
 
1
  import gradio as gr
2
+ import warnings
3
  import os
4
  import subprocess
5
  from pathlib import Path
6
  import shutil
7
  import spaces
8
+ from atomworks.io.utils.visualize import view
9
+ from lightning.fabric import seed_everything
10
+ from rfd3.engine import RFD3InferenceConfig, RFD3InferenceEngine
11
+
12
 
13
  # Download model weights (skips already-downloaded models automatically)
14
  # In total, ~6GB (3GB for RFD3, 3GB for RF3, <100MB for MPNN); may take a few minutes depending on your connection speed
 
62
  return f"RFD3 test failed: {result.stderr}"
63
  except Exception as e:
64
  return f"Error: {str(e)}"
65
+
66
+ @spaces.GPU(duration=300)
67
+ def test_rfd3_from_notebook():
68
+ # Set seed for reproducibility
69
+ seed_everything(0)
70
+
71
+ # Configure RFD3 inference
72
+ config = RFD3InferenceConfig(
73
+ specification={
74
+ 'length': 40, # Generate 80-residue proteins
75
+ },
76
+ diffusion_batch_size=2, # Generate 2 structures per batch
77
+ )
78
+
79
+ # Initialize engine and run generation
80
+ try:
81
+ model = RFD3InferenceEngine(**config)
82
+ outputs = model.run(
83
+ inputs=None, # None for unconditional generation
84
+ out_dir=None, # None to return in memory (no file output)
85
+ n_batches=1, # Generate 1 batch
86
+ )
87
+ return_str = "RDF3 test passed! Generated structures:\n"
88
+
89
+ for idx, data in outputs.items():
90
+ return_str += f"Batch {idx}: {len(data)} structure(s)\n"
91
+ for i, struct in enumerate(data):
92
+ return_str += f"Structure {i+1}: {len(struct)} atoms\n"
93
+ return_str += struct.atom_array
94
+ return return_str
95
+ except Exception as e:
96
+ return f"Error: {str(e)}"
97
+
98
 
99
  # Gradio UI
100
  with gr.Blocks(title="RFD3 Test") as demo:
 
104
  test_btn = gr.Button("Run RFD3 Test")
105
  output = gr.Textbox(label="Test Result")
106
 
107
+ test_btn.click(test_rfd3_from_notebook, outputs=output)
108
 
109
  if __name__ == "__main__":
110
  demo.launch()