nishanth-saka commited on
Commit
2ac33ac
·
verified ·
1 Parent(s): 42acb28

Feathering Overlays

Browse files
Files changed (1) hide show
  1. app.py +116 -44
app.py CHANGED
@@ -53,10 +53,15 @@ def depth_to_normal(depth):
53
  # CORE PROCESSING FUNCTION
54
  # ===============================
55
  def _process_saree_core(base_image: Image.Image, pattern_image: Image.Image):
 
 
 
56
  img_pil = base_image.convert("RGB")
57
  img_np = np.array(img_pil)
58
 
59
- # Prepare tensor
 
 
60
  img_resized = img_pil.resize((384, 384))
61
  img_tensor = torch.from_numpy(np.array(img_resized)).permute(2, 0, 1).unsqueeze(0).float() / 255.0
62
  mean = torch.as_tensor([0.5, 0.5, 0.5], device=img_tensor.device).view(1, 3, 1, 1)
@@ -67,84 +72,151 @@ def _process_saree_core(base_image: Image.Image, pattern_image: Image.Image):
67
  model = SimpleDPT(backbone_name='vit_base_patch16_384').to(device)
68
  model.eval()
69
 
70
- # Depth inference
71
  with torch.no_grad():
72
  target_size = img_pil.size[::-1]
73
  depth_map = model(img_tensor.to(device), target_size=target_size)
74
  depth_map = depth_map.squeeze().cpu().numpy()
75
 
76
- # Normalize depth
77
- depth_vis = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
78
-
79
- # Normal map
80
  normal_map = depth_to_normal(depth_vis)
81
 
82
- # Shading map (CLAHE)
 
 
83
  img_lab = cv2.cvtColor(img_np, cv2.COLOR_RGB2LAB)
84
  l_channel, _, _ = cv2.split(img_lab)
85
  clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
86
  l_clahe = clahe.apply(l_channel)
87
  shading_map = l_clahe / 255.0
88
-
89
- # Tile pattern
90
- pattern_np = np.array(pattern_image.convert("RGB"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  target_h, target_w = img_np.shape[:2]
92
- pattern_h, pattern_w = pattern_np.shape[:2]
93
- pattern_tiled = np.zeros((target_h, target_w, 3), dtype=np.uint8)
94
- for y in range(0, target_h, pattern_h):
95
- for x in range(0, target_w, pattern_w):
96
- end_y = min(y + pattern_h, target_h)
97
- end_x = min(x + pattern_w, target_w)
98
- pattern_tiled[y:end_y, x:end_x] = pattern_np[0:(end_y - y), 0:(end_x - x)]
99
-
100
- # Blend pattern
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  normal_map_loaded = normal_map.astype(np.float32)
102
- shading_map_loaded = np.stack([shading_map] * 3, axis=-1)
 
103
 
104
- alpha = 0.7
105
- blended_shading = alpha * shading_map_loaded + (1 - alpha)
 
 
 
106
 
107
- pattern_folded = pattern_tiled.astype(np.float32) / 255.0 * blended_shading
108
  normal_boost = 0.5 + 0.5 * normal_map_loaded[..., 2:3]
109
  pattern_folded *= normal_boost
110
- pattern_folded = np.clip(pattern_folded, 0, 1)
111
 
112
- # ==========================================================
113
- # Background removal with post-processing (no duplicate blur)
114
- # ==========================================================
115
  buf = BytesIO()
116
  base_image.save(buf, format="PNG")
117
  base_bytes = buf.getvalue()
118
 
119
- # Get RGBA from bgrem
120
  result_no_bg = bgrem_remove(base_bytes)
121
  mask_img = Image.open(BytesIO(result_no_bg)).convert("RGBA")
122
 
123
- # Extract alpha and clean edges
124
  mask_alpha = np.array(mask_img)[:, :, 3].astype(np.float32) / 255.0
125
 
126
- # 1. Slightly stronger shrink (balanced)
127
  k = 5
128
- kernel = np.ones((k, k), np.uint8) # slightly larger kernel
129
- mask_binary = (mask_alpha > k/100).astype(np.uint8) * 255 # slightly stricter threshold
130
- mask_eroded = cv2.erode(mask_binary, kernel, iterations=3) # balanced erosion
131
-
132
-
133
- # 2. Feather edges (blur)
134
  mask_blurred = cv2.GaussianBlur(mask_eroded, (15, 15), sigmaX=3, sigmaY=3)
 
135
 
136
- # 3. Normalize
137
- mask_blurred = mask_blurred.astype(np.float32) / 255.0
 
 
 
138
 
139
- # Final RGBA
140
- mask_stack = np.stack([mask_blurred] * 3, axis=-1)
141
- pattern_final = pattern_folded * mask_stack
142
- pattern_rgb = (pattern_final * 255).astype(np.uint8)
143
- alpha_channel = (mask_blurred * 255).astype(np.uint8)
144
- pattern_rgba = np.dstack((pattern_rgb, alpha_channel))
145
 
 
146
  return Image.fromarray(pattern_rgba, mode="RGBA")
147
 
 
148
  # ===============================
149
  # WRAPPER: ACCEPT BYTES OR BASE64
150
  # ===============================
 
53
  # CORE PROCESSING FUNCTION
54
  # ===============================
55
  def _process_saree_core(base_image: Image.Image, pattern_image: Image.Image):
56
+ # ===============================
57
+ # 0) Prep: base to RGB/np
58
+ # ===============================
59
  img_pil = base_image.convert("RGB")
60
  img_np = np.array(img_pil)
61
 
62
+ # ===============================
63
+ # 1) Depth inference (kept as-is)
64
+ # ===============================
65
  img_resized = img_pil.resize((384, 384))
66
  img_tensor = torch.from_numpy(np.array(img_resized)).permute(2, 0, 1).unsqueeze(0).float() / 255.0
67
  mean = torch.as_tensor([0.5, 0.5, 0.5], device=img_tensor.device).view(1, 3, 1, 1)
 
72
  model = SimpleDPT(backbone_name='vit_base_patch16_384').to(device)
73
  model.eval()
74
 
 
75
  with torch.no_grad():
76
  target_size = img_pil.size[::-1]
77
  depth_map = model(img_tensor.to(device), target_size=target_size)
78
  depth_map = depth_map.squeeze().cpu().numpy()
79
 
80
+ # Normalize depth and build normal map
81
+ depth_vis = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min() + 1e-8)
 
 
82
  normal_map = depth_to_normal(depth_vis)
83
 
84
+ # ===============================
85
+ # 2) Shading map (CLAHE)
86
+ # ===============================
87
  img_lab = cv2.cvtColor(img_np, cv2.COLOR_RGB2LAB)
88
  l_channel, _, _ = cv2.split(img_lab)
89
  clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
90
  l_clahe = clahe.apply(l_channel)
91
  shading_map = l_clahe / 255.0
92
+ shading_map_loaded = np.stack([shading_map] * 3, axis=-1) # (H,W,3)
93
+
94
+ # ===============================
95
+ # 3) OVERLAY alpha feather (NEW)
96
+ # ===============================
97
+ # pattern_np = np.array(pattern_image.convert("RGB")) # <-- ORIGINAL (kills alpha) [COMMENTED]
98
+ pattern_rgba_full = np.array(pattern_image.convert("RGBA")) # keep alpha
99
+ alpha = pattern_rgba_full[:, :, 3].astype(np.float32) / 255.0
100
+
101
+ # feather alpha a little to soften edges
102
+ alpha_feathered = cv2.GaussianBlur(alpha, (5, 5), sigmaX=2, sigmaY=2)
103
+ alpha_feathered = np.clip(alpha_feathered, 0.0, 1.0)
104
+
105
+ # premultiply RGB by feathered alpha
106
+ rgb = pattern_rgba_full[:, :, :3].astype(np.float32) / 255.0
107
+ rgb_pm = rgb * alpha_feathered[..., None] # premultiplied RGB in [0,1]
108
+
109
+ # Optional: crop to non-transparent bbox to avoid tiling empty margins
110
+ alpha_thresh = 0.01
111
+ ys, xs = np.where(alpha_feathered > alpha_thresh)
112
+ if ys.size > 0:
113
+ y0, y1 = ys.min(), ys.max() + 1
114
+ x0, x1 = xs.min(), xs.max() + 1
115
+ rgb_pm = rgb_pm[y0:y1, x0:x1, :]
116
+ alpha_crop = alpha_feathered[y0:y1, x0:x1]
117
+ else:
118
+ alpha_crop = alpha_feathered # degenerate case
119
+
120
+ ph, pw = alpha_crop.shape[:2]
121
+
122
+ # ===============================
123
+ # 4) Alpha-aware tiling (NEW)
124
+ # ===============================
125
  target_h, target_w = img_np.shape[:2]
126
+
127
+ # --- ORIGINAL hard RGB tiling (caused seams) [COMMENTED] ---
128
+ # pattern_h, pattern_w = pattern_np.shape[:2]
129
+ # pattern_tiled = np.zeros((target_h, target_w, 3), dtype=np.uint8)
130
+ # for y in range(0, target_h, pattern_h):
131
+ # for x in range(0, target_w, pattern_w):
132
+ # end_y = min(y + pattern_h, target_h)
133
+ # end_x = min(x + pattern_w, target_w)
134
+ # pattern_tiled[y:end_y, x:end_x] = pattern_np[0:(end_y - y), 0:(end_x - x)]
135
+
136
+ # NEW: premultiplied "over" compositing per tile
137
+ canvas_rgb_pm = np.zeros((target_h, target_w, 3), dtype=np.float32)
138
+ canvas_a = np.zeros((target_h, target_w, 1), dtype=np.float32)
139
+
140
+ tile_rgb_pm_src = rgb_pm.astype(np.float32) # (ph,pw,3), premultiplied
141
+ tile_a_src = alpha_crop.astype(np.float32)[..., None] # (ph,pw,1)
142
+
143
+ for y in range(0, target_h, ph):
144
+ for x in range(0, target_w, pw):
145
+ end_y = min(y + ph, target_h)
146
+ end_x = min(x + pw, target_w)
147
+ h = end_y - y
148
+ w = end_x - x
149
+
150
+ src_rgb_pm = tile_rgb_pm_src[:h, :w, :]
151
+ src_a = tile_a_src[:h, :w, :]
152
+
153
+ dst_rgb_pm = canvas_rgb_pm[y:end_y, x:end_x, :]
154
+ dst_a = canvas_a[y:end_y, x:end_x, :]
155
+
156
+ out_rgb_pm = src_rgb_pm + dst_rgb_pm * (1.0 - src_a)
157
+ out_a = src_a + dst_a * (1.0 - src_a)
158
+
159
+ canvas_rgb_pm[y:end_y, x:end_x, :] = out_rgb_pm
160
+ canvas_a[y:end_y, x:end_x, :] = out_a
161
+
162
+ # Un-premultiply to get display RGB; keep tiled overlay alpha
163
+ canvas_a_safe = np.clip(canvas_a, 1e-6, 1.0)
164
+ pattern_rgb_tiled = np.clip(canvas_rgb_pm / canvas_a_safe, 0.0, 1.0) # (H,W,3)
165
+ pattern_alpha_tiled = np.clip(canvas_a[..., 0], 0.0, 1.0) # (H,W)
166
+
167
+ # ===============================
168
+ # 5) Apply shading + normal boost (kept as-is)
169
+ # ===============================
170
  normal_map_loaded = normal_map.astype(np.float32)
171
+ alpha_shading = 0.7
172
+ blended_shading = alpha_shading * shading_map_loaded + (1 - alpha_shading)
173
 
174
+ # --- ORIGINAL (kept) ---
175
+ # pattern_folded = pattern_tiled.astype(np.float32) / 255.0 * blended_shading
176
+ # normal_boost = 0.5 + 0.5 * normal_map_loaded[..., 2:3]
177
+ # pattern_folded *= normal_boost
178
+ # pattern_folded = np.clip(pattern_folded, 0, 1)
179
 
180
+ pattern_folded = pattern_rgb_tiled * blended_shading
181
  normal_boost = 0.5 + 0.5 * normal_map_loaded[..., 2:3]
182
  pattern_folded *= normal_boost
183
+ pattern_folded = np.clip(pattern_folded, 0.0, 1.0)
184
 
185
+ # ===============================
186
+ # 6) Background removal for the BASE (kept, with your tuning)
187
+ # ===============================
188
  buf = BytesIO()
189
  base_image.save(buf, format="PNG")
190
  base_bytes = buf.getvalue()
191
 
 
192
  result_no_bg = bgrem_remove(base_bytes)
193
  mask_img = Image.open(BytesIO(result_no_bg)).convert("RGBA")
194
 
 
195
  mask_alpha = np.array(mask_img)[:, :, 3].astype(np.float32) / 255.0
196
 
197
+ # Slightly stronger shrink + feather (your settings)
198
  k = 5
199
+ kernel = np.ones((k, k), np.uint8)
200
+ mask_binary = (mask_alpha > k / 100.0).astype(np.uint8) * 255
201
+ mask_eroded = cv2.erode(mask_binary, kernel, iterations=3)
 
 
 
202
  mask_blurred = cv2.GaussianBlur(mask_eroded, (15, 15), sigmaX=3, sigmaY=3)
203
+ mask_blurred = mask_blurred.astype(np.float32) / 255.0 # [0,1]
204
 
205
+ # ===============================
206
+ # 7) Combine BASE mask with OVERLAY tiled alpha (NEW)
207
+ # ===============================
208
+ overlay_alpha_stack = pattern_alpha_tiled # (H,W) in [0,1]
209
+ alpha_combined = np.clip(mask_blurred * overlay_alpha_stack, 0.0, 1.0)
210
 
211
+ # Apply combined alpha to folded pattern
212
+ pattern_final_rgb = pattern_folded * alpha_combined[..., None]
213
+ pattern_rgb_u8 = (np.clip(pattern_final_rgb, 0.0, 1.0) * 255).astype(np.uint8)
214
+ alpha_u8 = (alpha_combined * 255).astype(np.uint8)
 
 
215
 
216
+ pattern_rgba = np.dstack((pattern_rgb_u8, alpha_u8))
217
  return Image.fromarray(pattern_rgba, mode="RGBA")
218
 
219
+
220
  # ===============================
221
  # WRAPPER: ACCEPT BYTES OR BASE64
222
  # ===============================