Jazz1508 commited on
Commit
dc909e1
·
verified ·
1 Parent(s): b52af01

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -15
app.py CHANGED
@@ -44,21 +44,19 @@ model = smp.DeepLabV3Plus(
44
  )
45
 
46
  checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
47
-
48
- if "model_state_dict" in checkpoint:
49
- model.load_state_dict(checkpoint["model_state_dict"])
50
- else:
51
- model.load_state_dict(checkpoint)
52
 
53
  model.to(DEVICE)
54
  model.eval()
55
 
56
- # ================================
57
- # HELPERS
58
- # ================================
59
  normalize = Normalize()
60
  to_tensor = ToTensorV2()
61
 
 
 
 
62
  def pad_to_16(img):
63
  h, w = img.shape[:2]
64
  new_h = (h + 15) // 16 * 16
@@ -76,11 +74,13 @@ def colorize_mask(mask):
76
  return color_mask
77
 
78
  # ================================
79
- # INFERENCE FUNCTION
80
  # ================================
81
  def segment_image(image):
82
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
 
83
 
 
84
  padded, orig_h, orig_w = pad_to_16(image)
85
 
86
  img = normalize(image=padded)["image"]
@@ -96,6 +96,7 @@ def segment_image(image):
96
  color_mask = colorize_mask(pred_mask)
97
  overlay = cv2.addWeighted(image, 0.6, color_mask, 0.4, 0)
98
 
 
99
  vals, counts = np.unique(pred_mask, return_counts=True)
100
  vals = vals[vals > 0]
101
 
@@ -112,17 +113,17 @@ def segment_image(image):
112
  # ================================
113
  with gr.Blocks() as demo:
114
  gr.Markdown("# 🏗 Structural Defect Segmentation")
115
-
116
  with gr.Tab("Image Upload"):
117
  input_img = gr.Image(type="numpy")
118
  output_img = gr.Image()
119
  output_text = gr.Textbox()
120
  btn = gr.Button("Run Segmentation")
121
  btn.click(segment_image, inputs=input_img, outputs=[output_img, output_text])
122
-
123
- with gr.Tab("Live Camera"):
124
- cam = gr.Image(source="webcam", streaming=True)
125
  cam_out = gr.Image()
126
- cam.stream(segment_image, inputs=cam, outputs=[cam_out])
127
 
128
  demo.launch()
 
44
  )
45
 
46
  checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
47
+ model.load_state_dict(
48
+ checkpoint["model_state_dict"] if "model_state_dict" in checkpoint else checkpoint
49
+ )
 
 
50
 
51
  model.to(DEVICE)
52
  model.eval()
53
 
 
 
 
54
  normalize = Normalize()
55
  to_tensor = ToTensorV2()
56
 
57
+ # ================================
58
+ # HELPERS
59
+ # ================================
60
  def pad_to_16(img):
61
  h, w = img.shape[:2]
62
  new_h = (h + 15) // 16 * 16
 
74
  return color_mask
75
 
76
  # ================================
77
+ # INFERENCE
78
  # ================================
79
  def segment_image(image):
80
+ if image is None:
81
+ return None, ""
82
 
83
+ # Gradio provides RGB already
84
  padded, orig_h, orig_w = pad_to_16(image)
85
 
86
  img = normalize(image=padded)["image"]
 
96
  color_mask = colorize_mask(pred_mask)
97
  overlay = cv2.addWeighted(image, 0.6, color_mask, 0.4, 0)
98
 
99
+ # Image-level classification
100
  vals, counts = np.unique(pred_mask, return_counts=True)
101
  vals = vals[vals > 0]
102
 
 
113
  # ================================
114
  with gr.Blocks() as demo:
115
  gr.Markdown("# 🏗 Structural Defect Segmentation")
116
+
117
  with gr.Tab("Image Upload"):
118
  input_img = gr.Image(type="numpy")
119
  output_img = gr.Image()
120
  output_text = gr.Textbox()
121
  btn = gr.Button("Run Segmentation")
122
  btn.click(segment_image, inputs=input_img, outputs=[output_img, output_text])
123
+
124
+ with gr.Tab("Live Camera (Real-Time)"):
125
+ cam = gr.Image(sources=["webcam"], streaming=True, type="numpy")
126
  cam_out = gr.Image()
127
+ cam.stream(lambda x: segment_image(x)[0], inputs=cam, outputs=cam_out)
128
 
129
  demo.launch()