Spaces:
Runtime error
Runtime error
s194649
commited on
Commit
·
640f5b4
1
Parent(s):
ae1d3bb
encoder decoder setup
Browse files- app.py +9 -7
- inference.py +70 -6
app.py
CHANGED
|
@@ -10,6 +10,7 @@ from utils import generate_PCL, PCL3, point_cloud
|
|
| 10 |
|
| 11 |
|
| 12 |
sam = SegmentPredictor()
|
|
|
|
| 13 |
dpt = DepthPredictor()
|
| 14 |
red = (255,0,0)
|
| 15 |
blue = (0,0,255)
|
|
@@ -30,6 +31,7 @@ with block:
|
|
| 30 |
cutout_idx = gr.State(set())
|
| 31 |
pred_masks = gr.State([])
|
| 32 |
prompt_masks = gr.State([])
|
|
|
|
| 33 |
|
| 34 |
# UI
|
| 35 |
with gr.Column():
|
|
@@ -73,7 +75,7 @@ with block:
|
|
| 73 |
sam_decode_btn = gr.Button('Predict using points!', variant = 'primary')
|
| 74 |
|
| 75 |
# components
|
| 76 |
-
components = {point_coords, point_labels, image_edit_trigger, masks, cutout_idx, input_image,
|
| 77 |
point_label_radio, text, reset_btn, sam_sgmt_everything_btn,
|
| 78 |
sam_decode_btn, depth_reconstruction_btn, prompt_image, lbl_image, n_samples, max_depth, min_depth, cube_size, selected_masks_image}
|
| 79 |
|
|
@@ -88,7 +90,7 @@ with block:
|
|
| 88 |
return input_image, point_coords_empty(), point_labels_empty(), None, []
|
| 89 |
reset_btn.click(on_reset_btn_click, [input_image], [input_image, point_coords, point_labels], queue=False)
|
| 90 |
|
| 91 |
-
def on_prompt_image_select(input_image, prompt_image, point_coords, point_labels, point_label_radio, text, pred_masks, evt: gr.SelectData):
|
| 92 |
x, y = evt.index
|
| 93 |
color = red if point_label_radio == 0 else blue
|
| 94 |
if prompt_image is None:
|
|
@@ -97,7 +99,7 @@ with block:
|
|
| 97 |
cv2.circle(prompt_image, (x, y), 5, color, -1)
|
| 98 |
point_coords.append([x,y])
|
| 99 |
point_labels.append(point_label_radio)
|
| 100 |
-
sam_masks =
|
| 101 |
return [ prompt_image,
|
| 102 |
(input_image, sam_masks),
|
| 103 |
point_coords,
|
|
@@ -105,7 +107,7 @@ with block:
|
|
| 105 |
sam_masks ]
|
| 106 |
|
| 107 |
prompt_image.select(on_prompt_image_select,
|
| 108 |
-
[input_image, prompt_image, point_coords, point_labels, point_label_radio, text, pred_masks],
|
| 109 |
[prompt_image, lbl_image, point_coords, point_labels, pred_masks], queue=False)
|
| 110 |
|
| 111 |
|
|
@@ -139,10 +141,10 @@ with block:
|
|
| 139 |
def on_click_sam_encode_btn(inputs):
|
| 140 |
print("encoding")
|
| 141 |
# encode image on click
|
| 142 |
-
sam.encode(inputs[input_image])
|
| 143 |
print("encoding done")
|
| 144 |
-
return
|
| 145 |
-
sam_encode_btn.click(on_click_sam_encode_btn, components, [prompt_image], queue=False)
|
| 146 |
|
| 147 |
def on_click_sam_dencode_btn(inputs):
|
| 148 |
print("inferencing")
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
sam = SegmentPredictor()
|
| 13 |
+
sam_cpu = SegmentPredictor(device='cpu')
|
| 14 |
dpt = DepthPredictor()
|
| 15 |
red = (255,0,0)
|
| 16 |
blue = (0,0,255)
|
|
|
|
| 31 |
cutout_idx = gr.State(set())
|
| 32 |
pred_masks = gr.State([])
|
| 33 |
prompt_masks = gr.State([])
|
| 34 |
+
embedding = gr.State()
|
| 35 |
|
| 36 |
# UI
|
| 37 |
with gr.Column():
|
|
|
|
| 75 |
sam_decode_btn = gr.Button('Predict using points!', variant = 'primary')
|
| 76 |
|
| 77 |
# components
|
| 78 |
+
components = {point_coords, point_labels, image_edit_trigger, masks, cutout_idx, input_image, embedding,
|
| 79 |
point_label_radio, text, reset_btn, sam_sgmt_everything_btn,
|
| 80 |
sam_decode_btn, depth_reconstruction_btn, prompt_image, lbl_image, n_samples, max_depth, min_depth, cube_size, selected_masks_image}
|
| 81 |
|
|
|
|
| 90 |
return input_image, point_coords_empty(), point_labels_empty(), None, []
|
| 91 |
reset_btn.click(on_reset_btn_click, [input_image], [input_image, point_coords, point_labels], queue=False)
|
| 92 |
|
| 93 |
+
def on_prompt_image_select(input_image, prompt_image, point_coords, point_labels, point_label_radio, text, pred_masks, embedding, evt: gr.SelectData):
|
| 94 |
x, y = evt.index
|
| 95 |
color = red if point_label_radio == 0 else blue
|
| 96 |
if prompt_image is None:
|
|
|
|
| 99 |
cv2.circle(prompt_image, (x, y), 5, color, -1)
|
| 100 |
point_coords.append([x,y])
|
| 101 |
point_labels.append(point_label_radio)
|
| 102 |
+
sam_masks = sam_cpu.cond_pred(pts=np.array(point_coords), lbls=np.array(point_labels), embedding=embedding)
|
| 103 |
return [ prompt_image,
|
| 104 |
(input_image, sam_masks),
|
| 105 |
point_coords,
|
|
|
|
| 107 |
sam_masks ]
|
| 108 |
|
| 109 |
prompt_image.select(on_prompt_image_select,
|
| 110 |
+
[input_image, prompt_image, point_coords, point_labels, point_label_radio, text, pred_masks, embedding],
|
| 111 |
[prompt_image, lbl_image, point_coords, point_labels, pred_masks], queue=False)
|
| 112 |
|
| 113 |
|
|
|
|
| 141 |
def on_click_sam_encode_btn(inputs):
|
| 142 |
print("encoding")
|
| 143 |
# encode image on click
|
| 144 |
+
embedding = sam.encode(inputs[input_image]).cpu()
|
| 145 |
print("encoding done")
|
| 146 |
+
return [inputs[input_image], embedding]
|
| 147 |
+
sam_encode_btn.click(on_click_sam_encode_btn, components, [prompt_image, embedding], queue=False)
|
| 148 |
|
| 149 |
def on_click_sam_dencode_btn(inputs):
|
| 150 |
print("inferencing")
|
inference.py
CHANGED
|
@@ -11,6 +11,10 @@ import pandas as pd
|
|
| 11 |
import plotly.express as px
|
| 12 |
import matplotlib.pyplot as plt
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
def map_image_range(image, min_value, max_value):
|
| 15 |
"""
|
| 16 |
Maps the values of a numpy image array to a specified range.
|
|
@@ -188,26 +192,86 @@ class DepthPredictor:
|
|
| 188 |
|
| 189 |
|
| 190 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
|
| 192 |
class SegmentPredictor:
|
| 193 |
-
def __init__(self):
|
| 194 |
MODEL_TYPE = "vit_h"
|
| 195 |
checkpoint = "sam_vit_h_4b8939.pth"
|
| 196 |
sam = sam_model_registry[MODEL_TYPE](checkpoint=checkpoint)
|
| 197 |
# Select device
|
| 198 |
-
|
|
|
|
|
|
|
|
|
|
| 199 |
sam.to(device=self.device)
|
| 200 |
self.mask_generator = SamAutomaticMaskGenerator(sam)
|
| 201 |
-
self.conditioned_pred =
|
| 202 |
|
| 203 |
def encode(self, image):
|
| 204 |
image = np.array(image)
|
| 205 |
-
self.
|
| 206 |
|
| 207 |
-
def cond_pred(self, pts, lbls):
|
| 208 |
lbls = np.array(lbls)
|
| 209 |
pts = np.array(pts)
|
| 210 |
-
masks, _, _ = self.conditioned_pred.
|
|
|
|
| 211 |
point_coords=pts,
|
| 212 |
point_labels=lbls,
|
| 213 |
multimask_output=True
|
|
|
|
| 11 |
import plotly.express as px
|
| 12 |
import matplotlib.pyplot as plt
|
| 13 |
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
|
| 18 |
def map_image_range(image, min_value, max_value):
|
| 19 |
"""
|
| 20 |
Maps the values of a numpy image array to a specified range.
|
|
|
|
| 192 |
|
| 193 |
|
| 194 |
|
| 195 |
+
import numpy as np
|
| 196 |
+
from typing import Optional, Tuple
|
| 197 |
+
|
| 198 |
+
class CustomSamPredictor(SamPredictor):
|
| 199 |
+
def __init__(
|
| 200 |
+
self,
|
| 201 |
+
sam_model,
|
| 202 |
+
) -> None:
|
| 203 |
+
super().__init__(sam_model)
|
| 204 |
+
|
| 205 |
+
def encode_image(self, image: np.ndarray, image_format: str = "RGB") -> torch.Tensor:
|
| 206 |
+
"""
|
| 207 |
+
Encodes the image and returns its embedding.
|
| 208 |
+
|
| 209 |
+
Arguments:
|
| 210 |
+
image (np.ndarray): The image for which to calculate the embedding.
|
| 211 |
+
image_format (str): The color format of the image, in ['RGB', 'BGR'].
|
| 212 |
+
|
| 213 |
+
Returns:
|
| 214 |
+
torch.Tensor: The image embedding with shape 1xCxHxW.
|
| 215 |
+
"""
|
| 216 |
+
self.set_image(image, image_format)
|
| 217 |
+
return self.get_image_embedding()
|
| 218 |
+
|
| 219 |
+
def decode_and_predict(
|
| 220 |
+
self,
|
| 221 |
+
embedding: torch.Tensor,
|
| 222 |
+
point_coords: Optional[np.ndarray] = None,
|
| 223 |
+
point_labels: Optional[np.ndarray] = None,
|
| 224 |
+
box: Optional[np.ndarray] = None,
|
| 225 |
+
mask_input: Optional[np.ndarray] = None,
|
| 226 |
+
multimask_output: bool = True,
|
| 227 |
+
return_logits: bool = False,
|
| 228 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 229 |
+
"""
|
| 230 |
+
Decodes the provided image embedding and makes mask predictions based on prompts.
|
| 231 |
+
|
| 232 |
+
Arguments:
|
| 233 |
+
embedding (torch.Tensor): The image embedding to decode.
|
| 234 |
+
... (other arguments from the predict function)
|
| 235 |
+
|
| 236 |
+
Returns:
|
| 237 |
+
(np.ndarray): The output masks in CxHxW format.
|
| 238 |
+
(np.ndarray): An array of quality predictions for each mask.
|
| 239 |
+
(np.ndarray): Low resolution mask logits for subsequent iterations.
|
| 240 |
+
"""
|
| 241 |
+
self.set_torch_image(embedding, (embedding.shape[-2], embedding.shape[-1]))
|
| 242 |
+
return self.predict(
|
| 243 |
+
point_coords=point_coords,
|
| 244 |
+
point_labels=point_labels,
|
| 245 |
+
box=box,
|
| 246 |
+
mask_input=mask_input,
|
| 247 |
+
multimask_output=multimask_output,
|
| 248 |
+
return_logits=return_logits,
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
|
| 252 |
class SegmentPredictor:
|
| 253 |
+
def __init__(self, device=None):
|
| 254 |
MODEL_TYPE = "vit_h"
|
| 255 |
checkpoint = "sam_vit_h_4b8939.pth"
|
| 256 |
sam = sam_model_registry[MODEL_TYPE](checkpoint=checkpoint)
|
| 257 |
# Select device
|
| 258 |
+
if device is None:
|
| 259 |
+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 260 |
+
else:
|
| 261 |
+
self.device = device
|
| 262 |
sam.to(device=self.device)
|
| 263 |
self.mask_generator = SamAutomaticMaskGenerator(sam)
|
| 264 |
+
self.conditioned_pred = CustomSamPredictor(sam)
|
| 265 |
|
| 266 |
def encode(self, image):
|
| 267 |
image = np.array(image)
|
| 268 |
+
return self.encode_image(image)
|
| 269 |
|
| 270 |
+
def cond_pred(self, embedding, pts, lbls):
|
| 271 |
lbls = np.array(lbls)
|
| 272 |
pts = np.array(pts)
|
| 273 |
+
masks, _, _ = self.conditioned_pred.decode_and_predict(
|
| 274 |
+
embedding,
|
| 275 |
point_coords=pts,
|
| 276 |
point_labels=lbls,
|
| 277 |
multimask_output=True
|