RFdiffusion3 / utils /pipelines.py
gabboud's picture
rename results fields, modularize event handling
46eb3c3
raw
history blame
3.54 kB
from rfd3.engine import RFD3InferenceConfig, RFD3InferenceEngine
import gradio as gr
from lightning.fabric import seed_everything
import time
import os
import spaces
import subprocess
import gzip
import gemmi
@spaces.GPU(duration=300)
def test_rfd3_from_notebook():
# Set seed for reproducibility
seed_everything(0)
# Configure RFD3 inference
config = RFD3InferenceConfig(
specification={
'length': 40, # Generate 80-residue proteins
},
diffusion_batch_size=2, # Generate 2 structures per batch
)
# Initialize engine and run generation
try:
model = RFD3InferenceEngine(**config)
outputs = model.run(
inputs=None, # None for unconditional generation
out_dir=None, # None to return in memory (no file output)
n_batches=1, # Generate 1 batch
)
return_str = "RDF3 test passed! Generated structures:\n"
for idx, data in outputs.items():
return_str += f"Batch {idx}: {len(data)} structure(s)\n"
for i, struct in enumerate(data):
return_str += f"Structure {i+1}: {struct.atom_array.array_length()} Atoms\n"
#return_str += struct.atom_array
return return_str
except Exception as e:
return f"Error: {str(e)}"
# Initialize engine and run generation
@spaces.GPU(duration=300)
def unconditional_generation(num_batches, num_designs_per_batch, length):
config = RFD3InferenceConfig(
specification={
'length': length,
},
diffusion_batch_size=num_designs_per_batch, # Generate 2 structures per batch
)
session_hash = gr.Request().session_hash
time_stamp = time.strftime("%Y-%m-%d-%H-%M-%S")
directory = f"./outputs/unconditional_generation/session_{session_hash}_{time_stamp}"
os.makedirs(directory, exist_ok=False)
try:
model = RFD3InferenceEngine(**config)
outputs = model.run(
inputs=None, # None for unconditional generation
out_dir=directory, # None to return in memory (no file output)
n_batches=num_batches, # Generate 1 batch
)
results = []
for batch in range(num_batches):
for design in range(num_designs_per_batch):
file_name = os.path.join(directory, f"_{batch}_model_{design}.cif.gz")
results.append({"batch": batch, "design": design, "cif_path": file_name, "pdb_path": mcif_gz_to_pdb(file_name)})
print(results)
return directory, results
except Exception as e:
raise RuntimeError(f"Error during generation: {str(e)}")
def collect_outputs(gen_directory, num_batches, num_designs_per_batch):
try:
cmd = f"ls -R {gen_directory}"
file_list = subprocess.check_output(cmd, shell=True).decode()
return file_list
except Exception as e:
return f"Error: {str(e)}"
def mcif_gz_to_pdb(file_path: str) -> str:
"""
Converts a .mcif.gz file to pdb and saves it to the same directory. Returns the path to the pdb file.
Parameters:
----------
file_path: str,
Path to the .mcif.gz file.
Returns
-------
str: path to the generated pdb file.
"""
st = gemmi.read_structure(file_path)
st.setup_entities() # Recommended for consistent entity handling [web:18]
pdb_path = file_path.replace(".cif.gz", ".pdb")
st.write_minimal_pdb(pdb_path)
return pdb_path