nishanth-saka commited on
Commit
b0d222e
·
verified ·
1 Parent(s): 0d1afea

BG Removal Updated

Browse files
Files changed (1) hide show
  1. app.py +15 -18
app.py CHANGED
@@ -11,6 +11,7 @@ from io import BytesIO
11
  import base64
12
  import traceback
13
  from starlette.exceptions import HTTPException as StarletteHTTPException
 
14
 
15
  # ===============================
16
  # SIMPLE DPT MODEL (DEPTH ESTIMATION)
@@ -85,17 +86,6 @@ def _process_saree_core(base_image: Image.Image, pattern_image: Image.Image):
85
  l_clahe = clahe.apply(l_channel)
86
  shading_map = l_clahe / 255.0
87
 
88
- # GrabCut mask
89
- img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
90
- grabcut_mask = np.zeros(img_bgr.shape[:2], np.uint8)
91
- height, width = img_bgr.shape[:2]
92
- margin = int(min(width, height) * 0.05)
93
- rect = (margin, margin, width - 2 * margin, height - 2 * margin)
94
- bgdModel = np.zeros((1, 65), np.float64)
95
- fgdModel = np.zeros((1, 65), np.float64)
96
- cv2.grabCut(img_bgr, grabcut_mask, rect, bgdModel, fgdModel, 5, cv2.GC_INIT_WITH_RECT)
97
- mask = np.where((grabcut_mask == cv2.GC_FGD) | (grabcut_mask == cv2.GC_PR_FGD), 255, 0).astype(np.uint8)
98
-
99
  # Tile pattern
100
  pattern_np = np.array(pattern_image.convert("RGB"))
101
  target_h, target_w = img_np.shape[:2]
@@ -119,13 +109,20 @@ def _process_saree_core(base_image: Image.Image, pattern_image: Image.Image):
119
  pattern_folded *= normal_boost
120
  pattern_folded = np.clip(pattern_folded, 0, 1)
121
 
122
- # Clean mask and feather edges
123
- mask_float = mask.astype(np.float32) / 255.0
124
- kernel = np.ones((3, 3), np.uint8)
125
- mask_clean = cv2.morphologyEx((mask_float * 255).astype(np.uint8), cv2.MORPH_OPEN, kernel)
126
- mask_clean = cv2.morphologyEx(mask_clean, cv2.MORPH_CLOSE, kernel)
127
- mask_clean = cv2.dilate(mask_clean, kernel, iterations=1)
128
- mask_blurred = cv2.GaussianBlur(mask_clean, (15, 15), sigmaX=5, sigmaY=5)
 
 
 
 
 
 
 
129
  mask_blurred[mask_blurred < 25] = 0
130
  mask_blurred = mask_blurred.astype(np.float32) / 255.0
131
 
 
11
  import base64
12
  import traceback
13
  from starlette.exceptions import HTTPException as StarletteHTTPException
14
+ from bgrem import remove as bgrem_remove # NEW: bgrem import
15
 
16
  # ===============================
17
  # SIMPLE DPT MODEL (DEPTH ESTIMATION)
 
86
  l_clahe = clahe.apply(l_channel)
87
  shading_map = l_clahe / 255.0
88
 
 
 
 
 
 
 
 
 
 
 
 
89
  # Tile pattern
90
  pattern_np = np.array(pattern_image.convert("RGB"))
91
  target_h, target_w = img_np.shape[:2]
 
109
  pattern_folded *= normal_boost
110
  pattern_folded = np.clip(pattern_folded, 0, 1)
111
 
112
+ # ==========================================================
113
+ # NEW: Background removal using bgrem (instead of GrabCut)
114
+ # ==========================================================
115
+ buf = BytesIO()
116
+ base_image.save(buf, format="PNG")
117
+ base_bytes = buf.getvalue()
118
+
119
+ # Get RGBA from bgrem
120
+ result_no_bg = bgrem_remove(base_bytes)
121
+ mask_img = Image.open(BytesIO(result_no_bg)).convert("RGBA")
122
+ mask_alpha = np.array(mask_img)[:, :, 3].astype(np.float32) / 255.0
123
+
124
+ # Feather edges for smoother blending
125
+ mask_blurred = cv2.GaussianBlur((mask_alpha * 255).astype(np.uint8), (15, 15), sigmaX=5, sigmaY=5)
126
  mask_blurred[mask_blurred < 25] = 0
127
  mask_blurred = mask_blurred.astype(np.float32) / 255.0
128