from rfd3.engine import RFD3InferenceConfig, RFD3InferenceEngine import gradio as gr from lightning.fabric import seed_everything import time import os import spaces import subprocess import gzip import gemmi @spaces.GPU(duration=300) def test_rfd3_from_notebook(): # Set seed for reproducibility seed_everything(0) # Configure RFD3 inference config = RFD3InferenceConfig( specification={ 'length': 40, # Generate 80-residue proteins }, diffusion_batch_size=2, # Generate 2 structures per batch ) # Initialize engine and run generation try: model = RFD3InferenceEngine(**config) outputs = model.run( inputs=None, # None for unconditional generation out_dir=None, # None to return in memory (no file output) n_batches=1, # Generate 1 batch ) return_str = "RDF3 test passed! Generated structures:\n" for idx, data in outputs.items(): return_str += f"Batch {idx}: {len(data)} structure(s)\n" for i, struct in enumerate(data): return_str += f"Structure {i+1}: {struct.atom_array.array_length()} Atoms\n" #return_str += struct.atom_array return return_str except Exception as e: return f"Error: {str(e)}" # Initialize engine and run generation @spaces.GPU(duration=300) def unconditional_generation(num_batches, num_designs_per_batch, length): config = RFD3InferenceConfig( specification={ 'length': length, }, diffusion_batch_size=num_designs_per_batch, # Generate 2 structures per batch ) session_hash = gr.Request().session_hash time_stamp = time.strftime("%Y-%m-%d-%H-%M-%S") directory = f"./outputs/unconditional_generation/session_{session_hash}_{time_stamp}" os.makedirs(directory, exist_ok=False) try: model = RFD3InferenceEngine(**config) outputs = model.run( inputs=None, # None for unconditional generation out_dir=directory, # None to return in memory (no file output) n_batches=num_batches, # Generate 1 batch ) results = [] for batch in range(num_batches): for design in range(num_designs_per_batch): file_name = os.path.join(directory, f"_{batch}_model_{design}.cif.gz") results.append({"batch": batch, "design": design, "cif_path": file_name, "pdb_path": mcif_gz_to_pdb(file_name)}) print(results) return directory, results except Exception as e: raise RuntimeError(f"Error during generation: {str(e)}") def collect_outputs(gen_directory, num_batches, num_designs_per_batch): try: cmd = f"ls -R {gen_directory}" file_list = subprocess.check_output(cmd, shell=True).decode() return file_list except Exception as e: return f"Error: {str(e)}" def mcif_gz_to_pdb(file_path: str) -> str: """ Converts a .mcif.gz file to pdb and saves it to the same directory. Returns the path to the pdb file. Parameters: ---------- file_path: str, Path to the .mcif.gz file. Returns ------- str: path to the generated pdb file. """ st = gemmi.read_structure(file_path) st.setup_entities() # Recommended for consistent entity handling [web:18] pdb_path = file_path.replace(".cif.gz", ".pdb") st.write_minimal_pdb(pdb_path) return pdb_path