Spaces:
Sleeping
Sleeping
AI Agent commited on
Commit ·
488f9ce
1
Parent(s): a387aca
Revert to SAM 3 mask pipeline with sigmoid logits + edge-only AA upscaler, remove rembg
Browse files- app.py +78 -83
- requirements.txt +0 -2
app.py
CHANGED
|
@@ -7,13 +7,6 @@ import os
|
|
| 7 |
import io
|
| 8 |
import fitz # PyMuPDF
|
| 9 |
|
| 10 |
-
# Fix: HF Spaces sets OMP_NUM_THREADS to "3500m" which crashes onnxruntime/rembg
|
| 11 |
-
omp_val = os.environ.get("OMP_NUM_THREADS", "")
|
| 12 |
-
if not omp_val.isdigit():
|
| 13 |
-
os.environ["OMP_NUM_THREADS"] = "4"
|
| 14 |
-
|
| 15 |
-
from rembg import remove as rembg_remove, new_session as rembg_new_session
|
| 16 |
-
|
| 17 |
# ── UNCONDITIONAL BFloat16 → Float16 Patch for T4 Turing GPUs ────
|
| 18 |
# CRITICAL: torch.cuda.is_bf16_supported() returns True on T4 because CUDA
|
| 19 |
# can *emulate* bfloat16 in software, but the actual kernels crash on mixed
|
|
@@ -146,37 +139,53 @@ ASSET_LIBRARY_DIR = os.path.join(tempfile.gettempdir(), "sam3_library")
|
|
| 146 |
os.makedirs(ASSET_LIBRARY_DIR, exist_ok=True)
|
| 147 |
asset_counter = 0
|
| 148 |
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
|
|
|
|
|
|
| 153 |
|
| 154 |
-
def
|
| 155 |
-
"""
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
def upscale_4x(rgba: np.ndarray) -> np.ndarray:
|
| 167 |
-
"""4x Lanczos upscale with unsharp masking
|
| 168 |
h, w = rgba.shape[:2]
|
| 169 |
new_w, new_h = w * 4, h * 4
|
| 170 |
|
| 171 |
-
# Upscale full RGBA
|
| 172 |
upscaled = cv2.resize(rgba, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
|
| 173 |
|
| 174 |
-
# Unsharp mask on RGB only
|
| 175 |
rgb = upscaled[:, :, :3]
|
| 176 |
blurred = cv2.GaussianBlur(rgb, (0, 0), sigmaX=1.0)
|
| 177 |
rgb_sharp = cv2.addWeighted(rgb, 1.5, blurred, -0.5, 0)
|
| 178 |
upscaled[:, :, :3] = rgb_sharp
|
| 179 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
return upscaled
|
| 181 |
|
| 182 |
def extract_assets(input_image):
|
|
@@ -196,7 +205,7 @@ def extract_assets(input_image):
|
|
| 196 |
print(f">>> Image size: {w}x{h}, area: {img_area}", flush=True)
|
| 197 |
pil_img = Image.fromarray(orig_rgb)
|
| 198 |
|
| 199 |
-
|
| 200 |
all_scores = []
|
| 201 |
|
| 202 |
with torch.inference_mode():
|
|
@@ -210,97 +219,83 @@ def extract_assets(input_image):
|
|
| 210 |
|
| 211 |
masks = out["masks"]
|
| 212 |
scores = out["scores"]
|
| 213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
|
| 215 |
if masks is None or len(masks) == 0:
|
| 216 |
-
print(f" [{concept}] No
|
| 217 |
continue
|
| 218 |
|
| 219 |
if torch.is_tensor(masks): masks = masks.float().cpu().numpy()
|
| 220 |
if torch.is_tensor(scores): scores = scores.float().cpu().numpy()
|
| 221 |
-
if
|
|
|
|
| 222 |
|
| 223 |
-
print(f" [{concept}] Found {len(masks)}
|
| 224 |
|
| 225 |
for j in range(len(masks)):
|
| 226 |
-
|
|
|
|
|
|
|
| 227 |
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
box = boxes[j].flatten()
|
| 231 |
-
# SAM 3 boxes format: get x0,y0,x1,y1
|
| 232 |
-
if len(box) >= 4:
|
| 233 |
-
x0, y0, x1, y1 = int(box[0]), int(box[1]), int(box[2]), int(box[3])
|
| 234 |
-
else:
|
| 235 |
-
continue
|
| 236 |
-
else:
|
| 237 |
-
# Fallback: derive box from mask
|
| 238 |
-
m = masks[j]
|
| 239 |
-
while m.ndim > 2: m = m[0]
|
| 240 |
-
ys, xs = np.nonzero(m > 0.5)
|
| 241 |
-
if len(ys) == 0: continue
|
| 242 |
-
x0, y0, x1, y1 = int(xs.min()), int(ys.min()), int(xs.max()), int(ys.max())
|
| 243 |
-
|
| 244 |
-
# Validate box
|
| 245 |
-
box_w = x1 - x0
|
| 246 |
-
box_h = y1 - y0
|
| 247 |
-
box_area = box_w * box_h
|
| 248 |
|
| 249 |
-
if score < 0.1 or
|
| 250 |
-
print(f"
|
| 251 |
continue
|
| 252 |
|
| 253 |
-
#
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
|
| 261 |
-
|
| 262 |
all_scores.append(score)
|
| 263 |
-
print(f"
|
| 264 |
|
| 265 |
-
print(f">>> Total
|
| 266 |
|
| 267 |
-
if not
|
| 268 |
gr.Info("No assets found in this image. Try a different slide with more visual elements.")
|
| 269 |
-
print(">>> No
|
| 270 |
return []
|
| 271 |
|
| 272 |
-
# Deduplicate by
|
| 273 |
-
order = sorted(range(len(
|
| 274 |
keep = []
|
| 275 |
for i in order:
|
| 276 |
dup = False
|
| 277 |
for ki in keep:
|
| 278 |
-
if
|
| 279 |
dup = True
|
| 280 |
break
|
| 281 |
if not dup:
|
| 282 |
keep.append(i)
|
| 283 |
|
| 284 |
-
#
|
| 285 |
results = []
|
| 286 |
global asset_counter
|
| 287 |
for idx, ki in enumerate(keep):
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
# Background removal with rembg (clean alpha matte)
|
| 292 |
-
crop_pil = Image.fromarray(crop_rgb)
|
| 293 |
-
rgba_pil = rembg_remove(crop_pil, session=rembg_session)
|
| 294 |
-
rgba_np = np.array(rgba_pil)
|
| 295 |
-
print(f" crop[{idx}] rembg done: {rgba_np.shape}", flush=True)
|
| 296 |
-
|
| 297 |
-
# 4x upscale
|
| 298 |
-
rgba_np = upscale_4x(rgba_np)
|
| 299 |
-
|
| 300 |
-
# Save to library
|
| 301 |
asset_counter += 1
|
| 302 |
lib_path = os.path.join(ASSET_LIBRARY_DIR, f"asset_{asset_counter:04d}.png")
|
| 303 |
-
Image.fromarray(
|
| 304 |
results.append(lib_path)
|
| 305 |
|
| 306 |
print(f">>> Returning {len(results)} assets (library total: {asset_counter})", flush=True)
|
|
|
|
| 7 |
import io
|
| 8 |
import fitz # PyMuPDF
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
# ── UNCONDITIONAL BFloat16 → Float16 Patch for T4 Turing GPUs ────
|
| 11 |
# CRITICAL: torch.cuda.is_bf16_supported() returns True on T4 because CUDA
|
| 12 |
# can *emulate* bfloat16 in software, but the actual kernels crash on mixed
|
|
|
|
| 139 |
os.makedirs(ASSET_LIBRARY_DIR, exist_ok=True)
|
| 140 |
asset_counter = 0
|
| 141 |
|
| 142 |
+
def mask_iou(m1: np.ndarray, m2: np.ndarray) -> float:
|
| 143 |
+
b1 = m1 > 128 if m1.dtype == np.uint8 else m1
|
| 144 |
+
b2 = m2 > 128 if m2.dtype == np.uint8 else m2
|
| 145 |
+
inter = np.logical_and(b1, b2).sum()
|
| 146 |
+
union = np.logical_or(b1, b2).sum()
|
| 147 |
+
return float(inter) / float(union) if union > 0 else 0.0
|
| 148 |
|
| 149 |
+
def mask_to_crop(orig_rgb: np.ndarray, alpha: np.ndarray) -> np.ndarray:
|
| 150 |
+
"""Crop RGBA using alpha channel, with edge-only AA smoothing."""
|
| 151 |
+
h, w = orig_rgb.shape[:2]
|
| 152 |
+
|
| 153 |
+
rgba = np.zeros((h, w, 4), dtype=np.uint8)
|
| 154 |
+
rgba[:, :, :3] = orig_rgb
|
| 155 |
+
rgba[:, :, 3] = alpha
|
| 156 |
+
|
| 157 |
+
ys, xs = np.nonzero(alpha > 10)
|
| 158 |
+
if len(ys) == 0:
|
| 159 |
+
return rgba
|
| 160 |
+
y0, y1 = int(ys.min()), int(ys.max())
|
| 161 |
+
x0, x1 = int(xs.min()), int(xs.max())
|
| 162 |
+
pad = 6
|
| 163 |
+
y0, x0 = max(0, y0 - pad), max(0, x0 - pad)
|
| 164 |
+
y1, x1 = min(h - 1, y1 + pad), min(w - 1, x1 + pad)
|
| 165 |
+
return rgba[y0:y1+1, x0:x1+1]
|
| 166 |
|
| 167 |
def upscale_4x(rgba: np.ndarray) -> np.ndarray:
|
| 168 |
+
"""4x Lanczos upscale with unsharp masking + edge-only alpha AA."""
|
| 169 |
h, w = rgba.shape[:2]
|
| 170 |
new_w, new_h = w * 4, h * 4
|
| 171 |
|
| 172 |
+
# Upscale full RGBA with Lanczos4
|
| 173 |
upscaled = cv2.resize(rgba, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
|
| 174 |
|
| 175 |
+
# Unsharp mask on RGB only
|
| 176 |
rgb = upscaled[:, :, :3]
|
| 177 |
blurred = cv2.GaussianBlur(rgb, (0, 0), sigmaX=1.0)
|
| 178 |
rgb_sharp = cv2.addWeighted(rgb, 1.5, blurred, -0.5, 0)
|
| 179 |
upscaled[:, :, :3] = rgb_sharp
|
| 180 |
|
| 181 |
+
# Edge-only AA on alpha: blur then re-harden interior
|
| 182 |
+
alpha = upscaled[:, :, 3].astype(np.float32)
|
| 183 |
+
alpha_blur = cv2.GaussianBlur(alpha, (5, 5), sigmaX=1.2)
|
| 184 |
+
# Keep interior fully opaque, only use blurred values at edges
|
| 185 |
+
interior = alpha > 240 # pixels that were solidly opaque
|
| 186 |
+
alpha_aa = np.where(interior, 255.0, alpha_blur)
|
| 187 |
+
upscaled[:, :, 3] = alpha_aa.clip(0, 255).astype(np.uint8)
|
| 188 |
+
|
| 189 |
return upscaled
|
| 190 |
|
| 191 |
def extract_assets(input_image):
|
|
|
|
| 205 |
print(f">>> Image size: {w}x{h}, area: {img_area}", flush=True)
|
| 206 |
pil_img = Image.fromarray(orig_rgb)
|
| 207 |
|
| 208 |
+
all_masks = []
|
| 209 |
all_scores = []
|
| 210 |
|
| 211 |
with torch.inference_mode():
|
|
|
|
| 219 |
|
| 220 |
masks = out["masks"]
|
| 221 |
scores = out["scores"]
|
| 222 |
+
|
| 223 |
+
# Check for raw logits
|
| 224 |
+
raw_logits = None
|
| 225 |
+
for logit_key in ["masks_logits", "low_res_masks", "logits", "mask_logits"]:
|
| 226 |
+
val = out.get(logit_key)
|
| 227 |
+
if val is not None and (not torch.is_tensor(val) or val.numel() > 0):
|
| 228 |
+
raw_logits = val
|
| 229 |
+
break
|
| 230 |
|
| 231 |
if masks is None or len(masks) == 0:
|
| 232 |
+
print(f" [{concept}] No masks returned", flush=True)
|
| 233 |
continue
|
| 234 |
|
| 235 |
if torch.is_tensor(masks): masks = masks.float().cpu().numpy()
|
| 236 |
if torch.is_tensor(scores): scores = scores.float().cpu().numpy()
|
| 237 |
+
if raw_logits is not None and torch.is_tensor(raw_logits):
|
| 238 |
+
raw_logits = raw_logits.float().cpu().numpy()
|
| 239 |
|
| 240 |
+
print(f" [{concept}] Found {len(masks)} masks", flush=True)
|
| 241 |
|
| 242 |
for j in range(len(masks)):
|
| 243 |
+
m = masks[j]
|
| 244 |
+
while m.ndim > 2: m = m[0]
|
| 245 |
+
m_bool = m.astype(bool)
|
| 246 |
|
| 247 |
+
score = float(scores[j]) if scores.ndim > 0 else float(scores)
|
| 248 |
+
area = m_bool.sum()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
|
| 250 |
+
if score < 0.1 or area < 500 or area > img_area * 0.90:
|
| 251 |
+
print(f" mask[{j}] SKIPPED: score={score:.4f}, area={area}", flush=True)
|
| 252 |
continue
|
| 253 |
|
| 254 |
+
# Build alpha: sigmoid logits for smooth edges + full opacity interior
|
| 255 |
+
if raw_logits is not None and j < len(raw_logits):
|
| 256 |
+
logit = raw_logits[j]
|
| 257 |
+
while logit.ndim > 2: logit = logit[0]
|
| 258 |
+
alpha_smooth = 1.0 / (1.0 + np.exp(-logit.astype(np.float32)))
|
| 259 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
|
| 260 |
+
dilated = cv2.dilate(m_bool.astype(np.uint8), kernel, iterations=2)
|
| 261 |
+
alpha_smooth = alpha_smooth * dilated
|
| 262 |
+
alpha_uint8 = (alpha_smooth * 255).clip(0, 255).astype(np.uint8)
|
| 263 |
+
alpha_mask = np.where(m_bool, np.uint8(255), alpha_uint8)
|
| 264 |
+
else:
|
| 265 |
+
alpha_mask = m_bool.astype(np.uint8) * 255
|
| 266 |
|
| 267 |
+
all_masks.append(alpha_mask)
|
| 268 |
all_scores.append(score)
|
| 269 |
+
print(f" mask[{j}] KEPT: score={score:.4f}, area={area}", flush=True)
|
| 270 |
|
| 271 |
+
print(f">>> Total masks kept: {len(all_masks)}", flush=True)
|
| 272 |
|
| 273 |
+
if not all_masks:
|
| 274 |
gr.Info("No assets found in this image. Try a different slide with more visual elements.")
|
| 275 |
+
print(">>> No masks passed filters, returning []", flush=True)
|
| 276 |
return []
|
| 277 |
|
| 278 |
+
# Deduplicate by mask IoU
|
| 279 |
+
order = sorted(range(len(all_masks)), key=lambda i: all_scores[i], reverse=True)
|
| 280 |
keep = []
|
| 281 |
for i in order:
|
| 282 |
dup = False
|
| 283 |
for ki in keep:
|
| 284 |
+
if mask_iou(all_masks[i], all_masks[ki]) > 0.5:
|
| 285 |
dup = True
|
| 286 |
break
|
| 287 |
if not dup:
|
| 288 |
keep.append(i)
|
| 289 |
|
| 290 |
+
# Crop → upscale (with edge AA) → save
|
| 291 |
results = []
|
| 292 |
global asset_counter
|
| 293 |
for idx, ki in enumerate(keep):
|
| 294 |
+
crop = mask_to_crop(orig_rgb, all_masks[ki])
|
| 295 |
+
crop = upscale_4x(crop)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
asset_counter += 1
|
| 297 |
lib_path = os.path.join(ASSET_LIBRARY_DIR, f"asset_{asset_counter:04d}.png")
|
| 298 |
+
Image.fromarray(crop, "RGBA").save(lib_path, format="PNG")
|
| 299 |
results.append(lib_path)
|
| 300 |
|
| 301 |
print(f">>> Returning {len(results)} assets (library total: {asset_counter})", flush=True)
|
requirements.txt
CHANGED
|
@@ -11,5 +11,3 @@ gradio
|
|
| 11 |
git+https://github.com/facebookresearch/sam3.git
|
| 12 |
opencv-python-headless
|
| 13 |
PyMuPDF
|
| 14 |
-
rembg
|
| 15 |
-
onnxruntime
|
|
|
|
| 11 |
git+https://github.com/facebookresearch/sam3.git
|
| 12 |
opencv-python-headless
|
| 13 |
PyMuPDF
|
|
|
|
|
|