Jazz1508 commited on
Commit
e451a47
·
verified ·
1 Parent(s): dc909e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -6
app.py CHANGED
@@ -12,6 +12,7 @@ from albumentations.pytorch import ToTensorV2
12
  MODEL_PATH = "s2ds_deeplabv3plus.pth"
13
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
  NUM_CLASSES = 7
 
15
 
16
  CLASS_NAMES = {
17
  0: "Background",
@@ -74,13 +75,16 @@ def colorize_mask(mask):
74
  return color_mask
75
 
76
  # ================================
77
- # INFERENCE
78
  # ================================
79
  def segment_image(image):
80
  if image is None:
81
  return None, ""
82
 
83
- # Gradio provides RGB already
 
 
 
84
  padded, orig_h, orig_w = pad_to_16(image)
85
 
86
  img = normalize(image=padded)["image"]
@@ -88,13 +92,21 @@ def segment_image(image):
88
  img = img.unsqueeze(0).to(DEVICE)
89
 
90
  with torch.no_grad():
91
- pred = model(img)
 
 
 
 
 
92
  pred_mask = torch.argmax(pred, dim=1)[0].cpu().numpy()
93
 
94
  pred_mask = pred_mask[:orig_h, :orig_w]
95
 
96
  color_mask = colorize_mask(pred_mask)
97
- overlay = cv2.addWeighted(image, 0.6, color_mask, 0.4, 0)
 
 
 
98
 
99
  # Image-level classification
100
  vals, counts = np.unique(pred_mask, return_counts=True)
@@ -121,8 +133,13 @@ with gr.Blocks() as demo:
121
  btn = gr.Button("Run Segmentation")
122
  btn.click(segment_image, inputs=input_img, outputs=[output_img, output_text])
123
 
124
- with gr.Tab("Live Camera (Real-Time)"):
125
- cam = gr.Image(sources=["webcam"], streaming=True, type="numpy")
 
 
 
 
 
126
  cam_out = gr.Image()
127
  cam.stream(lambda x: segment_image(x)[0], inputs=cam, outputs=cam_out)
128
 
 
12
  MODEL_PATH = "s2ds_deeplabv3plus.pth"
13
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
  NUM_CLASSES = 7
15
+ INFER_SIZE = 512 # 🔥 reduce for speed (important for live feed)
16
 
17
  CLASS_NAMES = {
18
  0: "Background",
 
75
  return color_mask
76
 
77
  # ================================
78
+ # FAST INFERENCE FUNCTION
79
  # ================================
80
  def segment_image(image):
81
  if image is None:
82
  return None, ""
83
 
84
+ # 🔥 Downscale for speed
85
+ original = image.copy()
86
+ image = cv2.resize(image, (INFER_SIZE, INFER_SIZE))
87
+
88
  padded, orig_h, orig_w = pad_to_16(image)
89
 
90
  img = normalize(image=padded)["image"]
 
92
  img = img.unsqueeze(0).to(DEVICE)
93
 
94
  with torch.no_grad():
95
+ if DEVICE == "cuda":
96
+ with torch.cuda.amp.autocast():
97
+ pred = model(img)
98
+ else:
99
+ pred = model(img)
100
+
101
  pred_mask = torch.argmax(pred, dim=1)[0].cpu().numpy()
102
 
103
  pred_mask = pred_mask[:orig_h, :orig_w]
104
 
105
  color_mask = colorize_mask(pred_mask)
106
+ overlay_small = cv2.addWeighted(image, 0.6, color_mask, 0.4, 0)
107
+
108
+ # 🔥 Resize back to original size
109
+ overlay = cv2.resize(overlay_small, (original.shape[1], original.shape[0]))
110
 
111
  # Image-level classification
112
  vals, counts = np.unique(pred_mask, return_counts=True)
 
133
  btn = gr.Button("Run Segmentation")
134
  btn.click(segment_image, inputs=input_img, outputs=[output_img, output_text])
135
 
136
+ with gr.Tab("Live Camera (Fast Mode)"):
137
+ cam = gr.Image(
138
+ sources=["webcam"],
139
+ streaming=True,
140
+ type="numpy",
141
+ webcam_options={"facingMode": "environment"} # 🔥 force back camera
142
+ )
143
  cam_out = gr.Image()
144
  cam.stream(lambda x: segment_image(x)[0], inputs=cam, outputs=cam_out)
145