Spaces:
Runtime error
Runtime error
Make gpu compatible
Browse files
app.py
CHANGED
|
@@ -26,7 +26,7 @@ image_path = st.sidebar.file_uploader("Background Image:", type=["png", "jpg", "
|
|
| 26 |
|
| 27 |
# Objects for prediction.
|
| 28 |
clicker = ck.Clicker()
|
| 29 |
-
device = torch.device("cpu")
|
| 30 |
predictor = None
|
| 31 |
with st.spinner("Wait for downloading a model..."):
|
| 32 |
if not os.path.exists(models[model]):
|
|
@@ -43,6 +43,7 @@ if image_path:
|
|
| 43 |
image = Image.open(image_path).convert("RGB")
|
| 44 |
canvas_height, canvas_width = 600, 600
|
| 45 |
pos_color, neg_color = "#3498DB", "#C70039"
|
|
|
|
| 46 |
st.title("Canvas:")
|
| 47 |
canvas_result = st_canvas(
|
| 48 |
fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
|
|
@@ -75,11 +76,15 @@ if canvas_result.json_data and canvas_result.json_data["objects"] and image:
|
|
| 75 |
click = ck.Click(is_positive=is_positive, coords=(y, x))
|
| 76 |
clicker.add_click(click)
|
| 77 |
|
| 78 |
-
# prediction.
|
| 79 |
pred = None
|
| 80 |
predictor.set_input_image(np.array(image))
|
|
|
|
|
|
|
| 81 |
with st.spinner("Wait for prediction..."):
|
| 82 |
-
pred = predictor.get_prediction(clicker, prev_mask=
|
| 83 |
pred = cv2.resize(pred, dsize=(canvas_height, canvas_width), interpolation=cv2.INTER_CUBIC)
|
| 84 |
pred = np.where(pred > threshold, 1.0, 0)
|
|
|
|
|
|
|
| 85 |
st.image(pred, caption="")
|
|
|
|
| 26 |
|
| 27 |
# Objects for prediction.
|
| 28 |
clicker = ck.Clicker()
|
| 29 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 30 |
predictor = None
|
| 31 |
with st.spinner("Wait for downloading a model..."):
|
| 32 |
if not os.path.exists(models[model]):
|
|
|
|
| 43 |
image = Image.open(image_path).convert("RGB")
|
| 44 |
canvas_height, canvas_width = 600, 600
|
| 45 |
pos_color, neg_color = "#3498DB", "#C70039"
|
| 46 |
+
|
| 47 |
st.title("Canvas:")
|
| 48 |
canvas_result = st_canvas(
|
| 49 |
fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
|
|
|
|
| 76 |
click = ck.Click(is_positive=is_positive, coords=(y, x))
|
| 77 |
clicker.add_click(click)
|
| 78 |
|
| 79 |
+
# Run prediction.
|
| 80 |
pred = None
|
| 81 |
predictor.set_input_image(np.array(image))
|
| 82 |
+
init_mask = torch.zeros((1, 1, image_height, image_width), device=device)
|
| 83 |
+
|
| 84 |
with st.spinner("Wait for prediction..."):
|
| 85 |
+
pred = predictor.get_prediction(clicker, prev_mask=init_mask)
|
| 86 |
pred = cv2.resize(pred, dsize=(canvas_height, canvas_width), interpolation=cv2.INTER_CUBIC)
|
| 87 |
pred = np.where(pred > threshold, 1.0, 0)
|
| 88 |
+
|
| 89 |
+
# Show the prediction result.
|
| 90 |
st.image(pred, caption="")
|