nishanth-saka commited on
Commit
142bff5
·
verified ·
1 Parent(s): 2ac33ac
Files changed (1) hide show
  1. app.py +44 -117
app.py CHANGED
@@ -53,15 +53,10 @@ def depth_to_normal(depth):
53
  # CORE PROCESSING FUNCTION
54
  # ===============================
55
  def _process_saree_core(base_image: Image.Image, pattern_image: Image.Image):
56
- # ===============================
57
- # 0) Prep: base to RGB/np
58
- # ===============================
59
  img_pil = base_image.convert("RGB")
60
  img_np = np.array(img_pil)
61
 
62
- # ===============================
63
- # 1) Depth inference (kept as-is)
64
- # ===============================
65
  img_resized = img_pil.resize((384, 384))
66
  img_tensor = torch.from_numpy(np.array(img_resized)).permute(2, 0, 1).unsqueeze(0).float() / 255.0
67
  mean = torch.as_tensor([0.5, 0.5, 0.5], device=img_tensor.device).view(1, 3, 1, 1)
@@ -72,151 +67,83 @@ def _process_saree_core(base_image: Image.Image, pattern_image: Image.Image):
72
  model = SimpleDPT(backbone_name='vit_base_patch16_384').to(device)
73
  model.eval()
74
 
 
75
  with torch.no_grad():
76
  target_size = img_pil.size[::-1]
77
  depth_map = model(img_tensor.to(device), target_size=target_size)
78
  depth_map = depth_map.squeeze().cpu().numpy()
79
 
80
- # Normalize depth and build normal map
81
- depth_vis = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min() + 1e-8)
 
 
82
  normal_map = depth_to_normal(depth_vis)
83
 
84
- # ===============================
85
- # 2) Shading map (CLAHE)
86
- # ===============================
87
  img_lab = cv2.cvtColor(img_np, cv2.COLOR_RGB2LAB)
88
  l_channel, _, _ = cv2.split(img_lab)
89
  clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
90
  l_clahe = clahe.apply(l_channel)
91
  shading_map = l_clahe / 255.0
92
- shading_map_loaded = np.stack([shading_map] * 3, axis=-1) # (H,W,3)
93
-
94
- # ===============================
95
- # 3) OVERLAY alpha feather (NEW)
96
- # ===============================
97
- # pattern_np = np.array(pattern_image.convert("RGB")) # <-- ORIGINAL (kills alpha) [COMMENTED]
98
- pattern_rgba_full = np.array(pattern_image.convert("RGBA")) # keep alpha
99
- alpha = pattern_rgba_full[:, :, 3].astype(np.float32) / 255.0
100
-
101
- # feather alpha a little to soften edges
102
- alpha_feathered = cv2.GaussianBlur(alpha, (5, 5), sigmaX=2, sigmaY=2)
103
- alpha_feathered = np.clip(alpha_feathered, 0.0, 1.0)
104
-
105
- # premultiply RGB by feathered alpha
106
- rgb = pattern_rgba_full[:, :, :3].astype(np.float32) / 255.0
107
- rgb_pm = rgb * alpha_feathered[..., None] # premultiplied RGB in [0,1]
108
-
109
- # Optional: crop to non-transparent bbox to avoid tiling empty margins
110
- alpha_thresh = 0.01
111
- ys, xs = np.where(alpha_feathered > alpha_thresh)
112
- if ys.size > 0:
113
- y0, y1 = ys.min(), ys.max() + 1
114
- x0, x1 = xs.min(), xs.max() + 1
115
- rgb_pm = rgb_pm[y0:y1, x0:x1, :]
116
- alpha_crop = alpha_feathered[y0:y1, x0:x1]
117
- else:
118
- alpha_crop = alpha_feathered # degenerate case
119
-
120
- ph, pw = alpha_crop.shape[:2]
121
-
122
- # ===============================
123
- # 4) Alpha-aware tiling (NEW)
124
- # ===============================
125
- target_h, target_w = img_np.shape[:2]
126
 
127
- # --- ORIGINAL hard RGB tiling (caused seams) [COMMENTED] ---
128
- # pattern_h, pattern_w = pattern_np.shape[:2]
129
- # pattern_tiled = np.zeros((target_h, target_w, 3), dtype=np.uint8)
130
- # for y in range(0, target_h, pattern_h):
131
- # for x in range(0, target_w, pattern_w):
132
- # end_y = min(y + pattern_h, target_h)
133
- # end_x = min(x + pattern_w, target_w)
134
- # pattern_tiled[y:end_y, x:end_x] = pattern_np[0:(end_y - y), 0:(end_x - x)]
135
-
136
- # NEW: premultiplied "over" compositing per tile
137
- canvas_rgb_pm = np.zeros((target_h, target_w, 3), dtype=np.float32)
138
- canvas_a = np.zeros((target_h, target_w, 1), dtype=np.float32)
139
-
140
- tile_rgb_pm_src = rgb_pm.astype(np.float32) # (ph,pw,3), premultiplied
141
- tile_a_src = alpha_crop.astype(np.float32)[..., None] # (ph,pw,1)
142
-
143
- for y in range(0, target_h, ph):
144
- for x in range(0, target_w, pw):
145
- end_y = min(y + ph, target_h)
146
- end_x = min(x + pw, target_w)
147
- h = end_y - y
148
- w = end_x - x
149
-
150
- src_rgb_pm = tile_rgb_pm_src[:h, :w, :]
151
- src_a = tile_a_src[:h, :w, :]
152
-
153
- dst_rgb_pm = canvas_rgb_pm[y:end_y, x:end_x, :]
154
- dst_a = canvas_a[y:end_y, x:end_x, :]
155
-
156
- out_rgb_pm = src_rgb_pm + dst_rgb_pm * (1.0 - src_a)
157
- out_a = src_a + dst_a * (1.0 - src_a)
158
-
159
- canvas_rgb_pm[y:end_y, x:end_x, :] = out_rgb_pm
160
- canvas_a[y:end_y, x:end_x, :] = out_a
161
-
162
- # Un-premultiply to get display RGB; keep tiled overlay alpha
163
- canvas_a_safe = np.clip(canvas_a, 1e-6, 1.0)
164
- pattern_rgb_tiled = np.clip(canvas_rgb_pm / canvas_a_safe, 0.0, 1.0) # (H,W,3)
165
- pattern_alpha_tiled = np.clip(canvas_a[..., 0], 0.0, 1.0) # (H,W)
166
-
167
- # ===============================
168
- # 5) Apply shading + normal boost (kept as-is)
169
- # ===============================
170
  normal_map_loaded = normal_map.astype(np.float32)
171
- alpha_shading = 0.7
172
- blended_shading = alpha_shading * shading_map_loaded + (1 - alpha_shading)
173
 
174
- # --- ORIGINAL (kept) ---
175
- # pattern_folded = pattern_tiled.astype(np.float32) / 255.0 * blended_shading
176
- # normal_boost = 0.5 + 0.5 * normal_map_loaded[..., 2:3]
177
- # pattern_folded *= normal_boost
178
- # pattern_folded = np.clip(pattern_folded, 0, 1)
179
 
180
- pattern_folded = pattern_rgb_tiled * blended_shading
181
  normal_boost = 0.5 + 0.5 * normal_map_loaded[..., 2:3]
182
  pattern_folded *= normal_boost
183
- pattern_folded = np.clip(pattern_folded, 0.0, 1.0)
184
 
185
- # ===============================
186
- # 6) Background removal for the BASE (kept, with your tuning)
187
- # ===============================
188
  buf = BytesIO()
189
  base_image.save(buf, format="PNG")
190
  base_bytes = buf.getvalue()
191
 
 
192
  result_no_bg = bgrem_remove(base_bytes)
193
  mask_img = Image.open(BytesIO(result_no_bg)).convert("RGBA")
194
 
 
195
  mask_alpha = np.array(mask_img)[:, :, 3].astype(np.float32) / 255.0
196
 
197
- # Slightly stronger shrink + feather (your settings)
198
- k = 5
199
- kernel = np.ones((k, k), np.uint8)
200
- mask_binary = (mask_alpha > k / 100.0).astype(np.uint8) * 255
201
- mask_eroded = cv2.erode(mask_binary, kernel, iterations=3)
 
 
202
  mask_blurred = cv2.GaussianBlur(mask_eroded, (15, 15), sigmaX=3, sigmaY=3)
203
- mask_blurred = mask_blurred.astype(np.float32) / 255.0 # [0,1]
204
 
205
- # ===============================
206
- # 7) Combine BASE mask with OVERLAY tiled alpha (NEW)
207
- # ===============================
208
- overlay_alpha_stack = pattern_alpha_tiled # (H,W) in [0,1]
209
- alpha_combined = np.clip(mask_blurred * overlay_alpha_stack, 0.0, 1.0)
210
 
211
- # Apply combined alpha to folded pattern
212
- pattern_final_rgb = pattern_folded * alpha_combined[..., None]
213
- pattern_rgb_u8 = (np.clip(pattern_final_rgb, 0.0, 1.0) * 255).astype(np.uint8)
214
- alpha_u8 = (alpha_combined * 255).astype(np.uint8)
 
 
215
 
216
- pattern_rgba = np.dstack((pattern_rgb_u8, alpha_u8))
217
  return Image.fromarray(pattern_rgba, mode="RGBA")
218
 
219
-
220
  # ===============================
221
  # WRAPPER: ACCEPT BYTES OR BASE64
222
  # ===============================
 
53
  # CORE PROCESSING FUNCTION
54
  # ===============================
55
  def _process_saree_core(base_image: Image.Image, pattern_image: Image.Image):
 
 
 
56
  img_pil = base_image.convert("RGB")
57
  img_np = np.array(img_pil)
58
 
59
+ # Prepare tensor
 
 
60
  img_resized = img_pil.resize((384, 384))
61
  img_tensor = torch.from_numpy(np.array(img_resized)).permute(2, 0, 1).unsqueeze(0).float() / 255.0
62
  mean = torch.as_tensor([0.5, 0.5, 0.5], device=img_tensor.device).view(1, 3, 1, 1)
 
67
  model = SimpleDPT(backbone_name='vit_base_patch16_384').to(device)
68
  model.eval()
69
 
70
+ # Depth inference
71
  with torch.no_grad():
72
  target_size = img_pil.size[::-1]
73
  depth_map = model(img_tensor.to(device), target_size=target_size)
74
  depth_map = depth_map.squeeze().cpu().numpy()
75
 
76
+ # Normalize depth
77
+ depth_vis = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
78
+
79
+ # Normal map
80
  normal_map = depth_to_normal(depth_vis)
81
 
82
+ # Shading map (CLAHE)
 
 
83
  img_lab = cv2.cvtColor(img_np, cv2.COLOR_RGB2LAB)
84
  l_channel, _, _ = cv2.split(img_lab)
85
  clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
86
  l_clahe = clahe.apply(l_channel)
87
  shading_map = l_clahe / 255.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
+ # Tile pattern
90
+ pattern_np = np.array(pattern_image.convert("RGB"))
91
+ target_h, target_w = img_np.shape[:2]
92
+ pattern_h, pattern_w = pattern_np.shape[:2]
93
+ pattern_tiled = np.zeros((target_h, target_w, 3), dtype=np.uint8)
94
+ for y in range(0, target_h, pattern_h):
95
+ for x in range(0, target_w, pattern_w):
96
+ end_y = min(y + pattern_h, target_h)
97
+ end_x = min(x + pattern_w, target_w)
98
+ pattern_tiled[y:end_y, x:end_x] = pattern_np[0:(end_y - y), 0:(end_x - x)]
99
+
100
+ # Blend pattern
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  normal_map_loaded = normal_map.astype(np.float32)
102
+ shading_map_loaded = np.stack([shading_map] * 3, axis=-1)
 
103
 
104
+ alpha = 0.7
105
+ blended_shading = alpha * shading_map_loaded + (1 - alpha)
 
 
 
106
 
107
+ pattern_folded = pattern_tiled.astype(np.float32) / 255.0 * blended_shading
108
  normal_boost = 0.5 + 0.5 * normal_map_loaded[..., 2:3]
109
  pattern_folded *= normal_boost
110
+ pattern_folded = np.clip(pattern_folded, 0, 1)
111
 
112
+ # ==========================================================
113
+ # Background removal with post-processing (no duplicate blur)
114
+ # ==========================================================
115
  buf = BytesIO()
116
  base_image.save(buf, format="PNG")
117
  base_bytes = buf.getvalue()
118
 
119
+ # Get RGBA from bgrem
120
  result_no_bg = bgrem_remove(base_bytes)
121
  mask_img = Image.open(BytesIO(result_no_bg)).convert("RGBA")
122
 
123
+ # Extract alpha and clean edges
124
  mask_alpha = np.array(mask_img)[:, :, 3].astype(np.float32) / 255.0
125
 
126
+ # 1. Slightly stronger shrink (balanced)
127
+ kernel = np.ones((5, 5), np.uint8) # slightly larger kernel
128
+ mask_binary = (mask_alpha > 0.05).astype(np.uint8) * 255 # slightly stricter threshold
129
+ mask_eroded = cv2.erode(mask_binary, kernel, iterations=3) # balanced erosion
130
+
131
+
132
+ # 2. Feather edges (blur)
133
  mask_blurred = cv2.GaussianBlur(mask_eroded, (15, 15), sigmaX=3, sigmaY=3)
 
134
 
135
+ # 3. Normalize
136
+ mask_blurred = mask_blurred.astype(np.float32) / 255.0
 
 
 
137
 
138
+ # Final RGBA
139
+ mask_stack = np.stack([mask_blurred] * 3, axis=-1)
140
+ pattern_final = pattern_folded * mask_stack
141
+ pattern_rgb = (pattern_final * 255).astype(np.uint8)
142
+ alpha_channel = (mask_blurred * 255).astype(np.uint8)
143
+ pattern_rgba = np.dstack((pattern_rgb, alpha_channel))
144
 
 
145
  return Image.fromarray(pattern_rgba, mode="RGBA")
146
 
 
147
  # ===============================
148
  # WRAPPER: ACCEPT BYTES OR BASE64
149
  # ===============================