Spaces:
Running on Zero
Running on Zero
correct test based on colab notebook
Browse files
app.py
CHANGED
|
@@ -1,10 +1,14 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
|
| 3 |
import os
|
| 4 |
import subprocess
|
| 5 |
from pathlib import Path
|
| 6 |
import shutil
|
| 7 |
import spaces
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
# Download model weights (skips already-downloaded models automatically)
|
| 10 |
# In total, ~6GB (3GB for RFD3, 3GB for RF3, <100MB for MPNN); may take a few minutes depending on your connection speed
|
|
@@ -58,7 +62,39 @@ def test_rfd3():
|
|
| 58 |
return f"RFD3 test failed: {result.stderr}"
|
| 59 |
except Exception as e:
|
| 60 |
return f"Error: {str(e)}"
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
# Gradio UI
|
| 64 |
with gr.Blocks(title="RFD3 Test") as demo:
|
|
@@ -68,7 +104,7 @@ with gr.Blocks(title="RFD3 Test") as demo:
|
|
| 68 |
test_btn = gr.Button("Run RFD3 Test")
|
| 69 |
output = gr.Textbox(label="Test Result")
|
| 70 |
|
| 71 |
-
test_btn.click(
|
| 72 |
|
| 73 |
if __name__ == "__main__":
|
| 74 |
demo.launch()
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
import warnings
|
| 3 |
import os
|
| 4 |
import subprocess
|
| 5 |
from pathlib import Path
|
| 6 |
import shutil
|
| 7 |
import spaces
|
| 8 |
+
from atomworks.io.utils.visualize import view
|
| 9 |
+
from lightning.fabric import seed_everything
|
| 10 |
+
from rfd3.engine import RFD3InferenceConfig, RFD3InferenceEngine
|
| 11 |
+
|
| 12 |
|
| 13 |
# Download model weights (skips already-downloaded models automatically)
|
| 14 |
# In total, ~6GB (3GB for RFD3, 3GB for RF3, <100MB for MPNN); may take a few minutes depending on your connection speed
|
|
|
|
| 62 |
return f"RFD3 test failed: {result.stderr}"
|
| 63 |
except Exception as e:
|
| 64 |
return f"Error: {str(e)}"
|
| 65 |
+
|
| 66 |
+
@spaces.GPU(duration=300)
|
| 67 |
+
def test_rfd3_from_notebook():
|
| 68 |
+
# Set seed for reproducibility
|
| 69 |
+
seed_everything(0)
|
| 70 |
+
|
| 71 |
+
# Configure RFD3 inference
|
| 72 |
+
config = RFD3InferenceConfig(
|
| 73 |
+
specification={
|
| 74 |
+
'length': 40, # Generate 80-residue proteins
|
| 75 |
+
},
|
| 76 |
+
diffusion_batch_size=2, # Generate 2 structures per batch
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
# Initialize engine and run generation
|
| 80 |
+
try:
|
| 81 |
+
model = RFD3InferenceEngine(**config)
|
| 82 |
+
outputs = model.run(
|
| 83 |
+
inputs=None, # None for unconditional generation
|
| 84 |
+
out_dir=None, # None to return in memory (no file output)
|
| 85 |
+
n_batches=1, # Generate 1 batch
|
| 86 |
+
)
|
| 87 |
+
return_str = "RDF3 test passed! Generated structures:\n"
|
| 88 |
+
|
| 89 |
+
for idx, data in outputs.items():
|
| 90 |
+
return_str += f"Batch {idx}: {len(data)} structure(s)\n"
|
| 91 |
+
for i, struct in enumerate(data):
|
| 92 |
+
return_str += f"Structure {i+1}: {len(struct)} atoms\n"
|
| 93 |
+
return_str += struct.atom_array
|
| 94 |
+
return return_str
|
| 95 |
+
except Exception as e:
|
| 96 |
+
return f"Error: {str(e)}"
|
| 97 |
+
|
| 98 |
|
| 99 |
# Gradio UI
|
| 100 |
with gr.Blocks(title="RFD3 Test") as demo:
|
|
|
|
| 104 |
test_btn = gr.Button("Run RFD3 Test")
|
| 105 |
output = gr.Textbox(label="Test Result")
|
| 106 |
|
| 107 |
+
test_btn.click(test_rfd3_from_notebook, outputs=output)
|
| 108 |
|
| 109 |
if __name__ == "__main__":
|
| 110 |
demo.launch()
|