rayli commited on
Commit
fd27d70
·
verified ·
1 Parent(s): 4f2214c

Show output loading state during inference

Browse files
Files changed (1) hide show
  1. app.py +42 -4
app.py CHANGED
@@ -4768,7 +4768,7 @@ def run_predict_on_gpu(
4768
  enforce_connectivity_per_part: bool,
4769
  joint_decoding_confidence_temperature: float,
4770
  ):
4771
- yield from _get_active_app().predict_segmentation_payload(
4772
  mesh_path_value,
4773
  mesh_hash_value,
4774
  kinematic_tree_json,
@@ -4781,16 +4781,41 @@ def run_predict_on_gpu(
4781
  strict_face_postprocess,
4782
  enforce_connectivity_per_part,
4783
  joint_decoding_confidence_temperature,
4784
- )
 
 
 
 
 
 
 
 
4785
 
4786
 
4787
  def postprocess_segmentation_on_cpu(payload: dict[str, Any] | None):
4788
- return _get_active_app().postprocess_segmentation_payload(payload)
 
 
 
 
 
 
 
 
 
 
4789
 
4790
 
4791
  @_spaces_gpu
4792
  def run_motion_on_gpu(payload: dict[str, Any] | None):
4793
- yield from _get_active_app().predict_motion_payload(payload)
 
 
 
 
 
 
 
4794
 
4795
 
4796
  def finish_predict_on_cpu(payload: dict[str, Any] | None):
@@ -4804,6 +4829,9 @@ def prepare_inference_ui():
4804
  gr.update(interactive=False),
4805
  gr.update(value=None, interactive=False),
4806
  "Running inference...",
 
 
 
4807
  )
4808
 
4809
 
@@ -5139,6 +5167,9 @@ def create_gradio_app(app: InstructParticulateApp) -> gr.Blocks:
5139
  export_urdf_button,
5140
  urdf_zip,
5141
  status,
 
 
 
5142
  ],
5143
  queue=False,
5144
  )
@@ -5162,6 +5193,9 @@ def create_gradio_app(app: InstructParticulateApp) -> gr.Blocks:
5162
  inference_payload,
5163
  status,
5164
  export_urdf_button,
 
 
 
5165
  ],
5166
  )
5167
  postprocess_event = gpu_event.then(
@@ -5172,6 +5206,8 @@ def create_gradio_app(app: InstructParticulateApp) -> gr.Blocks:
5172
  query_visualization,
5173
  status,
5174
  export_urdf_button,
 
 
5175
  ],
5176
  )
5177
  motion_event = postprocess_event.then(
@@ -5181,6 +5217,8 @@ def create_gradio_app(app: InstructParticulateApp) -> gr.Blocks:
5181
  inference_payload,
5182
  status,
5183
  export_urdf_button,
 
 
5184
  ],
5185
  )
5186
  motion_event.then(
 
4768
  enforce_connectivity_per_part: bool,
4769
  joint_decoding_confidence_temperature: float,
4770
  ):
4771
+ for payload, status, export_button in _get_active_app().predict_segmentation_payload(
4772
  mesh_path_value,
4773
  mesh_hash_value,
4774
  kinematic_tree_json,
 
4781
  strict_face_postprocess,
4782
  enforce_connectivity_per_part,
4783
  joint_decoding_confidence_temperature,
4784
+ ):
4785
+ yield (
4786
+ payload,
4787
+ status,
4788
+ export_button,
4789
+ gr.update(),
4790
+ gr.update(),
4791
+ gr.update(),
4792
+ )
4793
 
4794
 
4795
  def postprocess_segmentation_on_cpu(payload: dict[str, Any] | None):
4796
+ next_payload, query_visualization, status, export_button = (
4797
+ _get_active_app().postprocess_segmentation_payload(payload)
4798
+ )
4799
+ return (
4800
+ next_payload,
4801
+ query_visualization,
4802
+ status,
4803
+ export_button,
4804
+ gr.update(),
4805
+ gr.update(),
4806
+ )
4807
 
4808
 
4809
  @_spaces_gpu
4810
  def run_motion_on_gpu(payload: dict[str, Any] | None):
4811
+ for next_payload, status, export_button in _get_active_app().predict_motion_payload(payload):
4812
+ yield (
4813
+ next_payload,
4814
+ status,
4815
+ export_button,
4816
+ gr.update(),
4817
+ gr.update(),
4818
+ )
4819
 
4820
 
4821
  def finish_predict_on_cpu(payload: dict[str, Any] | None):
 
4829
  gr.update(interactive=False),
4830
  gr.update(value=None, interactive=False),
4831
  "Running inference...",
4832
+ gr.update(value=None),
4833
+ gr.update(value=None),
4834
+ gr.update(value=None),
4835
  )
4836
 
4837
 
 
5167
  export_urdf_button,
5168
  urdf_zip,
5169
  status,
5170
+ query_visualization,
5171
+ animated_model,
5172
+ prediction_model,
5173
  ],
5174
  queue=False,
5175
  )
 
5193
  inference_payload,
5194
  status,
5195
  export_urdf_button,
5196
+ query_visualization,
5197
+ animated_model,
5198
+ prediction_model,
5199
  ],
5200
  )
5201
  postprocess_event = gpu_event.then(
 
5206
  query_visualization,
5207
  status,
5208
  export_urdf_button,
5209
+ animated_model,
5210
+ prediction_model,
5211
  ],
5212
  )
5213
  motion_event = postprocess_event.then(
 
5217
  inference_payload,
5218
  status,
5219
  export_urdf_button,
5220
+ animated_model,
5221
+ prediction_model,
5222
  ],
5223
  )
5224
  motion_event.then(