Spaces:
Runtime error
Runtime error
Cache loaded model
Browse files
app.py
CHANGED
|
@@ -12,11 +12,24 @@ from isegm.inference import clicker as ck
|
|
| 12 |
from isegm.inference import utils
|
| 13 |
from isegm.inference.predictors import get_predictor
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
"
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
# Items in the sidebar.
|
| 22 |
model = st.sidebar.selectbox("Select a Model:", tuple(models.keys()))
|
|
@@ -25,24 +38,16 @@ marking_type = st.sidebar.radio("Marking Type:", ("positive", "negative"))
|
|
| 25 |
image_path = st.sidebar.file_uploader("Background Image:", type=["png", "jpg", "jpeg"])
|
| 26 |
|
| 27 |
# Objects for prediction.
|
| 28 |
-
clicker = ck.Clicker()
|
| 29 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 30 |
-
predictor = None
|
| 31 |
with st.spinner("Wait for downloading a model..."):
|
| 32 |
if not os.path.exists(models[model]):
|
| 33 |
-
_ = wget.download(f"{
|
| 34 |
|
| 35 |
with st.spinner("Wait for loading a model..."):
|
| 36 |
-
|
| 37 |
-
predictor_params = {"brs_mode": "NoBRS"}
|
| 38 |
-
predictor = get_predictor(model, device=device, **predictor_params)
|
| 39 |
|
| 40 |
# Create a canvas component.
|
| 41 |
-
image = None
|
| 42 |
if image_path:
|
| 43 |
image = Image.open(image_path).convert("RGB")
|
| 44 |
-
canvas_height, canvas_width = 600, 600
|
| 45 |
-
pos_color, neg_color = "#3498DB", "#C70039"
|
| 46 |
|
| 47 |
st.title("Canvas:")
|
| 48 |
canvas_result = st_canvas(
|
|
@@ -66,7 +71,6 @@ if canvas_result.json_data and canvas_result.json_data["objects"] and image:
|
|
| 66 |
image_width, image_height = image.size
|
| 67 |
ratio_h, ratio_w = image_height / canvas_height, image_width / canvas_width
|
| 68 |
|
| 69 |
-
err_x, err_y = 5.5, 1.0
|
| 70 |
pos_clicks, neg_clicks = [], []
|
| 71 |
for click in objects:
|
| 72 |
x, y = (click["left"] + err_x) * ratio_w, (click["top"] + err_y) * ratio_h
|
|
|
|
| 12 |
from isegm.inference import utils
|
| 13 |
from isegm.inference.predictors import get_predictor
|
| 14 |
|
| 15 |
+
@st.cache_data
|
| 16 |
+
def load_model(model_path, device):
|
| 17 |
+
model = utils.load_is_model(model_path, device, cpu_dist_maps=True)
|
| 18 |
+
predictor_params = {"brs_mode": "NoBRS"}
|
| 19 |
+
predictor = get_predictor(model, device=device, **predictor_params)
|
| 20 |
+
return predictor
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# Objects in the global scope
|
| 24 |
+
url_prefix = "https://huggingface.co/curt-park/interactive-segmentation/resolve/main"
|
| 25 |
+
models = {"RITM": "ritm_coco_lvis_h18_itermask.pth"}
|
| 26 |
+
clicker = ck.Clicker()
|
| 27 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 28 |
+
pos_color, neg_color = "#3498DB", "#C70039"
|
| 29 |
+
canvas_height, canvas_width = 600, 600
|
| 30 |
+
err_x, err_y = 5.5, 1.0
|
| 31 |
+
predictor = None
|
| 32 |
+
image = None
|
| 33 |
|
| 34 |
# Items in the sidebar.
|
| 35 |
model = st.sidebar.selectbox("Select a Model:", tuple(models.keys()))
|
|
|
|
| 38 |
image_path = st.sidebar.file_uploader("Background Image:", type=["png", "jpg", "jpeg"])
|
| 39 |
|
| 40 |
# Objects for prediction.
|
|
|
|
|
|
|
|
|
|
| 41 |
with st.spinner("Wait for downloading a model..."):
|
| 42 |
if not os.path.exists(models[model]):
|
| 43 |
+
_ = wget.download(f"{url_prefix}/{models[model]}")
|
| 44 |
|
| 45 |
with st.spinner("Wait for loading a model..."):
|
| 46 |
+
predictor = load_model(models[model], device)
|
|
|
|
|
|
|
| 47 |
|
| 48 |
# Create a canvas component.
|
|
|
|
| 49 |
if image_path:
|
| 50 |
image = Image.open(image_path).convert("RGB")
|
|
|
|
|
|
|
| 51 |
|
| 52 |
st.title("Canvas:")
|
| 53 |
canvas_result = st_canvas(
|
|
|
|
| 71 |
image_width, image_height = image.size
|
| 72 |
ratio_h, ratio_w = image_height / canvas_height, image_width / canvas_width
|
| 73 |
|
|
|
|
| 74 |
pos_clicks, neg_clicks = [], []
|
| 75 |
for click in objects:
|
| 76 |
x, y = (click["left"] + err_x) * ratio_w, (click["top"] + err_y) * ratio_h
|