Spaces:
Sleeping
Sleeping
AI Agent commited on
Commit ·
32ac9f9
1
Parent(s): 956b060
Pipeline rewrite: SAM 3 boxes + rembg background removal + 4x Lanczos upscale
Browse files- app.py +80 -81
- requirements.txt +1 -0
app.py
CHANGED
|
@@ -6,6 +6,7 @@ from PIL import Image
|
|
| 6 |
import os
|
| 7 |
import io
|
| 8 |
import fitz # PyMuPDF
|
|
|
|
| 9 |
|
| 10 |
# ── UNCONDITIONAL BFloat16 → Float16 Patch for T4 Turing GPUs ────
|
| 11 |
# CRITICAL: torch.cuda.is_bf16_supported() returns True on T4 because CUDA
|
|
@@ -139,31 +140,22 @@ ASSET_LIBRARY_DIR = os.path.join(tempfile.gettempdir(), "sam3_library")
|
|
| 139 |
os.makedirs(ASSET_LIBRARY_DIR, exist_ok=True)
|
| 140 |
asset_counter = 0
|
| 141 |
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
inter = np.logical_and(b1, b2).sum()
|
| 147 |
-
union = np.logical_or(b1, b2).sum()
|
| 148 |
-
return float(inter) / float(union) if union > 0 else 0.0
|
| 149 |
|
| 150 |
-
def
|
| 151 |
-
"""
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
y0, y1 = int(ys.min()), int(ys.max())
|
| 162 |
-
x0, x1 = int(xs.min()), int(xs.max())
|
| 163 |
-
pad = 4
|
| 164 |
-
y0, x0 = max(0, y0 - pad), max(0, x0 - pad)
|
| 165 |
-
y1, x1 = min(h - 1, y1 + pad), min(w - 1, x1 + pad)
|
| 166 |
-
return rgba[y0:y1+1, x0:x1+1]
|
| 167 |
|
| 168 |
def upscale_4x(rgba: np.ndarray) -> np.ndarray:
|
| 169 |
"""4x Lanczos upscale with unsharp masking for crisp graphics."""
|
|
@@ -173,7 +165,7 @@ def upscale_4x(rgba: np.ndarray) -> np.ndarray:
|
|
| 173 |
# Upscale full RGBA together with Lanczos4
|
| 174 |
upscaled = cv2.resize(rgba, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
|
| 175 |
|
| 176 |
-
# Unsharp mask on RGB only (preserve
|
| 177 |
rgb = upscaled[:, :, :3]
|
| 178 |
blurred = cv2.GaussianBlur(rgb, (0, 0), sigmaX=1.0)
|
| 179 |
rgb_sharp = cv2.addWeighted(rgb, 1.5, blurred, -0.5, 0)
|
|
@@ -198,7 +190,7 @@ def extract_assets(input_image):
|
|
| 198 |
print(f">>> Image size: {w}x{h}, area: {img_area}", flush=True)
|
| 199 |
pil_img = Image.fromarray(orig_rgb)
|
| 200 |
|
| 201 |
-
|
| 202 |
all_scores = []
|
| 203 |
|
| 204 |
with torch.inference_mode():
|
|
@@ -210,95 +202,102 @@ def extract_assets(input_image):
|
|
| 210 |
print(f">>> Processing concept: '{concept}'...", flush=True)
|
| 211 |
out = processor.set_text_prompt(state=state, prompt=concept)
|
| 212 |
|
| 213 |
-
# Log all output keys to find raw logits
|
| 214 |
-
print(f" [{concept}] Output keys: {list(out.keys())}", flush=True)
|
| 215 |
-
|
| 216 |
masks = out["masks"]
|
| 217 |
scores = out["scores"]
|
| 218 |
-
|
| 219 |
-
# Check for raw logits (sub-pixel smooth edges)
|
| 220 |
-
raw_logits = None
|
| 221 |
-
for logit_key in ["masks_logits", "low_res_masks", "logits", "mask_logits"]:
|
| 222 |
-
val = out.get(logit_key)
|
| 223 |
-
if val is not None and (not torch.is_tensor(val) or val.numel() > 0):
|
| 224 |
-
raw_logits = val
|
| 225 |
-
break
|
| 226 |
-
if raw_logits is not None:
|
| 227 |
-
print(f" [{concept}] RAW LOGITS found: shape={raw_logits.shape}, dtype={raw_logits.dtype}", flush=True)
|
| 228 |
|
| 229 |
if masks is None or len(masks) == 0:
|
| 230 |
-
print(f" [{concept}] No
|
| 231 |
continue
|
| 232 |
-
|
| 233 |
if torch.is_tensor(masks): masks = masks.float().cpu().numpy()
|
| 234 |
if torch.is_tensor(scores): scores = scores.float().cpu().numpy()
|
| 235 |
-
if
|
| 236 |
-
raw_logits = raw_logits.float().cpu().numpy()
|
| 237 |
|
| 238 |
-
print(f" [{concept}]
|
| 239 |
|
| 240 |
for j in range(len(masks)):
|
| 241 |
-
m = masks[j]
|
| 242 |
-
while m.ndim > 2: m = m[0]
|
| 243 |
-
m_bool = m.astype(bool)
|
| 244 |
-
|
| 245 |
score = float(scores[j]) if scores.ndim > 0 else float(scores)
|
| 246 |
-
area = m_bool.sum()
|
| 247 |
|
| 248 |
-
|
| 249 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
continue
|
| 251 |
|
| 252 |
-
#
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
alpha_smooth = alpha_smooth * dilated
|
| 260 |
-
alpha_uint8 = (alpha_smooth * 255).clip(0, 255).astype(np.uint8)
|
| 261 |
-
alpha_mask = np.where(m_bool, np.uint8(255), alpha_uint8)
|
| 262 |
-
else:
|
| 263 |
-
alpha_mask = m_bool.astype(np.uint8) * 255
|
| 264 |
|
| 265 |
-
|
| 266 |
all_scores.append(score)
|
| 267 |
-
print(f"
|
| 268 |
-
|
| 269 |
-
print(f">>> Total
|
| 270 |
|
| 271 |
-
if not
|
| 272 |
gr.Info("No assets found in this image. Try a different slide with more visual elements.")
|
| 273 |
-
print(">>> No
|
| 274 |
return []
|
| 275 |
-
|
| 276 |
-
# Deduplicate
|
| 277 |
-
order = sorted(range(len(
|
| 278 |
keep = []
|
| 279 |
for i in order:
|
| 280 |
dup = False
|
| 281 |
for ki in keep:
|
| 282 |
-
if
|
| 283 |
dup = True
|
| 284 |
break
|
| 285 |
if not dup:
|
| 286 |
keep.append(i)
|
| 287 |
-
|
| 288 |
-
#
|
| 289 |
results = []
|
| 290 |
global asset_counter
|
| 291 |
for idx, ki in enumerate(keep):
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
asset_counter += 1
|
| 297 |
lib_path = os.path.join(ASSET_LIBRARY_DIR, f"asset_{asset_counter:04d}.png")
|
| 298 |
-
Image.fromarray(
|
| 299 |
results.append(lib_path)
|
| 300 |
|
| 301 |
-
print(f">>> Returning {len(results)}
|
| 302 |
return results
|
| 303 |
|
| 304 |
except Exception as e:
|
|
|
|
| 6 |
import os
|
| 7 |
import io
|
| 8 |
import fitz # PyMuPDF
|
| 9 |
+
from rembg import remove as rembg_remove, new_session as rembg_new_session
|
| 10 |
|
| 11 |
# ── UNCONDITIONAL BFloat16 → Float16 Patch for T4 Turing GPUs ────
|
| 12 |
# CRITICAL: torch.cuda.is_bf16_supported() returns True on T4 because CUDA
|
|
|
|
| 140 |
os.makedirs(ASSET_LIBRARY_DIR, exist_ok=True)
|
| 141 |
asset_counter = 0
|
| 142 |
|
| 143 |
+
# Initialize rembg session (downloads model once, reuses for all requests)
|
| 144 |
+
print("Loading background removal model...", flush=True)
|
| 145 |
+
rembg_session = rembg_new_session(model_name="isnet-general-use")
|
| 146 |
+
print("Background removal model loaded.", flush=True)
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
+
def box_iou(b1, b2):
|
| 149 |
+
"""IoU between two boxes [x0, y0, x1, y1]."""
|
| 150 |
+
x0 = max(b1[0], b2[0])
|
| 151 |
+
y0 = max(b1[1], b2[1])
|
| 152 |
+
x1 = min(b1[2], b2[2])
|
| 153 |
+
y1 = min(b1[3], b2[3])
|
| 154 |
+
inter = max(0, x1 - x0) * max(0, y1 - y0)
|
| 155 |
+
a1 = (b1[2] - b1[0]) * (b1[3] - b1[1])
|
| 156 |
+
a2 = (b2[2] - b2[0]) * (b2[3] - b2[1])
|
| 157 |
+
union = a1 + a2 - inter
|
| 158 |
+
return inter / union if union > 0 else 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
|
| 160 |
def upscale_4x(rgba: np.ndarray) -> np.ndarray:
|
| 161 |
"""4x Lanczos upscale with unsharp masking for crisp graphics."""
|
|
|
|
| 165 |
# Upscale full RGBA together with Lanczos4
|
| 166 |
upscaled = cv2.resize(rgba, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
|
| 167 |
|
| 168 |
+
# Unsharp mask on RGB only (preserve alpha from rembg)
|
| 169 |
rgb = upscaled[:, :, :3]
|
| 170 |
blurred = cv2.GaussianBlur(rgb, (0, 0), sigmaX=1.0)
|
| 171 |
rgb_sharp = cv2.addWeighted(rgb, 1.5, blurred, -0.5, 0)
|
|
|
|
| 190 |
print(f">>> Image size: {w}x{h}, area: {img_area}", flush=True)
|
| 191 |
pil_img = Image.fromarray(orig_rgb)
|
| 192 |
|
| 193 |
+
all_boxes = []
|
| 194 |
all_scores = []
|
| 195 |
|
| 196 |
with torch.inference_mode():
|
|
|
|
| 202 |
print(f">>> Processing concept: '{concept}'...", flush=True)
|
| 203 |
out = processor.set_text_prompt(state=state, prompt=concept)
|
| 204 |
|
|
|
|
|
|
|
|
|
|
| 205 |
masks = out["masks"]
|
| 206 |
scores = out["scores"]
|
| 207 |
+
boxes = out.get("boxes")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
|
| 209 |
if masks is None or len(masks) == 0:
|
| 210 |
+
print(f" [{concept}] No detections", flush=True)
|
| 211 |
continue
|
| 212 |
+
|
| 213 |
if torch.is_tensor(masks): masks = masks.float().cpu().numpy()
|
| 214 |
if torch.is_tensor(scores): scores = scores.float().cpu().numpy()
|
| 215 |
+
if boxes is not None and torch.is_tensor(boxes): boxes = boxes.float().cpu().numpy()
|
|
|
|
| 216 |
|
| 217 |
+
print(f" [{concept}] Found {len(masks)} detections, boxes: {boxes.shape if boxes is not None else 'None'}", flush=True)
|
| 218 |
|
| 219 |
for j in range(len(masks)):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
score = float(scores[j]) if scores.ndim > 0 else float(scores)
|
|
|
|
| 221 |
|
| 222 |
+
# Get bounding box: prefer SAM 3's boxes, fallback to mask bbox
|
| 223 |
+
if boxes is not None and j < len(boxes):
|
| 224 |
+
box = boxes[j].flatten()
|
| 225 |
+
# SAM 3 boxes format: get x0,y0,x1,y1
|
| 226 |
+
if len(box) >= 4:
|
| 227 |
+
x0, y0, x1, y1 = int(box[0]), int(box[1]), int(box[2]), int(box[3])
|
| 228 |
+
else:
|
| 229 |
+
continue
|
| 230 |
+
else:
|
| 231 |
+
# Fallback: derive box from mask
|
| 232 |
+
m = masks[j]
|
| 233 |
+
while m.ndim > 2: m = m[0]
|
| 234 |
+
ys, xs = np.nonzero(m > 0.5)
|
| 235 |
+
if len(ys) == 0: continue
|
| 236 |
+
x0, y0, x1, y1 = int(xs.min()), int(ys.min()), int(xs.max()), int(ys.max())
|
| 237 |
+
|
| 238 |
+
# Validate box
|
| 239 |
+
box_w = x1 - x0
|
| 240 |
+
box_h = y1 - y0
|
| 241 |
+
box_area = box_w * box_h
|
| 242 |
+
|
| 243 |
+
if score < 0.1 or box_area < 500 or box_area > img_area * 0.90:
|
| 244 |
+
print(f" det[{j}] SKIPPED: score={score:.4f}, area={box_area}", flush=True)
|
| 245 |
continue
|
| 246 |
|
| 247 |
+
# Add padding (10% of box size)
|
| 248 |
+
pad_x = max(10, int(box_w * 0.10))
|
| 249 |
+
pad_y = max(10, int(box_h * 0.10))
|
| 250 |
+
bx0 = max(0, x0 - pad_x)
|
| 251 |
+
by0 = max(0, y0 - pad_y)
|
| 252 |
+
bx1 = min(w, x1 + pad_x)
|
| 253 |
+
by1 = min(h, y1 + pad_y)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
|
| 255 |
+
all_boxes.append([bx0, by0, bx1, by1])
|
| 256 |
all_scores.append(score)
|
| 257 |
+
print(f" det[{j}] KEPT: score={score:.4f}, box=[{bx0},{by0},{bx1},{by1}]", flush=True)
|
| 258 |
+
|
| 259 |
+
print(f">>> Total detections kept: {len(all_boxes)}", flush=True)
|
| 260 |
|
| 261 |
+
if not all_boxes:
|
| 262 |
gr.Info("No assets found in this image. Try a different slide with more visual elements.")
|
| 263 |
+
print(">>> No detections passed filters, returning []", flush=True)
|
| 264 |
return []
|
| 265 |
+
|
| 266 |
+
# Deduplicate by box IoU
|
| 267 |
+
order = sorted(range(len(all_boxes)), key=lambda i: all_scores[i], reverse=True)
|
| 268 |
keep = []
|
| 269 |
for i in order:
|
| 270 |
dup = False
|
| 271 |
for ki in keep:
|
| 272 |
+
if box_iou(all_boxes[i], all_boxes[ki]) > 0.5:
|
| 273 |
dup = True
|
| 274 |
break
|
| 275 |
if not dup:
|
| 276 |
keep.append(i)
|
| 277 |
+
|
| 278 |
+
# For each kept box: crop → rembg → upscale → save
|
| 279 |
results = []
|
| 280 |
global asset_counter
|
| 281 |
for idx, ki in enumerate(keep):
|
| 282 |
+
bx0, by0, bx1, by1 = all_boxes[ki]
|
| 283 |
+
crop_rgb = orig_rgb[by0:by1, bx0:bx1]
|
| 284 |
+
|
| 285 |
+
# Background removal with rembg (clean alpha matte)
|
| 286 |
+
crop_pil = Image.fromarray(crop_rgb)
|
| 287 |
+
rgba_pil = rembg_remove(crop_pil, session=rembg_session)
|
| 288 |
+
rgba_np = np.array(rgba_pil)
|
| 289 |
+
print(f" crop[{idx}] rembg done: {rgba_np.shape}", flush=True)
|
| 290 |
+
|
| 291 |
+
# 4x upscale
|
| 292 |
+
rgba_np = upscale_4x(rgba_np)
|
| 293 |
+
|
| 294 |
+
# Save to library
|
| 295 |
asset_counter += 1
|
| 296 |
lib_path = os.path.join(ASSET_LIBRARY_DIR, f"asset_{asset_counter:04d}.png")
|
| 297 |
+
Image.fromarray(rgba_np, "RGBA").save(lib_path, format="PNG")
|
| 298 |
results.append(lib_path)
|
| 299 |
|
| 300 |
+
print(f">>> Returning {len(results)} assets (library total: {asset_counter})", flush=True)
|
| 301 |
return results
|
| 302 |
|
| 303 |
except Exception as e:
|
requirements.txt
CHANGED
|
@@ -11,3 +11,4 @@ gradio
|
|
| 11 |
git+https://github.com/facebookresearch/sam3.git
|
| 12 |
opencv-python-headless
|
| 13 |
PyMuPDF
|
|
|
|
|
|
| 11 |
git+https://github.com/facebookresearch/sam3.git
|
| 12 |
opencv-python-headless
|
| 13 |
PyMuPDF
|
| 14 |
+
rembg
|