Spaces:
Runtime error
Runtime error
Add gr state
Browse files
app.py
CHANGED
|
@@ -19,18 +19,34 @@ def mkstemp(suffix, dir=None):
|
|
| 19 |
return Path(path)
|
| 20 |
|
| 21 |
|
| 22 |
-
|
| 23 |
-
#
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
|
| 28 |
-
def get_masked_img(img, w, h):
|
| 29 |
point_coords = [w, h]
|
| 30 |
point_labels = [1]
|
| 31 |
dilate_kernel_size = 15
|
| 32 |
|
| 33 |
-
model['sam'].
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
# masks, _, _ = predictor.predict(
|
| 35 |
masks, _, _ = model['sam'].predict(
|
| 36 |
point_coords=np.array([point_coords]),
|
|
@@ -98,6 +114,12 @@ model['lama'] = build_lama_model(lama_config, lama_ckpt, device=device)
|
|
| 98 |
|
| 99 |
|
| 100 |
with gr.Blocks() as demo:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
with gr.Row():
|
| 102 |
img = gr.Image(label="Image")
|
| 103 |
# img_pointed = gr.Image(label='Pointed Image')
|
|
@@ -146,9 +168,11 @@ with gr.Blocks() as demo:
|
|
| 146 |
# []
|
| 147 |
# )
|
| 148 |
# img.change(get_sam_feat, [img], [])
|
|
|
|
|
|
|
| 149 |
sam_mask.click(
|
| 150 |
get_masked_img,
|
| 151 |
-
[img, w, h],
|
| 152 |
[img_with_mask_0, img_with_mask_1, img_with_mask_2, mask_0, mask_1, mask_2]
|
| 153 |
)
|
| 154 |
|
|
|
|
| 19 |
return Path(path)
|
| 20 |
|
| 21 |
|
| 22 |
+
def get_sam_feat(img):
|
| 23 |
+
# predictor.set_image(img)
|
| 24 |
+
model['sam'].set_image(img)
|
| 25 |
+
features = model['sam'].features
|
| 26 |
+
orig_h = model['sam'].orig_h
|
| 27 |
+
orig_w = model['sam'].orig_w
|
| 28 |
+
input_h = model['sam'].input_h
|
| 29 |
+
input_w = model['sam'].input_w
|
| 30 |
+
return features, orig_h, orig_w, input_h, input_w
|
| 31 |
|
| 32 |
|
| 33 |
+
def get_masked_img(img, w, h, features, orig_h, orig_w, input_h, input_w):
|
| 34 |
point_coords = [w, h]
|
| 35 |
point_labels = [1]
|
| 36 |
dilate_kernel_size = 15
|
| 37 |
|
| 38 |
+
# model['sam'].is_image_set = False
|
| 39 |
+
model['sam'].features = features
|
| 40 |
+
model['sam'].orig_h = orig_h
|
| 41 |
+
model['sam'].orig_w = orig_w
|
| 42 |
+
model['sam'].input_h = input_h
|
| 43 |
+
model['sam'].input_w = input_w
|
| 44 |
+
# model['sam'].image_embedding = image_embedding
|
| 45 |
+
# model['sam'].original_size = original_size
|
| 46 |
+
# model['sam'].input_size = input_size
|
| 47 |
+
# model['sam'].is_image_set = True
|
| 48 |
+
|
| 49 |
+
# model['sam'].set_image(img)
|
| 50 |
# masks, _, _ = predictor.predict(
|
| 51 |
masks, _, _ = model['sam'].predict(
|
| 52 |
point_coords=np.array([point_coords]),
|
|
|
|
| 114 |
|
| 115 |
|
| 116 |
with gr.Blocks() as demo:
|
| 117 |
+
features = gr.State(None)
|
| 118 |
+
orig_h = gr.State(None)
|
| 119 |
+
orig_w = gr.State(None)
|
| 120 |
+
input_h = gr.State(None)
|
| 121 |
+
input_w = gr.State(None)
|
| 122 |
+
|
| 123 |
with gr.Row():
|
| 124 |
img = gr.Image(label="Image")
|
| 125 |
# img_pointed = gr.Image(label='Pointed Image')
|
|
|
|
| 168 |
# []
|
| 169 |
# )
|
| 170 |
# img.change(get_sam_feat, [img], [])
|
| 171 |
+
img.upload(get_sam_feat, [img], [features, orig_h, orig_w, input_h, input_w])
|
| 172 |
+
|
| 173 |
sam_mask.click(
|
| 174 |
get_masked_img,
|
| 175 |
+
[img, w, h, features, orig_h, orig_w, input_h, input_w],
|
| 176 |
[img_with_mask_0, img_with_mask_1, img_with_mask_2, mask_0, mask_1, mask_2]
|
| 177 |
)
|
| 178 |
|