Spaces:
Runtime error
Runtime error
s194649
commited on
Commit
·
cf01ea3
1
Parent(s):
5f453af
fix
Browse files- inference.py +21 -7
inference.py
CHANGED
|
@@ -202,18 +202,32 @@ class CustomSamPredictor(SamPredictor):
|
|
| 202 |
) -> None:
|
| 203 |
super().__init__(sam_model)
|
| 204 |
|
| 205 |
-
def encode_image(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
"""
|
| 207 |
-
|
|
|
|
| 208 |
|
| 209 |
Arguments:
|
| 210 |
-
image (np.ndarray): The image for
|
|
|
|
| 211 |
image_format (str): The color format of the image, in ['RGB', 'BGR'].
|
| 212 |
-
|
| 213 |
-
Returns:
|
| 214 |
-
torch.Tensor: The image embedding with shape 1xCxHxW.
|
| 215 |
"""
|
| 216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
return self.get_image_embedding()
|
| 218 |
|
| 219 |
def decode_and_predict(
|
|
|
|
| 202 |
) -> None:
|
| 203 |
super().__init__(sam_model)
|
| 204 |
|
| 205 |
+
def encode_image(
|
| 206 |
+
self,
|
| 207 |
+
image: np.ndarray,
|
| 208 |
+
image_format: str = "RGB",
|
| 209 |
+
) -> None:
|
| 210 |
"""
|
| 211 |
+
Calculates the image embeddings for the provided image, allowing
|
| 212 |
+
masks to be predicted with the 'predict' method.
|
| 213 |
|
| 214 |
Arguments:
|
| 215 |
+
image (np.ndarray): The image for calculating masks. Expects an
|
| 216 |
+
image in HWC uint8 format, with pixel values in [0, 255].
|
| 217 |
image_format (str): The color format of the image, in ['RGB', 'BGR'].
|
|
|
|
|
|
|
|
|
|
| 218 |
"""
|
| 219 |
+
assert image_format in [
|
| 220 |
+
"RGB",
|
| 221 |
+
"BGR",
|
| 222 |
+
], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
|
| 223 |
+
if image_format != self.model.image_format:
|
| 224 |
+
image = image[..., ::-1]
|
| 225 |
+
|
| 226 |
+
# Transform the image to the form expected by the model
|
| 227 |
+
input_image = self.transform.apply_image(image)
|
| 228 |
+
input_image_torch = torch.as_tensor(input_image, device=self.device)
|
| 229 |
+
input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
|
| 230 |
+
self.set_torch_image(input_image_torch, image.shape[:2])
|
| 231 |
return self.get_image_embedding()
|
| 232 |
|
| 233 |
def decode_and_predict(
|