Spaces:
Runtime error
Runtime error
Fix bug
Browse files
app.py
CHANGED
|
@@ -30,7 +30,7 @@ image = None
|
|
| 30 |
###################################
|
| 31 |
# Functions.
|
| 32 |
###################################
|
| 33 |
-
# @st.cache_resource
|
| 34 |
def load_model(model_path: str, device: torch.device) -> BasePredictor:
|
| 35 |
model = utils.load_is_model(model_path, device, cpu_dist_maps=True)
|
| 36 |
predictor_params = {"brs_mode": "NoBRS"}
|
|
@@ -54,9 +54,7 @@ def feed_clicks(
|
|
| 54 |
clicker.add_click(click)
|
| 55 |
|
| 56 |
|
| 57 |
-
def predict(
|
| 58 |
-
image: Image, mask: torch.Tensor, threshold: float = 0.5
|
| 59 |
-
) -> torch.Tensor:
|
| 60 |
predictor.set_input_image(np.array(image))
|
| 61 |
with st.spinner("Wait for prediction..."):
|
| 62 |
pred = predictor.get_prediction(clicker, prev_mask=mask)
|
|
@@ -120,7 +118,7 @@ if canvas_result.json_data and canvas_result.json_data["objects"] and image:
|
|
| 120 |
feed_clicks(clicker, canvas_result.json_data["objects"], image_width, image_height)
|
| 121 |
|
| 122 |
# Run prediction.
|
| 123 |
-
mask = torch.zeros((1, 1,
|
| 124 |
pred = predict(image, mask, threshold)
|
| 125 |
|
| 126 |
# Show the prediction result.
|
|
|
|
| 30 |
###################################
|
| 31 |
# Functions.
|
| 32 |
###################################
|
| 33 |
+
# @st.cache_resource # TODO: this doesn't work on Huggingface!
|
| 34 |
def load_model(model_path: str, device: torch.device) -> BasePredictor:
|
| 35 |
model = utils.load_is_model(model_path, device, cpu_dist_maps=True)
|
| 36 |
predictor_params = {"brs_mode": "NoBRS"}
|
|
|
|
| 54 |
clicker.add_click(click)
|
| 55 |
|
| 56 |
|
| 57 |
+
def predict(image: Image, mask: torch.Tensor, threshold: float = 0.5) -> torch.Tensor:
|
|
|
|
|
|
|
| 58 |
predictor.set_input_image(np.array(image))
|
| 59 |
with st.spinner("Wait for prediction..."):
|
| 60 |
pred = predictor.get_prediction(clicker, prev_mask=mask)
|
|
|
|
| 118 |
feed_clicks(clicker, canvas_result.json_data["objects"], image_width, image_height)
|
| 119 |
|
| 120 |
# Run prediction.
|
| 121 |
+
mask = torch.zeros((1, 1, image_height, image_width), device=device)
|
| 122 |
pred = predict(image, mask, threshold)
|
| 123 |
|
| 124 |
# Show the prediction result.
|