nishanth-saka commited on
Commit
1214ae6
·
verified ·
1 Parent(s): 33d3bc6
Files changed (1) hide show
  1. app.py +107 -101
app.py CHANGED
@@ -50,10 +50,10 @@ def depth_to_normal(depth):
50
  return normal
51
 
52
  # ===============================
53
- # CORE PROCESSING FUNCTION (Unified-mask + overlap blend)
54
  # ===============================
55
  def _process_saree_core(base_image: Image.Image, pattern_image: Image.Image):
56
- # img_pil = base_image.convert("RGB") # <-- COMMENTED: can inject black behind transparency
57
  # img_np = np.array(img_pil)
58
 
59
  # --- ORIGINAL (white matte) kept for reference ---
@@ -67,28 +67,37 @@ def _process_saree_core(base_image: Image.Image, pattern_image: Image.Image):
67
  # img_np = _rgb_over_white
68
  # --- end ORIGINAL ---
69
 
70
- # --- Alpha-aware RGB using median interior-boundary matte (prevents black halo) ---
71
  base_rgba = base_image.convert("RGBA")
72
- _arr = np.array(base_rgba).astype(np.float32) # (H,W,4) 0..255
73
- _rgb = _arr[..., :3]
74
- _alpha8 = _arr[..., 3].astype(np.uint8)
75
- _a = (_alpha8.astype(np.float32) / 255.0)[..., None] # (H,W,1)
76
-
77
- _fg_mask = (_alpha8 > 128).astype(np.uint8) * 255
78
- _k3 = np.ones((3, 3), np.uint8)
79
- _er1 = cv2.erode(_fg_mask, _k3, iterations=1)
80
- _boundary = cv2.bitwise_and(_fg_mask, cv2.bitwise_not(_er1))
 
 
 
 
 
81
  if int((_boundary > 0).sum()) < 100:
82
- _d2 = cv2.dilate(_fg_mask, _k3, iterations=2)
83
- _e2 = cv2.erode(_fg_mask, _k3, iterations=2)
84
- _boundary = cv2.subtract(_d2, _e2)
 
85
  _idx = (_boundary > 0)
86
  if not np.any(_idx):
 
87
  _idx = (_fg_mask > 0)
88
 
 
89
  _median_color = np.median(_rgb[_idx], axis=0) if np.any(_idx) else np.array([255.0, 255.0, 255.0], dtype=np.float32)
90
- _median_color = _median_color.reshape(1, 1, 3)
91
 
 
92
  _rgb_over_matte = _rgb * _a + (1.0 - _a) * _median_color
93
  _rgb_over_matte = np.clip(_rgb_over_matte, 0.0, 255.0).astype(np.uint8)
94
 
@@ -96,7 +105,7 @@ def _process_saree_core(base_image: Image.Image, pattern_image: Image.Image):
96
  img_np = _rgb_over_matte
97
  # --- end NEW ---
98
 
99
- # Prepare tensor (global, once)
100
  img_resized = img_pil.resize((384, 384))
101
  img_tensor = torch.from_numpy(np.array(img_resized)).permute(2, 0, 1).unsqueeze(0).float() / 255.0
102
  mean = torch.as_tensor([0.5, 0.5, 0.5], device=img_tensor.device).view(1, 3, 1, 1)
@@ -107,7 +116,7 @@ def _process_saree_core(base_image: Image.Image, pattern_image: Image.Image):
107
  model = SimpleDPT(backbone_name='vit_base_patch16_384').to(device)
108
  model.eval()
109
 
110
- # Depth inference (global)
111
  with torch.no_grad():
112
  target_size = img_pil.size[::-1]
113
  depth_map = model(img_tensor.to(device), target_size=target_size)
@@ -116,124 +125,121 @@ def _process_saree_core(base_image: Image.Image, pattern_image: Image.Image):
116
  # Normalize depth
117
  depth_vis = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
118
 
119
- # Normal & shading maps (global)
120
  normal_map = depth_to_normal(depth_vis)
 
 
121
  img_lab = cv2.cvtColor(img_np, cv2.COLOR_RGB2LAB)
122
  l_channel, _, _ = cv2.split(img_lab)
123
  clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
124
  l_clahe = clahe.apply(l_channel)
125
  shading_map = l_clahe / 255.0
126
 
127
- # Pattern tiling with global origin (no per-region reset)
128
  pattern_np = np.array(pattern_image.convert("RGB"))
129
  target_h, target_w = img_np.shape[:2]
130
- ph, pw = pattern_np.shape[:2]
131
  pattern_tiled = np.zeros((target_h, target_w, 3), dtype=np.uint8)
132
- for y in range(0, target_h, ph):
133
- for x in range(0, target_w, pw):
134
- ey = min(y + ph, target_h)
135
- ex = min(x + pw, target_w)
136
- pattern_tiled[y:ey, x:ex] = pattern_np[0:(ey - y), 0:(ex - x)]
137
-
138
- # Global fold & light
139
- normal_map_f = normal_map.astype(np.float32)
140
- shading_map_f = np.stack([shading_map] * 3, axis=-1)
141
- alpha_lit = 0.7
142
- blended_shading = alpha_lit * shading_map_f + (1 - alpha_lit)
 
143
 
144
  pattern_folded = pattern_tiled.astype(np.float32) / 255.0 * blended_shading
145
- normal_boost = 0.5 + 0.5 * normal_map_f[..., 2:3]
146
  pattern_folded *= normal_boost
147
  pattern_folded = np.clip(pattern_folded, 0, 1)
148
 
149
  # ==========================================================
150
- # Unified mask from rembg, then overlap cross-fade band
151
  # ==========================================================
152
  buf = BytesIO()
153
  base_image.save(buf, format="PNG")
154
  base_bytes = buf.getvalue()
155
 
156
- # Get RGBA from bgrem (unified alpha)
157
  result_no_bg = bgrem_remove(base_bytes)
158
  mask_img = Image.open(BytesIO(result_no_bg)).convert("RGBA")
159
- mask_alpha = np.array(mask_img)[:, :, 3].astype(np.float32) / 255.0 # [0..1]
160
-
161
- # --- ORIGINAL mask steps (commented for reference) ---
162
- # k = 5
163
- # kernel = np.ones((k, k), np.uint8)
164
- # mask_binary = (mask_alpha > k/100).astype(np.uint8) * 255
165
- # mask_eroded = cv2.erode(mask_binary, kernel, iterations=3)
166
- # mask_blurred = cv2.GaussianBlur(mask_eroded, (15, 15), sigmaX=3, sigmaY=3)
167
- # mask_blurred = mask_blurred.astype(np.float32) / 255.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  # mask_stack = np.stack([mask_blurred] * 3, axis=-1)
169
  # pattern_final = pattern_folded * mask_stack
170
- # --- end ORIGINAL ---
171
-
172
- # --- NEW: overlap band via inner/outer shells (distance-field cross-fade) ---
173
- # Tunables:
174
- overlap_px = 10 # width of cross-fade band around edges (px)
175
- feather_sigma = 1.5 # light Gaussian to keep transitions smooth
176
- bleed_iters = 2 # color bleed strength along edge
177
- alpha_floor = 0.02 # minimum alpha to hide hairlines
178
-
179
- # Build binary base (strict) for morphology
180
- bin_strict = (mask_alpha > 0.5).astype(np.uint8) * 255
181
- k5 = np.ones((5, 5), np.uint8)
182
-
183
- # Inner and outer shells
184
- inner = cv2.erode(bin_strict, k5, iterations=max(1, overlap_px // 5)) # shrink inside
185
- outer = cv2.dilate(bin_strict, k5, iterations=max(1, overlap_px // 5)) # grow outside
186
-
187
- # Distance fields (inside to inner edge; outside to outer edge)
188
- d_in = cv2.distanceTransform(inner, cv2.DIST_L2, 3) # distance to 0 within inner mask
189
- d_out = cv2.distanceTransform(255 - outer, cv2.DIST_L2, 3) # distance to 0 outside outer mask
190
-
191
- # Compose smooth alpha:
192
- # 1 inside inner → fully opaque
193
- # 0 outside outer → fully transparent
194
- # linear ramp in the overlap band
195
- alpha_inside = (inner > 0).astype(np.float32)
196
- alpha_outside = (outer == 0).astype(np.float32)
197
- # Normalize distances into 0..1 ramps
198
- ramp_in = np.clip(d_in / max(1.0, overlap_px), 0.0, 1.0)
199
- ramp_out = np.clip(d_out / max(1.0, overlap_px), 0.0, 1.0)
200
-
201
- # Where neither fully inside nor fully outside, use a symmetric blend
202
- alpha_band = np.clip(0.5 * (ramp_in + (1.0 - ramp_out)), 0.0, 1.0)
203
-
204
- alpha_unified = np.where(alpha_inside > 0, 1.0,
205
- np.where(alpha_outside > 0, 0.0, alpha_band))
206
-
207
- # Feather lightly
208
- alpha_unified = cv2.GaussianBlur((alpha_unified * 255).astype(np.uint8),
209
- (7, 7), sigmaX=feather_sigma, sigmaY=feather_sigma)
210
- alpha_unified = alpha_unified.astype(np.float32) / 255.0
211
-
212
- # Premultiplied blend with unified alpha
213
- mask_stack = np.stack([alpha_unified] * 3, axis=-1)
214
- pattern_final = pattern_folded * mask_stack # premultiplied RGB
215
-
216
- # --- Edge color bleed in premultiplied space (thin band only) ---
217
- edge_band = (alpha_unified > 0.0) & (alpha_unified <= min(0.12, overlap_px / max(10.0, overlap_px))) # ~8–12%
218
  if np.any(edge_band):
219
- k3 = np.ones((3, 3), np.uint8)
220
- premul_u8 = (pattern_final * 255).astype(np.uint8)
221
- premul_bleed = cv2.dilate(premul_u8, k3, iterations=int(bleed_iters)).astype(np.float32) / 255.0
222
- pattern_final[edge_band] = premul_bleed[edge_band]
223
 
224
- # Premultiplied → straight alpha for PNG export
225
  eps = 1e-6
226
- A = np.clip(alpha_unified, max(alpha_floor, eps), 1.0)
227
- A3 = A[..., None]
228
- rgb_straight = np.clip(pattern_final / A3, 0.0, 1.0)
 
 
 
229
 
230
  pattern_rgb = (rgb_straight * 255).astype(np.uint8)
231
- alpha_channel = (A * 255).astype(np.uint8)
232
  pattern_rgba = np.dstack((pattern_rgb, alpha_channel))
233
 
234
  return Image.fromarray(pattern_rgba, mode="RGBA")
235
 
236
 
 
237
  # ===============================
238
  # WRAPPER: ACCEPT BYTES OR BASE64
239
  # ===============================
 
50
  return normal
51
 
52
  # ===============================
53
+ # CORE PROCESSING FUNCTION
54
  # ===============================
55
  def _process_saree_core(base_image: Image.Image, pattern_image: Image.Image):
56
+ # img_pil = base_image.convert("RGB") # <-- COMMENTED: this injects black behind transparency
57
  # img_np = np.array(img_pil)
58
 
59
  # --- ORIGINAL (white matte) kept for reference ---
 
67
  # img_np = _rgb_over_white
68
  # --- end ORIGINAL ---
69
 
70
+ # --- NEW: alpha-aware RGB using median color along interior boundary as matte ---
71
  base_rgba = base_image.convert("RGBA")
72
+ _arr = np.array(base_rgba).astype(np.float32) # (H,W,4), RGB in 0..255, A in 0..255
73
+ _rgb = _arr[..., :3] # (H,W,3)
74
+ _alpha8 = _arr[..., 3].astype(np.uint8) # (H,W) uint8 alpha
75
+ _a = (_alpha8.astype(np.float32) / 255.0)[..., None] # (H,W,1) float alpha
76
+
77
+ # Build a foreground mask from alpha (slightly strict to avoid wispy edges)
78
+ _fg_mask = (_alpha8 > 128).astype(np.uint8) * 255 # (H,W) 0/255
79
+
80
+ # Morphological interior boundary: foreground minus a 1-iteration erosion
81
+ _k = np.ones((3, 3), np.uint8)
82
+ _eroded = cv2.erode(_fg_mask, _k, iterations=1)
83
+ _boundary = cv2.bitwise_and(_fg_mask, cv2.bitwise_not(_eroded)) # thin interior ring
84
+
85
+ # If boundary is too thin/few pixels, widen the ring via morphological gradient
86
  if int((_boundary > 0).sum()) < 100:
87
+ _dil = cv2.dilate(_fg_mask, _k, iterations=2)
88
+ _ero = cv2.erode(_fg_mask, _k, iterations=2)
89
+ _boundary = cv2.subtract(_dil, _ero)
90
+
91
  _idx = (_boundary > 0)
92
  if not np.any(_idx):
93
+ # Fallback: use entire foreground if boundary not found
94
  _idx = (_fg_mask > 0)
95
 
96
+ # Compute median color over the selected boundary pixels (in 0..255 space)
97
  _median_color = np.median(_rgb[_idx], axis=0) if np.any(_idx) else np.array([255.0, 255.0, 255.0], dtype=np.float32)
98
+ _median_color = _median_color.reshape(1, 1, 3) # (1,1,3)
99
 
100
+ # Composite RGB over median matte (avoid introducing black/white bias)
101
  _rgb_over_matte = _rgb * _a + (1.0 - _a) * _median_color
102
  _rgb_over_matte = np.clip(_rgb_over_matte, 0.0, 255.0).astype(np.uint8)
103
 
 
105
  img_np = _rgb_over_matte
106
  # --- end NEW ---
107
 
108
+ # Prepare tensor
109
  img_resized = img_pil.resize((384, 384))
110
  img_tensor = torch.from_numpy(np.array(img_resized)).permute(2, 0, 1).unsqueeze(0).float() / 255.0
111
  mean = torch.as_tensor([0.5, 0.5, 0.5], device=img_tensor.device).view(1, 3, 1, 1)
 
116
  model = SimpleDPT(backbone_name='vit_base_patch16_384').to(device)
117
  model.eval()
118
 
119
+ # Depth inference
120
  with torch.no_grad():
121
  target_size = img_pil.size[::-1]
122
  depth_map = model(img_tensor.to(device), target_size=target_size)
 
125
  # Normalize depth
126
  depth_vis = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
127
 
128
+ # Normal map
129
  normal_map = depth_to_normal(depth_vis)
130
+
131
+ # Shading map (CLAHE)
132
  img_lab = cv2.cvtColor(img_np, cv2.COLOR_RGB2LAB)
133
  l_channel, _, _ = cv2.split(img_lab)
134
  clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
135
  l_clahe = clahe.apply(l_channel)
136
  shading_map = l_clahe / 255.0
137
 
138
+ # Tile pattern
139
  pattern_np = np.array(pattern_image.convert("RGB"))
140
  target_h, target_w = img_np.shape[:2]
141
+ pattern_h, pattern_w = pattern_np.shape[:2]
142
  pattern_tiled = np.zeros((target_h, target_w, 3), dtype=np.uint8)
143
+ for y in range(0, target_h, pattern_h):
144
+ for x in range(0, target_w, pattern_w):
145
+ end_y = min(y + pattern_h, target_h)
146
+ end_x = min(x + pattern_w, target_w)
147
+ pattern_tiled[y:end_y, x:end_x] = pattern_np[0:(end_y - y), 0:(end_x - x)]
148
+
149
+ # Blend pattern
150
+ normal_map_loaded = normal_map.astype(np.float32)
151
+ shading_map_loaded = np.stack([shading_map] * 3, axis=-1)
152
+
153
+ alpha = 0.7
154
+ blended_shading = alpha * shading_map_loaded + (1 - alpha)
155
 
156
  pattern_folded = pattern_tiled.astype(np.float32) / 255.0 * blended_shading
157
+ normal_boost = 0.5 + 0.5 * normal_map_loaded[..., 2:3]
158
  pattern_folded *= normal_boost
159
  pattern_folded = np.clip(pattern_folded, 0, 1)
160
 
161
  # ==========================================================
162
+ # Background removal with post-processing (no duplicate blur)
163
  # ==========================================================
164
  buf = BytesIO()
165
  base_image.save(buf, format="PNG")
166
  base_bytes = buf.getvalue()
167
 
168
+ # Get RGBA from bgrem
169
  result_no_bg = bgrem_remove(base_bytes)
170
  mask_img = Image.open(BytesIO(result_no_bg)).convert("RGBA")
171
+
172
+ # Extract alpha and clean edges
173
+ mask_alpha = np.array(mask_img)[:, :, 3].astype(np.float32) / 255.0
174
+
175
+ # 1. Slightly stronger shrink (balanced)
176
+ k = 5
177
+ kernel = np.ones((k, k), np.uint8) # slightly larger kernel
178
+ mask_binary = (mask_alpha > k/100).astype(np.uint8) * 255 # slightly stricter threshold
179
+ mask_eroded = cv2.erode(mask_binary, kernel, iterations=3) # balanced erosion
180
+
181
+ # 2. Feather edges (blur)
182
+ mask_blurred = cv2.GaussianBlur(mask_eroded, (15, 15), sigmaX=3, sigmaY=3)
183
+
184
+ # 3. Normalize
185
+ mask_blurred = mask_blurred.astype(np.float32) / 255.0
186
+
187
+ # ================================
188
+ # NEW: SEAM-FIX UPSTREAM (3 steps)
189
+ # ================================
190
+ # (A) MASK EXPANSION / OVERLAP: expand slightly to ensure overlap across seams
191
+ overlap_iters = 2 # <-- tune: 1..3 (px-ish with 5x5 kernel)
192
+ # mask_expanded = cv2.dilate(mask_eroded, kernel, iterations=overlap_iters) # old idea
193
+ # --- Better: expand the FLOAT feathered mask to preserve soft edge continuity ---
194
+ _mask_float = (mask_blurred * 255).astype(np.uint8)
195
+ _mask_expanded_u8 = cv2.dilate(_mask_float, kernel, iterations=overlap_iters)
196
+ mask_expanded = _mask_expanded_u8.astype(np.float32) / 255.0 # [0..1]
197
+
198
+ # (B) FEATHER AGAIN after expansion (very light) for a smooth transition band
199
+ mask_expanded = cv2.GaussianBlur((mask_expanded * 255).astype(np.uint8), (7, 7), sigmaX=1.5, sigmaY=1.5)
200
+ mask_expanded = mask_expanded.astype(np.float32) / 255.0
201
+
202
+ # (C) BLEED-COLOR FILLING in premultiplied space for near-edge pixels
203
+ # - Create a thin edge band where alpha is small (e.g., up to 8%)
204
+ edge_upper = 0.08
205
+ # Final RGBA
206
  # mask_stack = np.stack([mask_blurred] * 3, axis=-1)
207
  # pattern_final = pattern_folded * mask_stack
208
+ # --- Replace above with expanded mask for overlap ---
209
+ mask_stack = np.stack([mask_expanded] * 3, axis=-1)
210
+ pattern_final = pattern_folded * mask_stack # premultiplied RGB (color * alpha) with overlap
211
+
212
+ # - Dilate premultiplied RGB slightly so edge pixels borrow nearby garment color
213
+ bleed_iters = 2 # <-- tune: 1..3
214
+ _kernel_bleed = np.ones((3, 3), np.uint8)
215
+ _premul_u8 = (pattern_final * 255).astype(np.uint8)
216
+ _premul_bleed = cv2.dilate(_premul_u8, _kernel_bleed, iterations=bleed_iters).astype(np.float32) / 255.0
217
+
218
+ # - Replace only in the very thin edge band (alpha between 0 and edge_upper)
219
+ edge_band = (mask_expanded > 0.0) & (mask_expanded <= edge_upper)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  if np.any(edge_band):
221
+ pattern_final[edge_band] = _premul_bleed[edge_band]
222
+ # ================================
223
+ # END NEW: SEAM-FIX UPSTREAM
224
+ # ================================
225
 
226
+ # Premultiplied → Straight alpha
227
  eps = 1e-6
228
+ # _alpha = np.clip(mask_blurred, eps, 1.0)
229
+ # --- Use the expanded mask for export, with a small alpha floor to hide hairlines ---
230
+ alpha_floor = 0.02 # 2% floor; increase to 0.03 if a faint line persists
231
+ _alpha = np.clip(mask_expanded, max(alpha_floor, eps), 1.0)
232
+ _alpha3 = _alpha[..., None]
233
+ rgb_straight = np.clip(pattern_final / _alpha3, 0.0, 1.0)
234
 
235
  pattern_rgb = (rgb_straight * 255).astype(np.uint8)
236
+ alpha_channel = (_alpha * 255).astype(np.uint8)
237
  pattern_rgba = np.dstack((pattern_rgb, alpha_channel))
238
 
239
  return Image.fromarray(pattern_rgba, mode="RGBA")
240
 
241
 
242
+
243
  # ===============================
244
  # WRAPPER: ACCEPT BYTES OR BASE64
245
  # ===============================