gabboud commited on
Commit
f8364c2
·
1 Parent(s): 74f99c5

call rfd3 with process.Popen to allow cancellation

Browse files
Files changed (2) hide show
  1. app.py +16 -9
  2. utils/pipelines.py +51 -7
app.py CHANGED
@@ -13,6 +13,7 @@ from utils.pipelines import *
13
  #from gradio_molecule3d import Molecule3D
14
  from utils.handle_events import *
15
  from utils.handle_files import *
 
16
 
17
  download_weights()
18
 
@@ -66,6 +67,7 @@ with gr.Blocks(title="RFD3 Test") as demo:
66
 
67
  gen_directory = gr.State(None) # the directory where generation results are saved, used to trigger the download of results as zip file
68
  gen_results = gr.State(None) # the results of the generation, which is a list of dicts where each dict contains batch number "batch", design number "design", path to cif file "cif_path", and path to pdb file "pdb_path".
 
69
 
70
  # inputs from user
71
  with gr.Row():
@@ -131,19 +133,13 @@ with gr.Blocks(title="RFD3 Test") as demo:
131
  display_state = gr.Textbox(label="Selected Batch and Design", visible=True)
132
  display_state.value = "Please Select a Batch and Design number to show sequence"
133
 
134
- def generate(config_upload, scaffold_upload, num_batches, num_designs_per_batch, extra_args, max_duration):
135
- if config_upload is None:
136
- return gr.update(), None, None
137
- else:
138
- textbox_update, gen_directory, gen_results = generation_with_input_config(config_upload, scaffold_upload, num_batches, num_designs_per_batch, extra_args, max_duration)
139
- print(textbox_update)
140
- return textbox_update, gen_directory, gen_results
141
 
142
  generation_event = run_btn.click(
143
  lambda: (gr.update(visible=False), gr.update(visible=True)),
144
  outputs=[run_btn, stop_btn]
145
  ).then(
146
- generate, inputs=[config_upload, scaffold_upload, num_batches, num_designs_per_batch, extra_args, max_duration], outputs=[runtextbox, gen_directory, gen_results]
147
  ).then(
148
  update_batch_choices,
149
  inputs=gen_results,
@@ -156,9 +152,19 @@ with gr.Blocks(title="RFD3 Test") as demo:
156
  lambda: (gr.update(visible=True), gr.update(visible=False)),
157
  outputs=[run_btn, stop_btn]
158
  )
 
 
 
 
 
 
 
 
 
159
 
160
  stop_btn.click(
161
- lambda: gr.update(value="Generation cancelled by user."),
 
162
  outputs=runtextbox,
163
  cancels=[generation_event]
164
  ).then(
@@ -183,5 +189,6 @@ with gr.Blocks(title="RFD3 Test") as demo:
183
  #visualize_btn.click(load_viewer, inputs=[batch_dropdown, design_dropdown, gen_results], outputs=viewer)
184
 
185
  if __name__ == "__main__":
 
186
  demo.queue()
187
  demo.launch()
 
13
  #from gradio_molecule3d import Molecule3D
14
  from utils.handle_events import *
15
  from utils.handle_files import *
16
+ import signal
17
 
18
  download_weights()
19
 
 
67
 
68
  gen_directory = gr.State(None) # the directory where generation results are saved, used to trigger the download of results as zip file
69
  gen_results = gr.State(None) # the results of the generation, which is a list of dicts where each dict contains batch number "batch", design number "design", path to cif file "cif_path", and path to pdb file "pdb_path".
70
+ process = gr.State(None) # to store the subprocess.Popen object for the generation process, used to terminate the process when the stop button is clicked.
71
 
72
  # inputs from user
73
  with gr.Row():
 
133
  display_state = gr.Textbox(label="Selected Batch and Design", visible=True)
134
  display_state.value = "Please Select a Batch and Design number to show sequence"
135
 
136
+
 
 
 
 
 
 
137
 
138
  generation_event = run_btn.click(
139
  lambda: (gr.update(visible=False), gr.update(visible=True)),
140
  outputs=[run_btn, stop_btn]
141
  ).then(
142
+ generation_with_input_config, inputs=[config_upload, scaffold_upload, num_batches, num_designs_per_batch, extra_args, max_duration], outputs=[runtextbox, gen_directory, gen_results, process]
143
  ).then(
144
  update_batch_choices,
145
  inputs=gen_results,
 
152
  lambda: (gr.update(visible=True), gr.update(visible=False)),
153
  outputs=[run_btn, stop_btn]
154
  )
155
+
156
+ def stop_generation(proc_state):
157
+ #if proc_state is not None and proc_state.poll() is None:
158
+ # try:
159
+ # # Kill whole process group so child processes die too
160
+ # os.killpg(os.getpgid(proc_state.pid), signal.SIGTERM)
161
+ # except Exception:
162
+ # proc_state.terminate()
163
+ return gr.update(value="Generation cancelled by user.")
164
 
165
  stop_btn.click(
166
+ stop_generation,
167
+ inputs=process,
168
  outputs=runtextbox,
169
  cancels=[generation_event]
170
  ).then(
 
189
  #visualize_btn.click(load_viewer, inputs=[batch_dropdown, design_dropdown, gen_results], outputs=viewer)
190
 
191
  if __name__ == "__main__":
192
+
193
  demo.queue()
194
  demo.launch()
utils/pipelines.py CHANGED
@@ -42,7 +42,7 @@ def generation_with_input_config(input_file, pdb_file, num_batches, num_designs_
42
 
43
 
44
  if input_file is None:
45
- return gr.update(value="Please ensure you have uploaded a configuration file: .yaml or .json"), gr.update(), gr.update()
46
  elif pdb_file is None:
47
  status_update = f"Running generation for {num_batches} batches of {num_designs_per_batch}\n job configuration uploaded from file {os.path.basename(input_file)}\n no scaffold/target provided"
48
  else:
@@ -53,6 +53,7 @@ def generation_with_input_config(input_file, pdb_file, num_batches, num_designs_
53
  directory = f"./outputs/generation_with_input_config/session_{session_hash}_{time_stamp}"
54
  os.makedirs(directory, exist_ok=False)
55
 
 
56
  try:
57
  if pdb_file is not None:
58
  # I need to do this because uploading files to a HF space stores each file in a separate temp directory so I need to copy them again to the same place.
@@ -70,11 +71,37 @@ def generation_with_input_config(input_file, pdb_file, num_batches, num_designs_
70
  print(f"Running command: {command}")
71
  status_update += f"\nRunning command: {command}."
72
  start = perf_counter()
73
- res = subprocess.run(command, shell=True, check=True, text=True, capture_output=True)
74
- print("Command took", perf_counter() - start, "seconds to run.")
75
- status_update += f"\nGeneration successful! Command took {perf_counter() - start:.2f} seconds to run."
76
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
 
78
  results = []
79
  for file_name in os.listdir(directory):
80
  if file_name.endswith(".cif.gz"):
@@ -88,11 +115,28 @@ def generation_with_input_config(input_file, pdb_file, num_batches, num_designs_
88
  results.append({"batch": batch, "design": design, "cif_path": cif_path, "pdb_path": pdb_path})
89
 
90
  print(results)
91
- return gr.update(value=status_update), directory, results
92
 
93
  except subprocess.CalledProcessError as e:
94
  print("subprocess threw an error", e.stderr)
95
- return gr.update(value=f"Generation failed:\n{e.stderr}"), None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
 
98
  #def generation_with_input_config_factory(max_duration):
 
42
 
43
 
44
  if input_file is None:
45
+ return gr.update(value="Please ensure you have uploaded a configuration file: .yaml or .json"), gr.update(), gr.update(), gr.update()
46
  elif pdb_file is None:
47
  status_update = f"Running generation for {num_batches} batches of {num_designs_per_batch}\n job configuration uploaded from file {os.path.basename(input_file)}\n no scaffold/target provided"
48
  else:
 
53
  directory = f"./outputs/generation_with_input_config/session_{session_hash}_{time_stamp}"
54
  os.makedirs(directory, exist_ok=False)
55
 
56
+ process = None
57
  try:
58
  if pdb_file is not None:
59
  # I need to do this because uploading files to a HF space stores each file in a separate temp directory so I need to copy them again to the same place.
 
71
  print(f"Running command: {command}")
72
  status_update += f"\nRunning command: {command}."
73
  start = perf_counter()
74
+
75
+ # Start the process with Popen to allow cancellation
76
+ process = subprocess.Popen(
77
+ command,
78
+ shell=True,
79
+ stdout=subprocess.PIPE,
80
+ stderr=subprocess.PIPE,
81
+ text=True,
82
+ preexec_fn=os.setsid, # For Unix-like systems, start a new process group
83
+ )
84
+
85
+ # Immediately yield to make process available in State
86
+ yield gr.update(value=status_update), gr.update(), gr.update(), process
87
+
88
+ # Poll the process - yield regularly to allow Gradio to stop calling the generator
89
+ while process.poll() is None:
90
+ time.sleep(2) # Wait before checking again
91
+ yield gr.update(value=status_update), gr.update(), gr.update(), process
92
+
93
+ # Get the output after process completes
94
+ stdout, stderr = process.communicate()
95
+
96
+ # Check if process succeeded
97
+ if process.returncode != 0:
98
+ raise subprocess.CalledProcessError(process.returncode, command, stdout, stderr)
99
+
100
+ elapsed_time = perf_counter() - start
101
+ print(f"Command took {elapsed_time:.2f} seconds to run.")
102
+ status_update += f"\nGeneration successful! Command took {elapsed_time:.2f} seconds to run."
103
 
104
+ # Collect results from output directory
105
  results = []
106
  for file_name in os.listdir(directory):
107
  if file_name.endswith(".cif.gz"):
 
115
  results.append({"batch": batch, "design": design, "cif_path": cif_path, "pdb_path": pdb_path})
116
 
117
  print(results)
118
+ return gr.update(value=status_update), directory, results, process
119
 
120
  except subprocess.CalledProcessError as e:
121
  print("subprocess threw an error", e.stderr)
122
+ return gr.update(value=f"Generation failed:\n{e.stderr}"), None, None, None
123
+
124
+ except Exception as e:
125
+ # Handle cancellation or other errors - terminate the subprocess if it's still running
126
+ if process and process.poll() is None:
127
+ print(f"Terminating subprocess due to: {e}")
128
+ try:
129
+ os.killpg(os.getpgid(process.pid), 15) # SIGTERM to process group
130
+ except Exception:
131
+ process.terminate()
132
+ try:
133
+ process.wait(timeout=5) # Wait up to 5 seconds for graceful termination
134
+ except subprocess.TimeoutExpired:
135
+ try:
136
+ os.killpg(os.getpgid(process.pid), 9) # SIGKILL to process group
137
+ except Exception:
138
+ process.kill() # Force kill if killpg fails
139
+ raise # Re-raise the exception so Gradio knows the event was cancelled
140
 
141
 
142
  #def generation_with_input_config_factory(max_duration):