AI Agent commited on
Commit
488f9ce
·
1 Parent(s): a387aca

Revert to SAM 3 mask pipeline with sigmoid logits + edge-only AA upscaler, remove rembg

Browse files
Files changed (2) hide show
  1. app.py +78 -83
  2. requirements.txt +0 -2
app.py CHANGED
@@ -7,13 +7,6 @@ import os
7
  import io
8
  import fitz # PyMuPDF
9
 
10
- # Fix: HF Spaces sets OMP_NUM_THREADS to "3500m" which crashes onnxruntime/rembg
11
- omp_val = os.environ.get("OMP_NUM_THREADS", "")
12
- if not omp_val.isdigit():
13
- os.environ["OMP_NUM_THREADS"] = "4"
14
-
15
- from rembg import remove as rembg_remove, new_session as rembg_new_session
16
-
17
  # ── UNCONDITIONAL BFloat16 → Float16 Patch for T4 Turing GPUs ────
18
  # CRITICAL: torch.cuda.is_bf16_supported() returns True on T4 because CUDA
19
  # can *emulate* bfloat16 in software, but the actual kernels crash on mixed
@@ -146,37 +139,53 @@ ASSET_LIBRARY_DIR = os.path.join(tempfile.gettempdir(), "sam3_library")
146
  os.makedirs(ASSET_LIBRARY_DIR, exist_ok=True)
147
  asset_counter = 0
148
 
149
- # Initialize rembg session (downloads model once, reuses for all requests)
150
- print("Loading background removal model...", flush=True)
151
- rembg_session = rembg_new_session(model_name="isnet-general-use")
152
- print("Background removal model loaded.", flush=True)
 
 
153
 
154
- def box_iou(b1, b2):
155
- """IoU between two boxes [x0, y0, x1, y1]."""
156
- x0 = max(b1[0], b2[0])
157
- y0 = max(b1[1], b2[1])
158
- x1 = min(b1[2], b2[2])
159
- y1 = min(b1[3], b2[3])
160
- inter = max(0, x1 - x0) * max(0, y1 - y0)
161
- a1 = (b1[2] - b1[0]) * (b1[3] - b1[1])
162
- a2 = (b2[2] - b2[0]) * (b2[3] - b2[1])
163
- union = a1 + a2 - inter
164
- return inter / union if union > 0 else 0.0
 
 
 
 
 
 
165
 
166
  def upscale_4x(rgba: np.ndarray) -> np.ndarray:
167
- """4x Lanczos upscale with unsharp masking for crisp graphics."""
168
  h, w = rgba.shape[:2]
169
  new_w, new_h = w * 4, h * 4
170
 
171
- # Upscale full RGBA together with Lanczos4
172
  upscaled = cv2.resize(rgba, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
173
 
174
- # Unsharp mask on RGB only (preserve alpha from rembg)
175
  rgb = upscaled[:, :, :3]
176
  blurred = cv2.GaussianBlur(rgb, (0, 0), sigmaX=1.0)
177
  rgb_sharp = cv2.addWeighted(rgb, 1.5, blurred, -0.5, 0)
178
  upscaled[:, :, :3] = rgb_sharp
179
 
 
 
 
 
 
 
 
 
180
  return upscaled
181
 
182
  def extract_assets(input_image):
@@ -196,7 +205,7 @@ def extract_assets(input_image):
196
  print(f">>> Image size: {w}x{h}, area: {img_area}", flush=True)
197
  pil_img = Image.fromarray(orig_rgb)
198
 
199
- all_boxes = []
200
  all_scores = []
201
 
202
  with torch.inference_mode():
@@ -210,97 +219,83 @@ def extract_assets(input_image):
210
 
211
  masks = out["masks"]
212
  scores = out["scores"]
213
- boxes = out.get("boxes")
 
 
 
 
 
 
 
214
 
215
  if masks is None or len(masks) == 0:
216
- print(f" [{concept}] No detections", flush=True)
217
  continue
218
 
219
  if torch.is_tensor(masks): masks = masks.float().cpu().numpy()
220
  if torch.is_tensor(scores): scores = scores.float().cpu().numpy()
221
- if boxes is not None and torch.is_tensor(boxes): boxes = boxes.float().cpu().numpy()
 
222
 
223
- print(f" [{concept}] Found {len(masks)} detections, boxes: {boxes.shape if boxes is not None else 'None'}", flush=True)
224
 
225
  for j in range(len(masks)):
226
- score = float(scores[j]) if scores.ndim > 0 else float(scores)
 
 
227
 
228
- # Get bounding box: prefer SAM 3's boxes, fallback to mask bbox
229
- if boxes is not None and j < len(boxes):
230
- box = boxes[j].flatten()
231
- # SAM 3 boxes format: get x0,y0,x1,y1
232
- if len(box) >= 4:
233
- x0, y0, x1, y1 = int(box[0]), int(box[1]), int(box[2]), int(box[3])
234
- else:
235
- continue
236
- else:
237
- # Fallback: derive box from mask
238
- m = masks[j]
239
- while m.ndim > 2: m = m[0]
240
- ys, xs = np.nonzero(m > 0.5)
241
- if len(ys) == 0: continue
242
- x0, y0, x1, y1 = int(xs.min()), int(ys.min()), int(xs.max()), int(ys.max())
243
-
244
- # Validate box
245
- box_w = x1 - x0
246
- box_h = y1 - y0
247
- box_area = box_w * box_h
248
 
249
- if score < 0.1 or box_area < 500 or box_area > img_area * 0.90:
250
- print(f" det[{j}] SKIPPED: score={score:.4f}, area={box_area}", flush=True)
251
  continue
252
 
253
- # Add padding (10% of box size)
254
- pad_x = max(10, int(box_w * 0.10))
255
- pad_y = max(10, int(box_h * 0.10))
256
- bx0 = max(0, x0 - pad_x)
257
- by0 = max(0, y0 - pad_y)
258
- bx1 = min(w, x1 + pad_x)
259
- by1 = min(h, y1 + pad_y)
 
 
 
 
 
260
 
261
- all_boxes.append([bx0, by0, bx1, by1])
262
  all_scores.append(score)
263
- print(f" det[{j}] KEPT: score={score:.4f}, box=[{bx0},{by0},{bx1},{by1}]", flush=True)
264
 
265
- print(f">>> Total detections kept: {len(all_boxes)}", flush=True)
266
 
267
- if not all_boxes:
268
  gr.Info("No assets found in this image. Try a different slide with more visual elements.")
269
- print(">>> No detections passed filters, returning []", flush=True)
270
  return []
271
 
272
- # Deduplicate by box IoU
273
- order = sorted(range(len(all_boxes)), key=lambda i: all_scores[i], reverse=True)
274
  keep = []
275
  for i in order:
276
  dup = False
277
  for ki in keep:
278
- if box_iou(all_boxes[i], all_boxes[ki]) > 0.5:
279
  dup = True
280
  break
281
  if not dup:
282
  keep.append(i)
283
 
284
- # For each kept box: crop rembg upscale → save
285
  results = []
286
  global asset_counter
287
  for idx, ki in enumerate(keep):
288
- bx0, by0, bx1, by1 = all_boxes[ki]
289
- crop_rgb = orig_rgb[by0:by1, bx0:bx1]
290
-
291
- # Background removal with rembg (clean alpha matte)
292
- crop_pil = Image.fromarray(crop_rgb)
293
- rgba_pil = rembg_remove(crop_pil, session=rembg_session)
294
- rgba_np = np.array(rgba_pil)
295
- print(f" crop[{idx}] rembg done: {rgba_np.shape}", flush=True)
296
-
297
- # 4x upscale
298
- rgba_np = upscale_4x(rgba_np)
299
-
300
- # Save to library
301
  asset_counter += 1
302
  lib_path = os.path.join(ASSET_LIBRARY_DIR, f"asset_{asset_counter:04d}.png")
303
- Image.fromarray(rgba_np, "RGBA").save(lib_path, format="PNG")
304
  results.append(lib_path)
305
 
306
  print(f">>> Returning {len(results)} assets (library total: {asset_counter})", flush=True)
 
7
  import io
8
  import fitz # PyMuPDF
9
 
 
 
 
 
 
 
 
10
  # ── UNCONDITIONAL BFloat16 → Float16 Patch for T4 Turing GPUs ────
11
  # CRITICAL: torch.cuda.is_bf16_supported() returns True on T4 because CUDA
12
  # can *emulate* bfloat16 in software, but the actual kernels crash on mixed
 
139
  os.makedirs(ASSET_LIBRARY_DIR, exist_ok=True)
140
  asset_counter = 0
141
 
142
+ def mask_iou(m1: np.ndarray, m2: np.ndarray) -> float:
143
+ b1 = m1 > 128 if m1.dtype == np.uint8 else m1
144
+ b2 = m2 > 128 if m2.dtype == np.uint8 else m2
145
+ inter = np.logical_and(b1, b2).sum()
146
+ union = np.logical_or(b1, b2).sum()
147
+ return float(inter) / float(union) if union > 0 else 0.0
148
 
149
+ def mask_to_crop(orig_rgb: np.ndarray, alpha: np.ndarray) -> np.ndarray:
150
+ """Crop RGBA using alpha channel, with edge-only AA smoothing."""
151
+ h, w = orig_rgb.shape[:2]
152
+
153
+ rgba = np.zeros((h, w, 4), dtype=np.uint8)
154
+ rgba[:, :, :3] = orig_rgb
155
+ rgba[:, :, 3] = alpha
156
+
157
+ ys, xs = np.nonzero(alpha > 10)
158
+ if len(ys) == 0:
159
+ return rgba
160
+ y0, y1 = int(ys.min()), int(ys.max())
161
+ x0, x1 = int(xs.min()), int(xs.max())
162
+ pad = 6
163
+ y0, x0 = max(0, y0 - pad), max(0, x0 - pad)
164
+ y1, x1 = min(h - 1, y1 + pad), min(w - 1, x1 + pad)
165
+ return rgba[y0:y1+1, x0:x1+1]
166
 
167
  def upscale_4x(rgba: np.ndarray) -> np.ndarray:
168
+ """4x Lanczos upscale with unsharp masking + edge-only alpha AA."""
169
  h, w = rgba.shape[:2]
170
  new_w, new_h = w * 4, h * 4
171
 
172
+ # Upscale full RGBA with Lanczos4
173
  upscaled = cv2.resize(rgba, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
174
 
175
+ # Unsharp mask on RGB only
176
  rgb = upscaled[:, :, :3]
177
  blurred = cv2.GaussianBlur(rgb, (0, 0), sigmaX=1.0)
178
  rgb_sharp = cv2.addWeighted(rgb, 1.5, blurred, -0.5, 0)
179
  upscaled[:, :, :3] = rgb_sharp
180
 
181
+ # Edge-only AA on alpha: blur then re-harden interior
182
+ alpha = upscaled[:, :, 3].astype(np.float32)
183
+ alpha_blur = cv2.GaussianBlur(alpha, (5, 5), sigmaX=1.2)
184
+ # Keep interior fully opaque, only use blurred values at edges
185
+ interior = alpha > 240 # pixels that were solidly opaque
186
+ alpha_aa = np.where(interior, 255.0, alpha_blur)
187
+ upscaled[:, :, 3] = alpha_aa.clip(0, 255).astype(np.uint8)
188
+
189
  return upscaled
190
 
191
  def extract_assets(input_image):
 
205
  print(f">>> Image size: {w}x{h}, area: {img_area}", flush=True)
206
  pil_img = Image.fromarray(orig_rgb)
207
 
208
+ all_masks = []
209
  all_scores = []
210
 
211
  with torch.inference_mode():
 
219
 
220
  masks = out["masks"]
221
  scores = out["scores"]
222
+
223
+ # Check for raw logits
224
+ raw_logits = None
225
+ for logit_key in ["masks_logits", "low_res_masks", "logits", "mask_logits"]:
226
+ val = out.get(logit_key)
227
+ if val is not None and (not torch.is_tensor(val) or val.numel() > 0):
228
+ raw_logits = val
229
+ break
230
 
231
  if masks is None or len(masks) == 0:
232
+ print(f" [{concept}] No masks returned", flush=True)
233
  continue
234
 
235
  if torch.is_tensor(masks): masks = masks.float().cpu().numpy()
236
  if torch.is_tensor(scores): scores = scores.float().cpu().numpy()
237
+ if raw_logits is not None and torch.is_tensor(raw_logits):
238
+ raw_logits = raw_logits.float().cpu().numpy()
239
 
240
+ print(f" [{concept}] Found {len(masks)} masks", flush=True)
241
 
242
  for j in range(len(masks)):
243
+ m = masks[j]
244
+ while m.ndim > 2: m = m[0]
245
+ m_bool = m.astype(bool)
246
 
247
+ score = float(scores[j]) if scores.ndim > 0 else float(scores)
248
+ area = m_bool.sum()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
 
250
+ if score < 0.1 or area < 500 or area > img_area * 0.90:
251
+ print(f" mask[{j}] SKIPPED: score={score:.4f}, area={area}", flush=True)
252
  continue
253
 
254
+ # Build alpha: sigmoid logits for smooth edges + full opacity interior
255
+ if raw_logits is not None and j < len(raw_logits):
256
+ logit = raw_logits[j]
257
+ while logit.ndim > 2: logit = logit[0]
258
+ alpha_smooth = 1.0 / (1.0 + np.exp(-logit.astype(np.float32)))
259
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
260
+ dilated = cv2.dilate(m_bool.astype(np.uint8), kernel, iterations=2)
261
+ alpha_smooth = alpha_smooth * dilated
262
+ alpha_uint8 = (alpha_smooth * 255).clip(0, 255).astype(np.uint8)
263
+ alpha_mask = np.where(m_bool, np.uint8(255), alpha_uint8)
264
+ else:
265
+ alpha_mask = m_bool.astype(np.uint8) * 255
266
 
267
+ all_masks.append(alpha_mask)
268
  all_scores.append(score)
269
+ print(f" mask[{j}] KEPT: score={score:.4f}, area={area}", flush=True)
270
 
271
+ print(f">>> Total masks kept: {len(all_masks)}", flush=True)
272
 
273
+ if not all_masks:
274
  gr.Info("No assets found in this image. Try a different slide with more visual elements.")
275
+ print(">>> No masks passed filters, returning []", flush=True)
276
  return []
277
 
278
+ # Deduplicate by mask IoU
279
+ order = sorted(range(len(all_masks)), key=lambda i: all_scores[i], reverse=True)
280
  keep = []
281
  for i in order:
282
  dup = False
283
  for ki in keep:
284
+ if mask_iou(all_masks[i], all_masks[ki]) > 0.5:
285
  dup = True
286
  break
287
  if not dup:
288
  keep.append(i)
289
 
290
+ # Crop upscale (with edge AA) → save
291
  results = []
292
  global asset_counter
293
  for idx, ki in enumerate(keep):
294
+ crop = mask_to_crop(orig_rgb, all_masks[ki])
295
+ crop = upscale_4x(crop)
 
 
 
 
 
 
 
 
 
 
 
296
  asset_counter += 1
297
  lib_path = os.path.join(ASSET_LIBRARY_DIR, f"asset_{asset_counter:04d}.png")
298
+ Image.fromarray(crop, "RGBA").save(lib_path, format="PNG")
299
  results.append(lib_path)
300
 
301
  print(f">>> Returning {len(results)} assets (library total: {asset_counter})", flush=True)
requirements.txt CHANGED
@@ -11,5 +11,3 @@ gradio
11
  git+https://github.com/facebookresearch/sam3.git
12
  opencv-python-headless
13
  PyMuPDF
14
- rembg
15
- onnxruntime
 
11
  git+https://github.com/facebookresearch/sam3.git
12
  opencv-python-headless
13
  PyMuPDF