AI Agent commited on
Commit
32ac9f9
·
1 Parent(s): 956b060

Pipeline rewrite: SAM 3 boxes + rembg background removal + 4x Lanczos upscale

Browse files
Files changed (2) hide show
  1. app.py +80 -81
  2. requirements.txt +1 -0
app.py CHANGED
@@ -6,6 +6,7 @@ from PIL import Image
6
  import os
7
  import io
8
  import fitz # PyMuPDF
 
9
 
10
  # ── UNCONDITIONAL BFloat16 → Float16 Patch for T4 Turing GPUs ────
11
  # CRITICAL: torch.cuda.is_bf16_supported() returns True on T4 because CUDA
@@ -139,31 +140,22 @@ ASSET_LIBRARY_DIR = os.path.join(tempfile.gettempdir(), "sam3_library")
139
  os.makedirs(ASSET_LIBRARY_DIR, exist_ok=True)
140
  asset_counter = 0
141
 
142
- def mask_iou(m1: np.ndarray, m2: np.ndarray) -> float:
143
- # Threshold uint8 alpha arrays to binary for IoU comparison
144
- b1 = m1 > 128 if m1.dtype == np.uint8 else m1
145
- b2 = m2 > 128 if m2.dtype == np.uint8 else m2
146
- inter = np.logical_and(b1, b2).sum()
147
- union = np.logical_or(b1, b2).sum()
148
- return float(inter) / float(union) if union > 0 else 0.0
149
 
150
- def mask_to_crop(orig_rgb: np.ndarray, alpha: np.ndarray) -> np.ndarray:
151
- """Crop an RGBA image using a binary alpha mask."""
152
- h, w = orig_rgb.shape[:2]
153
-
154
- rgba = np.zeros((h, w, 4), dtype=np.uint8)
155
- rgba[:, :, :3] = orig_rgb
156
- rgba[:, :, 3] = alpha
157
-
158
- ys, xs = np.nonzero(alpha > 0)
159
- if len(ys) == 0:
160
- return rgba
161
- y0, y1 = int(ys.min()), int(ys.max())
162
- x0, x1 = int(xs.min()), int(xs.max())
163
- pad = 4
164
- y0, x0 = max(0, y0 - pad), max(0, x0 - pad)
165
- y1, x1 = min(h - 1, y1 + pad), min(w - 1, x1 + pad)
166
- return rgba[y0:y1+1, x0:x1+1]
167
 
168
  def upscale_4x(rgba: np.ndarray) -> np.ndarray:
169
  """4x Lanczos upscale with unsharp masking for crisp graphics."""
@@ -173,7 +165,7 @@ def upscale_4x(rgba: np.ndarray) -> np.ndarray:
173
  # Upscale full RGBA together with Lanczos4
174
  upscaled = cv2.resize(rgba, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
175
 
176
- # Unsharp mask on RGB only (preserve smooth alpha)
177
  rgb = upscaled[:, :, :3]
178
  blurred = cv2.GaussianBlur(rgb, (0, 0), sigmaX=1.0)
179
  rgb_sharp = cv2.addWeighted(rgb, 1.5, blurred, -0.5, 0)
@@ -198,7 +190,7 @@ def extract_assets(input_image):
198
  print(f">>> Image size: {w}x{h}, area: {img_area}", flush=True)
199
  pil_img = Image.fromarray(orig_rgb)
200
 
201
- all_masks = []
202
  all_scores = []
203
 
204
  with torch.inference_mode():
@@ -210,95 +202,102 @@ def extract_assets(input_image):
210
  print(f">>> Processing concept: '{concept}'...", flush=True)
211
  out = processor.set_text_prompt(state=state, prompt=concept)
212
 
213
- # Log all output keys to find raw logits
214
- print(f" [{concept}] Output keys: {list(out.keys())}", flush=True)
215
-
216
  masks = out["masks"]
217
  scores = out["scores"]
218
-
219
- # Check for raw logits (sub-pixel smooth edges)
220
- raw_logits = None
221
- for logit_key in ["masks_logits", "low_res_masks", "logits", "mask_logits"]:
222
- val = out.get(logit_key)
223
- if val is not None and (not torch.is_tensor(val) or val.numel() > 0):
224
- raw_logits = val
225
- break
226
- if raw_logits is not None:
227
- print(f" [{concept}] RAW LOGITS found: shape={raw_logits.shape}, dtype={raw_logits.dtype}", flush=True)
228
 
229
  if masks is None or len(masks) == 0:
230
- print(f" [{concept}] No masks returned", flush=True)
231
  continue
232
-
233
  if torch.is_tensor(masks): masks = masks.float().cpu().numpy()
234
  if torch.is_tensor(scores): scores = scores.float().cpu().numpy()
235
- if raw_logits is not None and torch.is_tensor(raw_logits):
236
- raw_logits = raw_logits.float().cpu().numpy()
237
 
238
- print(f" [{concept}] MASK shape: {masks.shape}, SCORES max: {scores.max():.4f}, min: {scores.min():.4f}", flush=True)
239
 
240
  for j in range(len(masks)):
241
- m = masks[j]
242
- while m.ndim > 2: m = m[0]
243
- m_bool = m.astype(bool)
244
-
245
  score = float(scores[j]) if scores.ndim > 0 else float(scores)
246
- area = m_bool.sum()
247
 
248
- if score < 0.1 or area < 500 or area > img_area * 0.90:
249
- print(f" mask[{j}] SKIPPED: score={score:.4f}, area={area}", flush=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  continue
251
 
252
- # Build alpha: use logits for smooth edges if available
253
- if raw_logits is not None and j < len(raw_logits):
254
- logit = raw_logits[j]
255
- while logit.ndim > 2: logit = logit[0]
256
- alpha_smooth = 1.0 / (1.0 + np.exp(-logit.astype(np.float32)))
257
- kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
258
- dilated = cv2.dilate(m_bool.astype(np.uint8), kernel, iterations=2)
259
- alpha_smooth = alpha_smooth * dilated
260
- alpha_uint8 = (alpha_smooth * 255).clip(0, 255).astype(np.uint8)
261
- alpha_mask = np.where(m_bool, np.uint8(255), alpha_uint8)
262
- else:
263
- alpha_mask = m_bool.astype(np.uint8) * 255
264
 
265
- all_masks.append(alpha_mask)
266
  all_scores.append(score)
267
- print(f" mask[{j}] KEPT: score={score:.4f}, area={area}", flush=True)
268
-
269
- print(f">>> Total masks kept: {len(all_masks)}", flush=True)
270
 
271
- if not all_masks:
272
  gr.Info("No assets found in this image. Try a different slide with more visual elements.")
273
- print(">>> No masks passed filters, returning []", flush=True)
274
  return []
275
-
276
- # Deduplicate
277
- order = sorted(range(len(all_masks)), key=lambda i: all_scores[i], reverse=True)
278
  keep = []
279
  for i in order:
280
  dup = False
281
  for ki in keep:
282
- if mask_iou(all_masks[i], all_masks[ki]) > 0.5:
283
  dup = True
284
  break
285
  if not dup:
286
  keep.append(i)
287
-
288
- # Crop to PNG files and save to library
289
  results = []
290
  global asset_counter
291
  for idx, ki in enumerate(keep):
292
- crop = mask_to_crop(orig_rgb, all_masks[ki])
293
- # 4x upscale with Lanczos + sharpening
294
- crop = upscale_4x(crop)
295
- # Save to persistent library
 
 
 
 
 
 
 
 
 
296
  asset_counter += 1
297
  lib_path = os.path.join(ASSET_LIBRARY_DIR, f"asset_{asset_counter:04d}.png")
298
- Image.fromarray(crop, "RGBA").save(lib_path, format="PNG")
299
  results.append(lib_path)
300
 
301
- print(f">>> Returning {len(results)} cropped PNG assets (library total: {asset_counter})", flush=True)
302
  return results
303
 
304
  except Exception as e:
 
6
  import os
7
  import io
8
  import fitz # PyMuPDF
9
+ from rembg import remove as rembg_remove, new_session as rembg_new_session
10
 
11
  # ── UNCONDITIONAL BFloat16 → Float16 Patch for T4 Turing GPUs ────
12
  # CRITICAL: torch.cuda.is_bf16_supported() returns True on T4 because CUDA
 
140
  os.makedirs(ASSET_LIBRARY_DIR, exist_ok=True)
141
  asset_counter = 0
142
 
143
+ # Initialize rembg session (downloads model once, reuses for all requests)
144
+ print("Loading background removal model...", flush=True)
145
+ rembg_session = rembg_new_session(model_name="isnet-general-use")
146
+ print("Background removal model loaded.", flush=True)
 
 
 
147
 
148
+ def box_iou(b1, b2):
149
+ """IoU between two boxes [x0, y0, x1, y1]."""
150
+ x0 = max(b1[0], b2[0])
151
+ y0 = max(b1[1], b2[1])
152
+ x1 = min(b1[2], b2[2])
153
+ y1 = min(b1[3], b2[3])
154
+ inter = max(0, x1 - x0) * max(0, y1 - y0)
155
+ a1 = (b1[2] - b1[0]) * (b1[3] - b1[1])
156
+ a2 = (b2[2] - b2[0]) * (b2[3] - b2[1])
157
+ union = a1 + a2 - inter
158
+ return inter / union if union > 0 else 0.0
 
 
 
 
 
 
159
 
160
  def upscale_4x(rgba: np.ndarray) -> np.ndarray:
161
  """4x Lanczos upscale with unsharp masking for crisp graphics."""
 
165
  # Upscale full RGBA together with Lanczos4
166
  upscaled = cv2.resize(rgba, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
167
 
168
+ # Unsharp mask on RGB only (preserve alpha from rembg)
169
  rgb = upscaled[:, :, :3]
170
  blurred = cv2.GaussianBlur(rgb, (0, 0), sigmaX=1.0)
171
  rgb_sharp = cv2.addWeighted(rgb, 1.5, blurred, -0.5, 0)
 
190
  print(f">>> Image size: {w}x{h}, area: {img_area}", flush=True)
191
  pil_img = Image.fromarray(orig_rgb)
192
 
193
+ all_boxes = []
194
  all_scores = []
195
 
196
  with torch.inference_mode():
 
202
  print(f">>> Processing concept: '{concept}'...", flush=True)
203
  out = processor.set_text_prompt(state=state, prompt=concept)
204
 
 
 
 
205
  masks = out["masks"]
206
  scores = out["scores"]
207
+ boxes = out.get("boxes")
 
 
 
 
 
 
 
 
 
208
 
209
  if masks is None or len(masks) == 0:
210
+ print(f" [{concept}] No detections", flush=True)
211
  continue
212
+
213
  if torch.is_tensor(masks): masks = masks.float().cpu().numpy()
214
  if torch.is_tensor(scores): scores = scores.float().cpu().numpy()
215
+ if boxes is not None and torch.is_tensor(boxes): boxes = boxes.float().cpu().numpy()
 
216
 
217
+ print(f" [{concept}] Found {len(masks)} detections, boxes: {boxes.shape if boxes is not None else 'None'}", flush=True)
218
 
219
  for j in range(len(masks)):
 
 
 
 
220
  score = float(scores[j]) if scores.ndim > 0 else float(scores)
 
221
 
222
+ # Get bounding box: prefer SAM 3's boxes, fallback to mask bbox
223
+ if boxes is not None and j < len(boxes):
224
+ box = boxes[j].flatten()
225
+ # SAM 3 boxes format: get x0,y0,x1,y1
226
+ if len(box) >= 4:
227
+ x0, y0, x1, y1 = int(box[0]), int(box[1]), int(box[2]), int(box[3])
228
+ else:
229
+ continue
230
+ else:
231
+ # Fallback: derive box from mask
232
+ m = masks[j]
233
+ while m.ndim > 2: m = m[0]
234
+ ys, xs = np.nonzero(m > 0.5)
235
+ if len(ys) == 0: continue
236
+ x0, y0, x1, y1 = int(xs.min()), int(ys.min()), int(xs.max()), int(ys.max())
237
+
238
+ # Validate box
239
+ box_w = x1 - x0
240
+ box_h = y1 - y0
241
+ box_area = box_w * box_h
242
+
243
+ if score < 0.1 or box_area < 500 or box_area > img_area * 0.90:
244
+ print(f" det[{j}] SKIPPED: score={score:.4f}, area={box_area}", flush=True)
245
  continue
246
 
247
+ # Add padding (10% of box size)
248
+ pad_x = max(10, int(box_w * 0.10))
249
+ pad_y = max(10, int(box_h * 0.10))
250
+ bx0 = max(0, x0 - pad_x)
251
+ by0 = max(0, y0 - pad_y)
252
+ bx1 = min(w, x1 + pad_x)
253
+ by1 = min(h, y1 + pad_y)
 
 
 
 
 
254
 
255
+ all_boxes.append([bx0, by0, bx1, by1])
256
  all_scores.append(score)
257
+ print(f" det[{j}] KEPT: score={score:.4f}, box=[{bx0},{by0},{bx1},{by1}]", flush=True)
258
+
259
+ print(f">>> Total detections kept: {len(all_boxes)}", flush=True)
260
 
261
+ if not all_boxes:
262
  gr.Info("No assets found in this image. Try a different slide with more visual elements.")
263
+ print(">>> No detections passed filters, returning []", flush=True)
264
  return []
265
+
266
+ # Deduplicate by box IoU
267
+ order = sorted(range(len(all_boxes)), key=lambda i: all_scores[i], reverse=True)
268
  keep = []
269
  for i in order:
270
  dup = False
271
  for ki in keep:
272
+ if box_iou(all_boxes[i], all_boxes[ki]) > 0.5:
273
  dup = True
274
  break
275
  if not dup:
276
  keep.append(i)
277
+
278
+ # For each kept box: crop rembg → upscale → save
279
  results = []
280
  global asset_counter
281
  for idx, ki in enumerate(keep):
282
+ bx0, by0, bx1, by1 = all_boxes[ki]
283
+ crop_rgb = orig_rgb[by0:by1, bx0:bx1]
284
+
285
+ # Background removal with rembg (clean alpha matte)
286
+ crop_pil = Image.fromarray(crop_rgb)
287
+ rgba_pil = rembg_remove(crop_pil, session=rembg_session)
288
+ rgba_np = np.array(rgba_pil)
289
+ print(f" crop[{idx}] rembg done: {rgba_np.shape}", flush=True)
290
+
291
+ # 4x upscale
292
+ rgba_np = upscale_4x(rgba_np)
293
+
294
+ # Save to library
295
  asset_counter += 1
296
  lib_path = os.path.join(ASSET_LIBRARY_DIR, f"asset_{asset_counter:04d}.png")
297
+ Image.fromarray(rgba_np, "RGBA").save(lib_path, format="PNG")
298
  results.append(lib_path)
299
 
300
+ print(f">>> Returning {len(results)} assets (library total: {asset_counter})", flush=True)
301
  return results
302
 
303
  except Exception as e:
requirements.txt CHANGED
@@ -11,3 +11,4 @@ gradio
11
  git+https://github.com/facebookresearch/sam3.git
12
  opencv-python-headless
13
  PyMuPDF
 
 
11
  git+https://github.com/facebookresearch/sam3.git
12
  opencv-python-headless
13
  PyMuPDF
14
+ rembg