Spaces:
Sleeping
Sleeping
REVERT
Browse files
app.py
CHANGED
|
@@ -53,15 +53,10 @@ def depth_to_normal(depth):
|
|
| 53 |
# CORE PROCESSING FUNCTION
|
| 54 |
# ===============================
|
| 55 |
def _process_saree_core(base_image: Image.Image, pattern_image: Image.Image):
|
| 56 |
-
# ===============================
|
| 57 |
-
# 0) Prep: base to RGB/np
|
| 58 |
-
# ===============================
|
| 59 |
img_pil = base_image.convert("RGB")
|
| 60 |
img_np = np.array(img_pil)
|
| 61 |
|
| 62 |
-
#
|
| 63 |
-
# 1) Depth inference (kept as-is)
|
| 64 |
-
# ===============================
|
| 65 |
img_resized = img_pil.resize((384, 384))
|
| 66 |
img_tensor = torch.from_numpy(np.array(img_resized)).permute(2, 0, 1).unsqueeze(0).float() / 255.0
|
| 67 |
mean = torch.as_tensor([0.5, 0.5, 0.5], device=img_tensor.device).view(1, 3, 1, 1)
|
|
@@ -72,151 +67,83 @@ def _process_saree_core(base_image: Image.Image, pattern_image: Image.Image):
|
|
| 72 |
model = SimpleDPT(backbone_name='vit_base_patch16_384').to(device)
|
| 73 |
model.eval()
|
| 74 |
|
|
|
|
| 75 |
with torch.no_grad():
|
| 76 |
target_size = img_pil.size[::-1]
|
| 77 |
depth_map = model(img_tensor.to(device), target_size=target_size)
|
| 78 |
depth_map = depth_map.squeeze().cpu().numpy()
|
| 79 |
|
| 80 |
-
# Normalize depth
|
| 81 |
-
depth_vis = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min()
|
|
|
|
|
|
|
| 82 |
normal_map = depth_to_normal(depth_vis)
|
| 83 |
|
| 84 |
-
#
|
| 85 |
-
# 2) Shading map (CLAHE)
|
| 86 |
-
# ===============================
|
| 87 |
img_lab = cv2.cvtColor(img_np, cv2.COLOR_RGB2LAB)
|
| 88 |
l_channel, _, _ = cv2.split(img_lab)
|
| 89 |
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
| 90 |
l_clahe = clahe.apply(l_channel)
|
| 91 |
shading_map = l_clahe / 255.0
|
| 92 |
-
shading_map_loaded = np.stack([shading_map] * 3, axis=-1) # (H,W,3)
|
| 93 |
-
|
| 94 |
-
# ===============================
|
| 95 |
-
# 3) OVERLAY alpha feather (NEW)
|
| 96 |
-
# ===============================
|
| 97 |
-
# pattern_np = np.array(pattern_image.convert("RGB")) # <-- ORIGINAL (kills alpha) [COMMENTED]
|
| 98 |
-
pattern_rgba_full = np.array(pattern_image.convert("RGBA")) # keep alpha
|
| 99 |
-
alpha = pattern_rgba_full[:, :, 3].astype(np.float32) / 255.0
|
| 100 |
-
|
| 101 |
-
# feather alpha a little to soften edges
|
| 102 |
-
alpha_feathered = cv2.GaussianBlur(alpha, (5, 5), sigmaX=2, sigmaY=2)
|
| 103 |
-
alpha_feathered = np.clip(alpha_feathered, 0.0, 1.0)
|
| 104 |
-
|
| 105 |
-
# premultiply RGB by feathered alpha
|
| 106 |
-
rgb = pattern_rgba_full[:, :, :3].astype(np.float32) / 255.0
|
| 107 |
-
rgb_pm = rgb * alpha_feathered[..., None] # premultiplied RGB in [0,1]
|
| 108 |
-
|
| 109 |
-
# Optional: crop to non-transparent bbox to avoid tiling empty margins
|
| 110 |
-
alpha_thresh = 0.01
|
| 111 |
-
ys, xs = np.where(alpha_feathered > alpha_thresh)
|
| 112 |
-
if ys.size > 0:
|
| 113 |
-
y0, y1 = ys.min(), ys.max() + 1
|
| 114 |
-
x0, x1 = xs.min(), xs.max() + 1
|
| 115 |
-
rgb_pm = rgb_pm[y0:y1, x0:x1, :]
|
| 116 |
-
alpha_crop = alpha_feathered[y0:y1, x0:x1]
|
| 117 |
-
else:
|
| 118 |
-
alpha_crop = alpha_feathered # degenerate case
|
| 119 |
-
|
| 120 |
-
ph, pw = alpha_crop.shape[:2]
|
| 121 |
-
|
| 122 |
-
# ===============================
|
| 123 |
-
# 4) Alpha-aware tiling (NEW)
|
| 124 |
-
# ===============================
|
| 125 |
-
target_h, target_w = img_np.shape[:2]
|
| 126 |
|
| 127 |
-
#
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
tile_rgb_pm_src = rgb_pm.astype(np.float32) # (ph,pw,3), premultiplied
|
| 141 |
-
tile_a_src = alpha_crop.astype(np.float32)[..., None] # (ph,pw,1)
|
| 142 |
-
|
| 143 |
-
for y in range(0, target_h, ph):
|
| 144 |
-
for x in range(0, target_w, pw):
|
| 145 |
-
end_y = min(y + ph, target_h)
|
| 146 |
-
end_x = min(x + pw, target_w)
|
| 147 |
-
h = end_y - y
|
| 148 |
-
w = end_x - x
|
| 149 |
-
|
| 150 |
-
src_rgb_pm = tile_rgb_pm_src[:h, :w, :]
|
| 151 |
-
src_a = tile_a_src[:h, :w, :]
|
| 152 |
-
|
| 153 |
-
dst_rgb_pm = canvas_rgb_pm[y:end_y, x:end_x, :]
|
| 154 |
-
dst_a = canvas_a[y:end_y, x:end_x, :]
|
| 155 |
-
|
| 156 |
-
out_rgb_pm = src_rgb_pm + dst_rgb_pm * (1.0 - src_a)
|
| 157 |
-
out_a = src_a + dst_a * (1.0 - src_a)
|
| 158 |
-
|
| 159 |
-
canvas_rgb_pm[y:end_y, x:end_x, :] = out_rgb_pm
|
| 160 |
-
canvas_a[y:end_y, x:end_x, :] = out_a
|
| 161 |
-
|
| 162 |
-
# Un-premultiply to get display RGB; keep tiled overlay alpha
|
| 163 |
-
canvas_a_safe = np.clip(canvas_a, 1e-6, 1.0)
|
| 164 |
-
pattern_rgb_tiled = np.clip(canvas_rgb_pm / canvas_a_safe, 0.0, 1.0) # (H,W,3)
|
| 165 |
-
pattern_alpha_tiled = np.clip(canvas_a[..., 0], 0.0, 1.0) # (H,W)
|
| 166 |
-
|
| 167 |
-
# ===============================
|
| 168 |
-
# 5) Apply shading + normal boost (kept as-is)
|
| 169 |
-
# ===============================
|
| 170 |
normal_map_loaded = normal_map.astype(np.float32)
|
| 171 |
-
|
| 172 |
-
blended_shading = alpha_shading * shading_map_loaded + (1 - alpha_shading)
|
| 173 |
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
# normal_boost = 0.5 + 0.5 * normal_map_loaded[..., 2:3]
|
| 177 |
-
# pattern_folded *= normal_boost
|
| 178 |
-
# pattern_folded = np.clip(pattern_folded, 0, 1)
|
| 179 |
|
| 180 |
-
pattern_folded =
|
| 181 |
normal_boost = 0.5 + 0.5 * normal_map_loaded[..., 2:3]
|
| 182 |
pattern_folded *= normal_boost
|
| 183 |
-
pattern_folded = np.clip(pattern_folded, 0
|
| 184 |
|
| 185 |
-
#
|
| 186 |
-
#
|
| 187 |
-
#
|
| 188 |
buf = BytesIO()
|
| 189 |
base_image.save(buf, format="PNG")
|
| 190 |
base_bytes = buf.getvalue()
|
| 191 |
|
|
|
|
| 192 |
result_no_bg = bgrem_remove(base_bytes)
|
| 193 |
mask_img = Image.open(BytesIO(result_no_bg)).convert("RGBA")
|
| 194 |
|
|
|
|
| 195 |
mask_alpha = np.array(mask_img)[:, :, 3].astype(np.float32) / 255.0
|
| 196 |
|
| 197 |
-
# Slightly stronger shrink
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
|
|
|
|
|
|
| 202 |
mask_blurred = cv2.GaussianBlur(mask_eroded, (15, 15), sigmaX=3, sigmaY=3)
|
| 203 |
-
mask_blurred = mask_blurred.astype(np.float32) / 255.0 # [0,1]
|
| 204 |
|
| 205 |
-
#
|
| 206 |
-
|
| 207 |
-
# ===============================
|
| 208 |
-
overlay_alpha_stack = pattern_alpha_tiled # (H,W) in [0,1]
|
| 209 |
-
alpha_combined = np.clip(mask_blurred * overlay_alpha_stack, 0.0, 1.0)
|
| 210 |
|
| 211 |
-
#
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
|
|
|
|
|
|
| 215 |
|
| 216 |
-
pattern_rgba = np.dstack((pattern_rgb_u8, alpha_u8))
|
| 217 |
return Image.fromarray(pattern_rgba, mode="RGBA")
|
| 218 |
|
| 219 |
-
|
| 220 |
# ===============================
|
| 221 |
# WRAPPER: ACCEPT BYTES OR BASE64
|
| 222 |
# ===============================
|
|
|
|
| 53 |
# CORE PROCESSING FUNCTION
|
| 54 |
# ===============================
|
| 55 |
def _process_saree_core(base_image: Image.Image, pattern_image: Image.Image):
|
|
|
|
|
|
|
|
|
|
| 56 |
img_pil = base_image.convert("RGB")
|
| 57 |
img_np = np.array(img_pil)
|
| 58 |
|
| 59 |
+
# Prepare tensor
|
|
|
|
|
|
|
| 60 |
img_resized = img_pil.resize((384, 384))
|
| 61 |
img_tensor = torch.from_numpy(np.array(img_resized)).permute(2, 0, 1).unsqueeze(0).float() / 255.0
|
| 62 |
mean = torch.as_tensor([0.5, 0.5, 0.5], device=img_tensor.device).view(1, 3, 1, 1)
|
|
|
|
| 67 |
model = SimpleDPT(backbone_name='vit_base_patch16_384').to(device)
|
| 68 |
model.eval()
|
| 69 |
|
| 70 |
+
# Depth inference
|
| 71 |
with torch.no_grad():
|
| 72 |
target_size = img_pil.size[::-1]
|
| 73 |
depth_map = model(img_tensor.to(device), target_size=target_size)
|
| 74 |
depth_map = depth_map.squeeze().cpu().numpy()
|
| 75 |
|
| 76 |
+
# Normalize depth
|
| 77 |
+
depth_vis = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
|
| 78 |
+
|
| 79 |
+
# Normal map
|
| 80 |
normal_map = depth_to_normal(depth_vis)
|
| 81 |
|
| 82 |
+
# Shading map (CLAHE)
|
|
|
|
|
|
|
| 83 |
img_lab = cv2.cvtColor(img_np, cv2.COLOR_RGB2LAB)
|
| 84 |
l_channel, _, _ = cv2.split(img_lab)
|
| 85 |
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
| 86 |
l_clahe = clahe.apply(l_channel)
|
| 87 |
shading_map = l_clahe / 255.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
+
# Tile pattern
|
| 90 |
+
pattern_np = np.array(pattern_image.convert("RGB"))
|
| 91 |
+
target_h, target_w = img_np.shape[:2]
|
| 92 |
+
pattern_h, pattern_w = pattern_np.shape[:2]
|
| 93 |
+
pattern_tiled = np.zeros((target_h, target_w, 3), dtype=np.uint8)
|
| 94 |
+
for y in range(0, target_h, pattern_h):
|
| 95 |
+
for x in range(0, target_w, pattern_w):
|
| 96 |
+
end_y = min(y + pattern_h, target_h)
|
| 97 |
+
end_x = min(x + pattern_w, target_w)
|
| 98 |
+
pattern_tiled[y:end_y, x:end_x] = pattern_np[0:(end_y - y), 0:(end_x - x)]
|
| 99 |
+
|
| 100 |
+
# Blend pattern
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
normal_map_loaded = normal_map.astype(np.float32)
|
| 102 |
+
shading_map_loaded = np.stack([shading_map] * 3, axis=-1)
|
|
|
|
| 103 |
|
| 104 |
+
alpha = 0.7
|
| 105 |
+
blended_shading = alpha * shading_map_loaded + (1 - alpha)
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
+
pattern_folded = pattern_tiled.astype(np.float32) / 255.0 * blended_shading
|
| 108 |
normal_boost = 0.5 + 0.5 * normal_map_loaded[..., 2:3]
|
| 109 |
pattern_folded *= normal_boost
|
| 110 |
+
pattern_folded = np.clip(pattern_folded, 0, 1)
|
| 111 |
|
| 112 |
+
# ==========================================================
|
| 113 |
+
# Background removal with post-processing (no duplicate blur)
|
| 114 |
+
# ==========================================================
|
| 115 |
buf = BytesIO()
|
| 116 |
base_image.save(buf, format="PNG")
|
| 117 |
base_bytes = buf.getvalue()
|
| 118 |
|
| 119 |
+
# Get RGBA from bgrem
|
| 120 |
result_no_bg = bgrem_remove(base_bytes)
|
| 121 |
mask_img = Image.open(BytesIO(result_no_bg)).convert("RGBA")
|
| 122 |
|
| 123 |
+
# Extract alpha and clean edges
|
| 124 |
mask_alpha = np.array(mask_img)[:, :, 3].astype(np.float32) / 255.0
|
| 125 |
|
| 126 |
+
# 1. Slightly stronger shrink (balanced)
|
| 127 |
+
kernel = np.ones((5, 5), np.uint8) # slightly larger kernel
|
| 128 |
+
mask_binary = (mask_alpha > 0.05).astype(np.uint8) * 255 # slightly stricter threshold
|
| 129 |
+
mask_eroded = cv2.erode(mask_binary, kernel, iterations=3) # balanced erosion
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
# 2. Feather edges (blur)
|
| 133 |
mask_blurred = cv2.GaussianBlur(mask_eroded, (15, 15), sigmaX=3, sigmaY=3)
|
|
|
|
| 134 |
|
| 135 |
+
# 3. Normalize
|
| 136 |
+
mask_blurred = mask_blurred.astype(np.float32) / 255.0
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
+
# Final RGBA
|
| 139 |
+
mask_stack = np.stack([mask_blurred] * 3, axis=-1)
|
| 140 |
+
pattern_final = pattern_folded * mask_stack
|
| 141 |
+
pattern_rgb = (pattern_final * 255).astype(np.uint8)
|
| 142 |
+
alpha_channel = (mask_blurred * 255).astype(np.uint8)
|
| 143 |
+
pattern_rgba = np.dstack((pattern_rgb, alpha_channel))
|
| 144 |
|
|
|
|
| 145 |
return Image.fromarray(pattern_rgba, mode="RGBA")
|
| 146 |
|
|
|
|
| 147 |
# ===============================
|
| 148 |
# WRAPPER: ACCEPT BYTES OR BASE64
|
| 149 |
# ===============================
|