gabboud commited on
Commit
7f14dfe
·
1 Parent(s): 8649d5a

unconditional generation with input file pipeline

Browse files
Files changed (3) hide show
  1. app.py +5 -2
  2. utils/handle_files.py +11 -1
  3. utils/pipelines.py +64 -6
app.py CHANGED
@@ -9,7 +9,7 @@ from atomworks.io.utils.visualize import view
9
  from lightning.fabric import seed_everything
10
  from rfd3.engine import RFD3InferenceConfig, RFD3InferenceEngine
11
  from utils.download_weights import download_weights
12
- from utils.pipelines import test_rfd3_from_notebook, unconditional_generation
13
  #from gradio_molecule3d import Molecule3D
14
  from utils.handle_events import *
15
  from utils.handle_files import *
@@ -52,7 +52,7 @@ with gr.Blocks(title="RFD3 Test") as demo:
52
  gr.Markdown("Set up the configuration for your run through a valid yaml file or by manually setting minimal parameters for an unconditional run.")
53
  with gr.Tabs() as config_tabs:
54
  with gr.TabItem("Upload Config") as upload_tab: # upload a config yaml or json
55
- config_upload = gr.File(label="PDB + Config", file_types=[".pdb", ".yaml", ".json"])
56
  with gr.TabItem("Manual Config") as manual_tab: # minimal config for testing unconditional generation
57
  num_designs_per_batch = gr.Number(
58
  value=2,
@@ -123,6 +123,9 @@ with gr.Blocks(title="RFD3 Test") as demo:
123
  def generate(config_ready, scaffold_ready, num_batches, num_designs_per_batch, length):
124
  if config_ready is None or scaffold_ready is None:
125
  return None, None
 
 
 
126
  if config_ready=="manual" and scaffold_ready=="no_input":
127
  gen_directory, gen_results = unconditional_generation(num_batches, num_designs_per_batch, length)
128
  return gen_directory, gen_results
 
9
  from lightning.fabric import seed_everything
10
  from rfd3.engine import RFD3InferenceConfig, RFD3InferenceEngine
11
  from utils.download_weights import download_weights
12
+ from utils.pipelines import test_rfd3_from_notebook, unconditional_generation, unconditional_generation_with_input_config
13
  #from gradio_molecule3d import Molecule3D
14
  from utils.handle_events import *
15
  from utils.handle_files import *
 
52
  gr.Markdown("Set up the configuration for your run through a valid yaml file or by manually setting minimal parameters for an unconditional run.")
53
  with gr.Tabs() as config_tabs:
54
  with gr.TabItem("Upload Config") as upload_tab: # upload a config yaml or json
55
+ config_upload = gr.File(label="Config file: .yaml or .json", file_types=[".pdb", ".yaml", ".json"])
56
  with gr.TabItem("Manual Config") as manual_tab: # minimal config for testing unconditional generation
57
  num_designs_per_batch = gr.Number(
58
  value=2,
 
123
  def generate(config_ready, scaffold_ready, num_batches, num_designs_per_batch, length):
124
  if config_ready is None or scaffold_ready is None:
125
  return None, None
126
+ if config_ready is "upload" and scaffold_ready is "no_input":
127
+ gen_directory, gen_results = unconditional_generation_with_input_config(config_upload)
128
+ return gen_directory, gen_results
129
  if config_ready=="manual" and scaffold_ready=="no_input":
130
  gen_directory, gen_results = unconditional_generation(num_batches, num_designs_per_batch, length)
131
  return gen_directory, gen_results
utils/handle_files.py CHANGED
@@ -2,6 +2,7 @@ import gemmi
2
  import os
3
  import shutil
4
  import gradio as gr
 
5
 
6
  def mcif_gz_to_pdb(file_path: str) -> str:
7
  """
@@ -40,4 +41,13 @@ def download_results_as_zip(directory):
40
  return gr.update()
41
  zip_path = f"{directory}.zip"
42
  shutil.make_archive(directory, 'zip', directory)
43
- return gr.update(value=zip_path, visible=True)
 
 
 
 
 
 
 
 
 
 
2
  import os
3
  import shutil
4
  import gradio as gr
5
+ import subprocess
6
 
7
  def mcif_gz_to_pdb(file_path: str) -> str:
8
  """
 
41
  return gr.update()
42
  zip_path = f"{directory}.zip"
43
  shutil.make_archive(directory, 'zip', directory)
44
+ return gr.update(value=zip_path, visible=True)
45
+
46
+
47
+ def collect_outputs(gen_directory, num_batches, num_designs_per_batch):
48
+ try:
49
+ cmd = f"ls -R {gen_directory}"
50
+ file_list = subprocess.check_output(cmd, shell=True).decode()
51
+ return file_list
52
+ except Exception as e:
53
+ return f"Error: {str(e)}"
utils/pipelines.py CHANGED
@@ -46,6 +46,26 @@ def test_rfd3_from_notebook():
46
  # Initialize engine and run generation
47
  @spaces.GPU(duration=300)
48
  def unconditional_generation(num_batches, num_designs_per_batch, length):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  config = RFD3InferenceConfig(
51
  specification={
@@ -79,12 +99,50 @@ def unconditional_generation(num_batches, num_designs_per_batch, length):
79
  except Exception as e:
80
  raise RuntimeError(f"Error during generation: {str(e)}")
81
 
82
- def collect_outputs(gen_directory, num_batches, num_designs_per_batch):
83
- try:
84
- cmd = f"ls -R {gen_directory}"
85
- file_list = subprocess.check_output(cmd, shell=True).decode()
86
- return file_list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  except Exception as e:
88
- return f"Error: {str(e)}"
89
 
90
 
 
46
  # Initialize engine and run generation
47
  @spaces.GPU(duration=300)
48
  def unconditional_generation(num_batches, num_designs_per_batch, length):
49
+ """
50
+ Runs an unconditional generation with the specified parameters for number of batches, number of designs per batch, and length of the generated proteins. Saves the generated structures to a timestamped directory in the outputs folder and returns the path to the directory along with a list of the generated structures' file paths.
51
+
52
+ Parameters:
53
+ ----------
54
+ num_batches: int or gr.Number,
55
+ The number of batches to generate.
56
+ num_designs_per_batch: int or gr.Number,
57
+ The number of designs to generate per batch.
58
+ length: int or gr.Number,
59
+ The length of the generated proteins.
60
+
61
+ Returns:
62
+ -------
63
+ directory: str,
64
+ The path to the directory where the generated structures are saved.
65
+ results: list of dicts,
66
+ A list of the generated structures' file paths, where each dict contains batch number "batch", design number "design", path to cif file "cif_path", and path to pdb file "pdb_path".
67
+ """
68
+
69
 
70
  config = RFD3InferenceConfig(
71
  specification={
 
99
  except Exception as e:
100
  raise RuntimeError(f"Error during generation: {str(e)}")
101
 
102
+ def unconditional_generation_with_input_config(input_file):
103
+ """
104
+ Runs an unconditional generation with the specified input config file. Saves the generated structures to a timestamped directory in the outputs folder and returns the path to the directory along with a list of the generated structures' file paths.
105
+
106
+ Parameters:
107
+ ----------
108
+ input_file: gr.File,
109
+ gr.File object containing the uploaded config file (yaml or json). input_file.name is the path to the uploaded file on the server.
110
+
111
+ Returns:
112
+ -------
113
+ directory: str,
114
+ The path to the directory where the generated structures are saved.
115
+ results: list of dicts,
116
+ A list of the generated structures' file paths, where each dict contains batch number "batch", design number "design", path to cif file "cif_path", and path to pdb file "pdb_path".
117
+
118
+ """
119
+
120
+ session_hash = gr.Request().session_hash
121
+ time_stamp = time.strftime("%Y-%m-%d-%H-%M-%S")
122
+ directory = f"./outputs/unconditional_generation/session_{session_hash}_{time_stamp}"
123
+ os.makedirs(directory, exist_ok=False)
124
+
125
+ try:
126
+
127
+ command = f"rfd3 design inputs={input_file.name} out_dir={directory}"
128
+ print(f"Running command: {command}")
129
+ subprocess.run(command, shell=True, check=True)
130
+
131
+ results = []
132
+ for file_name in os.listdir(directory):
133
+ if file_name.endswith(".cif.gz"):
134
+ terms = file_name.split("_")
135
+ model_index = terms.index("model")
136
+ batch =model_index - 1
137
+ design = model_index + 1
138
+ cif_path = os.path.join(directory, file_name)
139
+ pdb_path = mcif_gz_to_pdb(cif_path)
140
+ results.append({"batch": batch, "design": design, "cif_path": cif_path, "pdb_path": pdb_path})
141
+
142
+ print(results)
143
+ return directory, results
144
+
145
  except Exception as e:
146
+ raise RuntimeError(f"Error during generation: {str(e)}")
147
 
148