tbuyuktanir commited on
Commit
0f9046d
·
verified ·
1 Parent(s): 441250a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -8
app.py CHANGED
@@ -20,11 +20,14 @@ global_state = {
20
 
21
  # Helper to apply mask overlay
22
  def apply_mask_overlay(image: Image.Image, mask: np.ndarray, color=(255, 0, 0)) -> Image.Image:
23
- mask_img = Image.fromarray(mask.astype(np.uint8) * 255).convert("L")
 
 
 
 
24
  color_mask = Image.new("RGB", image.size, color)
25
- mask_rgb = Image.composite(color_mask, image, mask_img)
26
- blended = Image.blend(image, mask_rgb, alpha=0.5)
27
- return blended
28
 
29
  # Set image
30
  def upload_image(img):
@@ -56,9 +59,15 @@ def run_segmentation():
56
  inputs = processor(image, return_tensors="pt").to(device)
57
 
58
  if global_state["clicks"]:
59
- points = torch.tensor([[[x, y] for (x, y, l) in global_state["clicks"]]], device=device)
60
- labels = torch.tensor([[l for (_, _, l) in global_state["clicks"]]], device=device)
61
- inputs.update({"input_points": points, "input_labels": labels})
 
 
 
 
 
 
62
 
63
  if global_state["bbox"]:
64
  x0, y0, x1, y1 = global_state["bbox"]
@@ -74,7 +83,7 @@ def run_segmentation():
74
  inputs["reshaped_input_sizes"].cpu()
75
  )[0]
76
 
77
- final_mask = masks[0].numpy()
78
  overlayed = apply_mask_overlay(image.convert("RGB"), final_mask)
79
 
80
  return overlayed, "Segmentation complete."
 
20
 
21
  # Helper to apply mask overlay
22
  def apply_mask_overlay(image: Image.Image, mask: np.ndarray, color=(255, 0, 0)) -> Image.Image:
23
+ if mask.ndim == 3:
24
+ mask = mask.squeeze()
25
+ if mask.max() <= 1:
26
+ mask = (mask * 255).astype(np.uint8)
27
+ mask_img = Image.fromarray(mask).convert("L")
28
  color_mask = Image.new("RGB", image.size, color)
29
+ blended = Image.composite(color_mask, image, mask_img)
30
+ return Image.blend(image, blended, alpha=0.5)
 
31
 
32
  # Set image
33
  def upload_image(img):
 
59
  inputs = processor(image, return_tensors="pt").to(device)
60
 
61
  if global_state["clicks"]:
62
+ coords = [[(x, y) for (x, y, l) in global_state["clicks"]]]
63
+ labels = [[l for (_, _, l) in global_state["clicks"]]]
64
+ input_points = torch.tensor([coords], device=device) # shape [1, 1, N, 2]
65
+ input_labels = torch.tensor([labels], device=device) # shape [1, 1, N]
66
+
67
+ inputs.update({
68
+ "input_points": input_points,
69
+ "input_labels": input_labels
70
+ })
71
 
72
  if global_state["bbox"]:
73
  x0, y0, x1, y1 = global_state["bbox"]
 
83
  inputs["reshaped_input_sizes"].cpu()
84
  )[0]
85
 
86
+ final_mask = masks[0].numpy().astype(np.uint8) # shape: (H, W)
87
  overlayed = apply_mask_overlay(image.convert("RGB"), final_mask)
88
 
89
  return overlayed, "Segmentation complete."