Fix ZeroGPU pickle error: extract gr.SelectData coords before GPU call
Browse files- hugging_face/app.py +22 -8
hugging_face/app.py
CHANGED
|
@@ -201,29 +201,35 @@ def get_end_number(track_pause_number_slider, video_state, interactive_state):
|
|
| 201 |
return video_state["painted_images"][track_pause_number_slider],interactive_state
|
| 202 |
|
| 203 |
# use sam to get the mask
|
|
|
|
|
|
|
|
|
|
| 204 |
@spaces.GPU(duration=60)
|
| 205 |
-
def
|
| 206 |
"""
|
|
|
|
| 207 |
Args:
|
| 208 |
-
|
| 209 |
-
point_prompt:
|
| 210 |
click_state: [[points], [labels]]
|
|
|
|
|
|
|
| 211 |
"""
|
| 212 |
if point_prompt == "Positive":
|
| 213 |
-
coordinate = "[[{},{},1]]".format(
|
| 214 |
interactive_state["positive_click_times"] += 1
|
| 215 |
else:
|
| 216 |
-
coordinate = "[[{},{},0]]".format(
|
| 217 |
interactive_state["negative_click_times"] += 1
|
| 218 |
-
|
| 219 |
# prompt for sam model
|
| 220 |
ensure_sam_on_cuda()
|
| 221 |
model.samcontroler.sam_controler.reset_image()
|
| 222 |
model.samcontroler.sam_controler.set_image(video_state["origin_images"][video_state["select_frame_number"]])
|
| 223 |
prompt = get_prompt(click_state=click_state, click_input=coordinate)
|
| 224 |
|
| 225 |
-
mask, logit, painted_image = model.first_frame_click(
|
| 226 |
-
image=video_state["origin_images"][video_state["select_frame_number"]],
|
| 227 |
points=np.array(prompt["input_point"]),
|
| 228 |
labels=np.array(prompt["input_label"]),
|
| 229 |
multimask=prompt["multimask_output"],
|
|
@@ -234,6 +240,14 @@ def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr
|
|
| 234 |
|
| 235 |
return painted_image, video_state, interactive_state
|
| 236 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
def add_multi_mask(video_state, interactive_state, mask_dropdown):
|
| 238 |
mask = video_state["masks"][video_state["select_frame_number"]]
|
| 239 |
interactive_state["multi_mask"]["masks"].append(mask)
|
|
|
|
| 201 |
return video_state["painted_images"][track_pause_number_slider],interactive_state
|
| 202 |
|
| 203 |
# use sam to get the mask
|
| 204 |
+
# ZeroGPU: gr.SelectData cannot be pickled (contains lambdas from Gradio's State.__init__).
|
| 205 |
+
# We split into an outer wrapper that extracts plain data from the event,
|
| 206 |
+
# and an inner @spaces.GPU function that receives only picklable arguments.
|
| 207 |
@spaces.GPU(duration=60)
|
| 208 |
+
def _sam_refine_gpu(video_state, point_prompt, click_state, interactive_state, click_x, click_y):
|
| 209 |
"""
|
| 210 |
+
Inner GPU function for SAM refinement.
|
| 211 |
Args:
|
| 212 |
+
video_state: dict with video/image data
|
| 213 |
+
point_prompt: "Positive" or "Negative"
|
| 214 |
click_state: [[points], [labels]]
|
| 215 |
+
interactive_state: dict with interaction state
|
| 216 |
+
click_x, click_y: integer pixel coordinates extracted from gr.SelectData
|
| 217 |
"""
|
| 218 |
if point_prompt == "Positive":
|
| 219 |
+
coordinate = "[[{},{},1]]".format(click_x, click_y)
|
| 220 |
interactive_state["positive_click_times"] += 1
|
| 221 |
else:
|
| 222 |
+
coordinate = "[[{},{},0]]".format(click_x, click_y)
|
| 223 |
interactive_state["negative_click_times"] += 1
|
| 224 |
+
|
| 225 |
# prompt for sam model
|
| 226 |
ensure_sam_on_cuda()
|
| 227 |
model.samcontroler.sam_controler.reset_image()
|
| 228 |
model.samcontroler.sam_controler.set_image(video_state["origin_images"][video_state["select_frame_number"]])
|
| 229 |
prompt = get_prompt(click_state=click_state, click_input=coordinate)
|
| 230 |
|
| 231 |
+
mask, logit, painted_image = model.first_frame_click(
|
| 232 |
+
image=video_state["origin_images"][video_state["select_frame_number"]],
|
| 233 |
points=np.array(prompt["input_point"]),
|
| 234 |
labels=np.array(prompt["input_label"]),
|
| 235 |
multimask=prompt["multimask_output"],
|
|
|
|
| 240 |
|
| 241 |
return painted_image, video_state, interactive_state
|
| 242 |
|
| 243 |
+
def sam_refine(video_state, point_prompt, click_state, interactive_state, evt: gr.SelectData):
|
| 244 |
+
"""
|
| 245 |
+
Outer wrapper: extracts plain picklable coordinates from gr.SelectData,
|
| 246 |
+
then delegates to the @spaces.GPU inner function.
|
| 247 |
+
"""
|
| 248 |
+
click_x, click_y = int(evt.index[0]), int(evt.index[1])
|
| 249 |
+
return _sam_refine_gpu(video_state, point_prompt, click_state, interactive_state, click_x, click_y)
|
| 250 |
+
|
| 251 |
def add_multi_mask(video_state, interactive_state, mask_dropdown):
|
| 252 |
mask = video_state["masks"][video_state["select_frame_number"]]
|
| 253 |
interactive_state["multi_mask"]["masks"].append(mask)
|