ysdede commited on
Commit
69a1365
·
1 Parent(s): 315f517

Fix ZeroGPU pickle error: extract gr.SelectData coords before GPU call

Browse files
Files changed (1) hide show
  1. hugging_face/app.py +22 -8
hugging_face/app.py CHANGED
@@ -201,29 +201,35 @@ def get_end_number(track_pause_number_slider, video_state, interactive_state):
201
  return video_state["painted_images"][track_pause_number_slider],interactive_state
202
 
203
  # use sam to get the mask
 
 
 
204
  @spaces.GPU(duration=60)
205
- def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr.SelectData):
206
  """
 
207
  Args:
208
- template_frame: PIL.Image
209
- point_prompt: flag for positive or negative button click
210
  click_state: [[points], [labels]]
 
 
211
  """
212
  if point_prompt == "Positive":
213
- coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1])
214
  interactive_state["positive_click_times"] += 1
215
  else:
216
- coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1])
217
  interactive_state["negative_click_times"] += 1
218
-
219
  # prompt for sam model
220
  ensure_sam_on_cuda()
221
  model.samcontroler.sam_controler.reset_image()
222
  model.samcontroler.sam_controler.set_image(video_state["origin_images"][video_state["select_frame_number"]])
223
  prompt = get_prompt(click_state=click_state, click_input=coordinate)
224
 
225
- mask, logit, painted_image = model.first_frame_click(
226
- image=video_state["origin_images"][video_state["select_frame_number"]],
227
  points=np.array(prompt["input_point"]),
228
  labels=np.array(prompt["input_label"]),
229
  multimask=prompt["multimask_output"],
@@ -234,6 +240,14 @@ def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr
234
 
235
  return painted_image, video_state, interactive_state
236
 
 
 
 
 
 
 
 
 
237
  def add_multi_mask(video_state, interactive_state, mask_dropdown):
238
  mask = video_state["masks"][video_state["select_frame_number"]]
239
  interactive_state["multi_mask"]["masks"].append(mask)
 
201
  return video_state["painted_images"][track_pause_number_slider],interactive_state
202
 
203
  # use sam to get the mask
204
+ # ZeroGPU: gr.SelectData cannot be pickled (contains lambdas from Gradio's State.__init__).
205
+ # We split into an outer wrapper that extracts plain data from the event,
206
+ # and an inner @spaces.GPU function that receives only picklable arguments.
207
  @spaces.GPU(duration=60)
208
+ def _sam_refine_gpu(video_state, point_prompt, click_state, interactive_state, click_x, click_y):
209
  """
210
+ Inner GPU function for SAM refinement.
211
  Args:
212
+ video_state: dict with video/image data
213
+ point_prompt: "Positive" or "Negative"
214
  click_state: [[points], [labels]]
215
+ interactive_state: dict with interaction state
216
+ click_x, click_y: integer pixel coordinates extracted from gr.SelectData
217
  """
218
  if point_prompt == "Positive":
219
+ coordinate = "[[{},{},1]]".format(click_x, click_y)
220
  interactive_state["positive_click_times"] += 1
221
  else:
222
+ coordinate = "[[{},{},0]]".format(click_x, click_y)
223
  interactive_state["negative_click_times"] += 1
224
+
225
  # prompt for sam model
226
  ensure_sam_on_cuda()
227
  model.samcontroler.sam_controler.reset_image()
228
  model.samcontroler.sam_controler.set_image(video_state["origin_images"][video_state["select_frame_number"]])
229
  prompt = get_prompt(click_state=click_state, click_input=coordinate)
230
 
231
+ mask, logit, painted_image = model.first_frame_click(
232
+ image=video_state["origin_images"][video_state["select_frame_number"]],
233
  points=np.array(prompt["input_point"]),
234
  labels=np.array(prompt["input_label"]),
235
  multimask=prompt["multimask_output"],
 
240
 
241
  return painted_image, video_state, interactive_state
242
 
243
+ def sam_refine(video_state, point_prompt, click_state, interactive_state, evt: gr.SelectData):
244
+ """
245
+ Outer wrapper: extracts plain picklable coordinates from gr.SelectData,
246
+ then delegates to the @spaces.GPU inner function.
247
+ """
248
+ click_x, click_y = int(evt.index[0]), int(evt.index[1])
249
+ return _sam_refine_gpu(video_state, point_prompt, click_state, interactive_state, click_x, click_y)
250
+
251
  def add_multi_mask(video_state, interactive_state, mask_dropdown):
252
  mask = video_state["masks"][video_state["select_frame_number"]]
253
  interactive_state["multi_mask"]["masks"].append(mask)