Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -12,6 +12,7 @@ from albumentations.pytorch import ToTensorV2
|
|
| 12 |
MODEL_PATH = "s2ds_deeplabv3plus.pth"
|
| 13 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 14 |
NUM_CLASSES = 7
|
|
|
|
| 15 |
|
| 16 |
CLASS_NAMES = {
|
| 17 |
0: "Background",
|
|
@@ -74,13 +75,16 @@ def colorize_mask(mask):
|
|
| 74 |
return color_mask
|
| 75 |
|
| 76 |
# ================================
|
| 77 |
-
# INFERENCE
|
| 78 |
# ================================
|
| 79 |
def segment_image(image):
|
| 80 |
if image is None:
|
| 81 |
return None, ""
|
| 82 |
|
| 83 |
-
#
|
|
|
|
|
|
|
|
|
|
| 84 |
padded, orig_h, orig_w = pad_to_16(image)
|
| 85 |
|
| 86 |
img = normalize(image=padded)["image"]
|
|
@@ -88,13 +92,21 @@ def segment_image(image):
|
|
| 88 |
img = img.unsqueeze(0).to(DEVICE)
|
| 89 |
|
| 90 |
with torch.no_grad():
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
pred_mask = torch.argmax(pred, dim=1)[0].cpu().numpy()
|
| 93 |
|
| 94 |
pred_mask = pred_mask[:orig_h, :orig_w]
|
| 95 |
|
| 96 |
color_mask = colorize_mask(pred_mask)
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
# Image-level classification
|
| 100 |
vals, counts = np.unique(pred_mask, return_counts=True)
|
|
@@ -121,8 +133,13 @@ with gr.Blocks() as demo:
|
|
| 121 |
btn = gr.Button("Run Segmentation")
|
| 122 |
btn.click(segment_image, inputs=input_img, outputs=[output_img, output_text])
|
| 123 |
|
| 124 |
-
with gr.Tab("Live Camera (
|
| 125 |
-
cam = gr.Image(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
cam_out = gr.Image()
|
| 127 |
cam.stream(lambda x: segment_image(x)[0], inputs=cam, outputs=cam_out)
|
| 128 |
|
|
|
|
| 12 |
MODEL_PATH = "s2ds_deeplabv3plus.pth"
|
| 13 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 14 |
NUM_CLASSES = 7
|
| 15 |
+
INFER_SIZE = 512 # 🔥 reduce for speed (important for live feed)
|
| 16 |
|
| 17 |
CLASS_NAMES = {
|
| 18 |
0: "Background",
|
|
|
|
| 75 |
return color_mask
|
| 76 |
|
| 77 |
# ================================
|
| 78 |
+
# FAST INFERENCE FUNCTION
|
| 79 |
# ================================
|
| 80 |
def segment_image(image):
|
| 81 |
if image is None:
|
| 82 |
return None, ""
|
| 83 |
|
| 84 |
+
# 🔥 Downscale for speed
|
| 85 |
+
original = image.copy()
|
| 86 |
+
image = cv2.resize(image, (INFER_SIZE, INFER_SIZE))
|
| 87 |
+
|
| 88 |
padded, orig_h, orig_w = pad_to_16(image)
|
| 89 |
|
| 90 |
img = normalize(image=padded)["image"]
|
|
|
|
| 92 |
img = img.unsqueeze(0).to(DEVICE)
|
| 93 |
|
| 94 |
with torch.no_grad():
|
| 95 |
+
if DEVICE == "cuda":
|
| 96 |
+
with torch.cuda.amp.autocast():
|
| 97 |
+
pred = model(img)
|
| 98 |
+
else:
|
| 99 |
+
pred = model(img)
|
| 100 |
+
|
| 101 |
pred_mask = torch.argmax(pred, dim=1)[0].cpu().numpy()
|
| 102 |
|
| 103 |
pred_mask = pred_mask[:orig_h, :orig_w]
|
| 104 |
|
| 105 |
color_mask = colorize_mask(pred_mask)
|
| 106 |
+
overlay_small = cv2.addWeighted(image, 0.6, color_mask, 0.4, 0)
|
| 107 |
+
|
| 108 |
+
# 🔥 Resize back to original size
|
| 109 |
+
overlay = cv2.resize(overlay_small, (original.shape[1], original.shape[0]))
|
| 110 |
|
| 111 |
# Image-level classification
|
| 112 |
vals, counts = np.unique(pred_mask, return_counts=True)
|
|
|
|
| 133 |
btn = gr.Button("Run Segmentation")
|
| 134 |
btn.click(segment_image, inputs=input_img, outputs=[output_img, output_text])
|
| 135 |
|
| 136 |
+
with gr.Tab("Live Camera (Fast Mode)"):
|
| 137 |
+
cam = gr.Image(
|
| 138 |
+
sources=["webcam"],
|
| 139 |
+
streaming=True,
|
| 140 |
+
type="numpy",
|
| 141 |
+
webcam_options={"facingMode": "environment"} # 🔥 force back camera
|
| 142 |
+
)
|
| 143 |
cam_out = gr.Image()
|
| 144 |
cam.stream(lambda x: segment_image(x)[0], inputs=cam, outputs=cam_out)
|
| 145 |
|