gabboud commited on
Commit
aced7ae
·
1 Parent(s): 5b8fa42

parse rf3d output folder

Browse files
Files changed (2) hide show
  1. app.py +53 -2
  2. utils/pipelines.py +37 -6
app.py CHANGED
@@ -10,6 +10,7 @@ from lightning.fabric import seed_everything
10
  from rfd3.engine import RFD3InferenceConfig, RFD3InferenceEngine
11
  from utils import download_weights
12
  from utils.pipelines import test_rfd3_from_notebook, unconditional_generation
 
13
 
14
 
15
  download_weights()
@@ -50,11 +51,61 @@ with gr.Blocks(title="RFD3 Test") as demo:
50
  maximum=200
51
  )
52
 
 
 
53
  gen_btn = gr.Button("Run Unconditional Generation")
54
- gen_output = gr.Textbox(label="Generation Result")
55
- gen_btn.click(unconditional_generation, inputs=[num_batches, num_designs_per_batch, length], outputs=gen_output)
56
 
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  if __name__ == "__main__":
60
  demo.launch()
 
10
  from rfd3.engine import RFD3InferenceConfig, RFD3InferenceEngine
11
  from utils import download_weights
12
  from utils.pipelines import test_rfd3_from_notebook, unconditional_generation
13
+ from gradio_molecule3d import Molecule3D
14
 
15
 
16
  download_weights()
 
51
  maximum=200
52
  )
53
 
54
+ gen_directory = gr.State(None)
55
+ gen_results = gr.State(None)
56
  gen_btn = gr.Button("Run Unconditional Generation")
57
+ #gen_output = gr.Textbox(label="Generation Result")
58
+ gen_btn.click(unconditional_generation, inputs=[num_batches, num_designs_per_batch, length], outputs=[gen_directory, gen_results])
59
 
60
 
61
+ # New visualize section
62
+ #with gr.Row():
63
+ # viz_btn = gr.Button("Visualize", visible=True)
64
+ # batch_dropdown = gr.Dropdown(
65
+ # choices=[],
66
+ # label="Select Batch",
67
+ # visible=False
68
+ # )
69
+ # design_dropdown = gr.Dropdown(
70
+ # choices=[],
71
+ # label="Select Design",
72
+ # visible=False
73
+ # )
74
+ # viewer = Molecule3D(visible=False)
75
+ #
76
+ #def toggle_visualize(result):
77
+ # if result is None:
78
+ # return gr.Dropdown(visible=False), gr.Dropdown(visible=False), Molecule3D(visible=False)
79
+ # batches = sorted(list({d["batch"] for d in result}))
80
+ # return (
81
+ # gr.update(choices=batches, visible=True), # Batch dropdown
82
+ # gr.update(choices=[], visible=True), # Design empty initially
83
+ # gr.update(visible=False)
84
+ # )
85
+ #
86
+ #def update_designs(batch, result):
87
+ # if batch is None:
88
+ # return gr.update(choices=[])
89
+ # designs = sorted(list({d["design"] for d in result if d["batch"] == batch}))
90
+ # return gr.update(choices=designs)
91
+ #
92
+ #def load_viewer(batch, design, result):
93
+ # if batch is None or design is None:
94
+ # return gr.update(visible=False)
95
+ # pdb_data = next(d["pdb"] for d in result if d["batch"] == int(batch) and d["design"] == int(design))
96
+ # return gr.update(value=pdb_data, visible=True, reps=[{"style": "cartoon"}]) # Customize style
97
+ #
98
+ ## Events
99
+ #viz_btn.click(toggle_visualize, inputs=gen_directory, outputs=[batch_dropdown, design_dropdown, viewer])
100
+ #batch_dropdown.change(update_designs, inputs=[batch_dropdown, gen_directory], outputs=design_dropdown)
101
+ #batch_dropdown.select(fn=update_designs, inputs=[batch_dropdown, gen_directory], outputs=design_dropdown) # For selection
102
+ #gr.Dropdown.select(update_designs, batch_dropdown, design_dropdown).then( # Chain
103
+ # lambda b, d, r: load_viewer(b, d, r),
104
+ # inputs=[batch_dropdown, design_dropdown, gen_directory],
105
+ # outputs=viewer
106
+ #)
107
+
108
+
109
 
110
  if __name__ == "__main__":
111
  demo.launch()
utils/pipelines.py CHANGED
@@ -5,6 +5,9 @@ import time
5
  import os
6
  import spaces
7
  import subprocess
 
 
 
8
 
9
 
10
  @spaces.GPU(duration=300)
@@ -54,10 +57,9 @@ def unconditional_generation(num_batches, num_designs_per_batch, length):
54
 
55
  session_hash = gr.Request().session_hash
56
  time_stamp = time.strftime("%Y-%m-%d-%H-%M-%S")
57
- directory = f"./outputs/session_{session_hash}_{time_stamp}"
58
  os.makedirs(directory, exist_ok=False)
59
 
60
-
61
  try:
62
  model = RFD3InferenceEngine(**config)
63
  outputs = model.run(
@@ -66,10 +68,39 @@ def unconditional_generation(num_batches, num_designs_per_batch, length):
66
  n_batches=num_batches, # Generate 1 batch
67
  )
68
 
69
- cmd = f"ls -R {directory}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  file_list = subprocess.check_output(cmd, shell=True).decode()
71
-
72
  return file_list
73
-
74
  except Exception as e:
75
- return f"Error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import os
6
  import spaces
7
  import subprocess
8
+ from Bio.PDB import MMCIFParser, PDBIO
9
+ import gzip
10
+
11
 
12
 
13
  @spaces.GPU(duration=300)
 
57
 
58
  session_hash = gr.Request().session_hash
59
  time_stamp = time.strftime("%Y-%m-%d-%H-%M-%S")
60
+ directory = f"./outputs/unconditional_generation/session_{session_hash}_{time_stamp}"
61
  os.makedirs(directory, exist_ok=False)
62
 
 
63
  try:
64
  model = RFD3InferenceEngine(**config)
65
  outputs = model.run(
 
68
  n_batches=num_batches, # Generate 1 batch
69
  )
70
 
71
+ results = {}
72
+ for batch in range(num_batches):
73
+ for design in range(num_designs_per_batch):
74
+ file_name = os.path.joint(directory, f"_{batch}_{design}.cif.gz")
75
+ results.append({"batch": batch, "design": design, "file": file_name, "pdb": cif_gz_to_pdb(file_name)})
76
+
77
+ print(results)
78
+ return directory, results
79
+
80
+ except Exception as e:
81
+ return f"Error: {str(e)}"
82
+
83
+ def collect_outputs(gen_directory, num_batches, num_designs_per_batch):
84
+ try:
85
+ cmd = f"ls -R {gen_directory}"
86
  file_list = subprocess.check_output(cmd, shell=True).decode()
 
87
  return file_list
 
88
  except Exception as e:
89
+ return f"Error: {str(e)}"
90
+
91
+
92
+ def cif_gz_to_pdb(cif_gz_path):
93
+ """Convert .cif.gz to PDB string for viewer."""
94
+ # Decompress & parse
95
+ parser = MMCIFParser(QUIET=True)
96
+ with gzip.open(cif_gz_path, 'rt') as f:
97
+ struct = parser.get_structure('model', f)
98
+
99
+ # Write to string
100
+ io = PDBIO()
101
+ io.set_structure(struct)
102
+ pdb_lines = []
103
+ class StringIO:
104
+ def write(self, s): pdb_lines.append(s)
105
+ io.save(StringIO())
106
+ return ''.join(pdb_lines)