Spaces:
Running on Zero
Running on Zero
File size: 3,541 Bytes
94ff1b9 c1e206f 5b8fa42 aced7ae c49e7b8 aced7ae 94ff1b9 26f1fe3 94ff1b9 aced7ae 94ff1b9 464a533 aced7ae 8705e46 46eb3c3 aced7ae 4f422de aced7ae 5b8fa42 94ff1b9 aced7ae 46eb3c3 c49e7b8 46eb3c3 c49e7b8 46eb3c3 c49e7b8 46eb3c3 c49e7b8 5d306a8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 | from rfd3.engine import RFD3InferenceConfig, RFD3InferenceEngine
import gradio as gr
from lightning.fabric import seed_everything
import time
import os
import spaces
import subprocess
import gzip
import gemmi
@spaces.GPU(duration=300)
def test_rfd3_from_notebook():
# Set seed for reproducibility
seed_everything(0)
# Configure RFD3 inference
config = RFD3InferenceConfig(
specification={
'length': 40, # Generate 80-residue proteins
},
diffusion_batch_size=2, # Generate 2 structures per batch
)
# Initialize engine and run generation
try:
model = RFD3InferenceEngine(**config)
outputs = model.run(
inputs=None, # None for unconditional generation
out_dir=None, # None to return in memory (no file output)
n_batches=1, # Generate 1 batch
)
return_str = "RDF3 test passed! Generated structures:\n"
for idx, data in outputs.items():
return_str += f"Batch {idx}: {len(data)} structure(s)\n"
for i, struct in enumerate(data):
return_str += f"Structure {i+1}: {struct.atom_array.array_length()} Atoms\n"
#return_str += struct.atom_array
return return_str
except Exception as e:
return f"Error: {str(e)}"
# Initialize engine and run generation
@spaces.GPU(duration=300)
def unconditional_generation(num_batches, num_designs_per_batch, length):
config = RFD3InferenceConfig(
specification={
'length': length,
},
diffusion_batch_size=num_designs_per_batch, # Generate 2 structures per batch
)
session_hash = gr.Request().session_hash
time_stamp = time.strftime("%Y-%m-%d-%H-%M-%S")
directory = f"./outputs/unconditional_generation/session_{session_hash}_{time_stamp}"
os.makedirs(directory, exist_ok=False)
try:
model = RFD3InferenceEngine(**config)
outputs = model.run(
inputs=None, # None for unconditional generation
out_dir=directory, # None to return in memory (no file output)
n_batches=num_batches, # Generate 1 batch
)
results = []
for batch in range(num_batches):
for design in range(num_designs_per_batch):
file_name = os.path.join(directory, f"_{batch}_model_{design}.cif.gz")
results.append({"batch": batch, "design": design, "cif_path": file_name, "pdb_path": mcif_gz_to_pdb(file_name)})
print(results)
return directory, results
except Exception as e:
raise RuntimeError(f"Error during generation: {str(e)}")
def collect_outputs(gen_directory, num_batches, num_designs_per_batch):
try:
cmd = f"ls -R {gen_directory}"
file_list = subprocess.check_output(cmd, shell=True).decode()
return file_list
except Exception as e:
return f"Error: {str(e)}"
def mcif_gz_to_pdb(file_path: str) -> str:
"""
Converts a .mcif.gz file to pdb and saves it to the same directory. Returns the path to the pdb file.
Parameters:
----------
file_path: str,
Path to the .mcif.gz file.
Returns
-------
str: path to the generated pdb file.
"""
st = gemmi.read_structure(file_path)
st.setup_entities() # Recommended for consistent entity handling [web:18]
pdb_path = file_path.replace(".cif.gz", ".pdb")
st.write_minimal_pdb(pdb_path)
return pdb_path
|