Spaces:
Running on Zero
Running on Zero
| from rfd3.engine import RFD3InferenceConfig, RFD3InferenceEngine | |
| import gradio as gr | |
| from lightning.fabric import seed_everything | |
| import time | |
| import os | |
| import spaces | |
| import subprocess | |
| import gzip | |
| import gemmi | |
| def test_rfd3_from_notebook(): | |
| # Set seed for reproducibility | |
| seed_everything(0) | |
| # Configure RFD3 inference | |
| config = RFD3InferenceConfig( | |
| specification={ | |
| 'length': 40, # Generate 80-residue proteins | |
| }, | |
| diffusion_batch_size=2, # Generate 2 structures per batch | |
| ) | |
| # Initialize engine and run generation | |
| try: | |
| model = RFD3InferenceEngine(**config) | |
| outputs = model.run( | |
| inputs=None, # None for unconditional generation | |
| out_dir=None, # None to return in memory (no file output) | |
| n_batches=1, # Generate 1 batch | |
| ) | |
| return_str = "RDF3 test passed! Generated structures:\n" | |
| for idx, data in outputs.items(): | |
| return_str += f"Batch {idx}: {len(data)} structure(s)\n" | |
| for i, struct in enumerate(data): | |
| return_str += f"Structure {i+1}: {struct.atom_array.array_length()} Atoms\n" | |
| #return_str += struct.atom_array | |
| return return_str | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| # Initialize engine and run generation | |
| def unconditional_generation(num_batches, num_designs_per_batch, length): | |
| config = RFD3InferenceConfig( | |
| specification={ | |
| 'length': length, | |
| }, | |
| diffusion_batch_size=num_designs_per_batch, # Generate 2 structures per batch | |
| ) | |
| session_hash = gr.Request().session_hash | |
| time_stamp = time.strftime("%Y-%m-%d-%H-%M-%S") | |
| directory = f"./outputs/unconditional_generation/session_{session_hash}_{time_stamp}" | |
| os.makedirs(directory, exist_ok=False) | |
| try: | |
| model = RFD3InferenceEngine(**config) | |
| outputs = model.run( | |
| inputs=None, # None for unconditional generation | |
| out_dir=directory, # None to return in memory (no file output) | |
| n_batches=num_batches, # Generate 1 batch | |
| ) | |
| results = [] | |
| for batch in range(num_batches): | |
| for design in range(num_designs_per_batch): | |
| file_name = os.path.join(directory, f"_{batch}_model_{design}.cif.gz") | |
| results.append({"batch": batch, "design": design, "cif_path": file_name, "pdb_path": mcif_gz_to_pdb(file_name)}) | |
| print(results) | |
| return directory, results | |
| except Exception as e: | |
| raise RuntimeError(f"Error during generation: {str(e)}") | |
| def collect_outputs(gen_directory, num_batches, num_designs_per_batch): | |
| try: | |
| cmd = f"ls -R {gen_directory}" | |
| file_list = subprocess.check_output(cmd, shell=True).decode() | |
| return file_list | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| def mcif_gz_to_pdb(file_path: str) -> str: | |
| """ | |
| Converts a .mcif.gz file to pdb and saves it to the same directory. Returns the path to the pdb file. | |
| Parameters: | |
| ---------- | |
| file_path: str, | |
| Path to the .mcif.gz file. | |
| Returns | |
| ------- | |
| str: path to the generated pdb file. | |
| """ | |
| st = gemmi.read_structure(file_path) | |
| st.setup_entities() # Recommended for consistent entity handling [web:18] | |
| pdb_path = file_path.replace(".cif.gz", ".pdb") | |
| st.write_minimal_pdb(pdb_path) | |
| return pdb_path | |