import gradio as gr import numpy as np import cv2 import torch from PIL import Image import os import io import fitz # PyMuPDF # ── UNCONDITIONAL BFloat16 → Float16 Patch for T4 Turing GPUs ──── # CRITICAL: torch.cuda.is_bf16_supported() returns True on T4 because CUDA # can *emulate* bfloat16 in software, but the actual kernels crash on mixed # dtype operations (linear, conv2d). We MUST patch unconditionally. if torch.cuda.is_available(): # 1. Intercept ALL autocast entry points to force float16 import torch.amp.autocast_mode _OriginalAutocast = torch.amp.autocast_mode.autocast class _Fp16Autocast(_OriginalAutocast): def __init__(self, device_type, dtype=None, *args, **kwargs): # Intercept Meta's bfloat16 request and force float16 for Turing support if dtype == torch.bfloat16: dtype = torch.float16 super().__init__(device_type, dtype=dtype, *args, **kwargs) torch.autocast = _Fp16Autocast torch.amp.autocast_mode.autocast = _Fp16Autocast if hasattr(torch.amp, 'autocast'): torch.amp.autocast = _Fp16Autocast if hasattr(torch.cuda.amp, 'autocast'): torch.cuda.amp.autocast = _Fp16Autocast # 2. Patch core Math Kernels to deterministically auto-cast mismatched float matrices natively. # This acts as our unbreakable "AMP Engine" that never drops state inside deep transformer blocks! import torch.nn.functional as F orig_linear = F.linear def patched_linear(input, weight, bias=None): if input.is_floating_point() and input.dtype != weight.dtype: input = input.to(weight.dtype) return orig_linear(input, weight, bias) F.linear = patched_linear orig_conv2d = F.conv2d def patched_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): if input.is_floating_point() and input.dtype != weight.dtype: input = input.to(weight.dtype) return orig_conv2d(input, weight, bias, stride, padding, dilation, groups) F.conv2d = patched_conv2d # 3. Patch torchvision.ops.roi_align — Meta's geometry_encoders.py # calls boxes_xyxy.float() which creates float32 while img_feats is float16. try: import torchvision.ops orig_roi_align = torchvision.ops.roi_align def patched_roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1, aligned=False): # Handle Tensor, list, or tuple (Meta uses .unbind() which returns a tuple!) if isinstance(boxes, torch.Tensor): if input.is_floating_point() and boxes.dtype != input.dtype: boxes = boxes.to(input.dtype) elif isinstance(boxes, (list, tuple)): boxes = [b.to(input.dtype) if isinstance(b, torch.Tensor) and b.dtype != input.dtype else b for b in boxes] return orig_roi_align(input, boxes, output_size, spatial_scale, sampling_ratio, aligned) torchvision.ops.roi_align = patched_roi_align except ImportError: pass # 4. Patch layer_norm / group_norm — common ViT dtype mismatch points orig_layer_norm = F.layer_norm def patched_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5): if weight is not None and input.is_floating_point() and input.dtype != weight.dtype: input = input.to(weight.dtype) return orig_layer_norm(input, normalized_shape, weight, bias, eps) F.layer_norm = patched_layer_norm # ── Ensure SAM 3 Checkpoint is downloaded ──────────────────────── # (HuggingFace Spaces can use the hf_hub_download mechanism) from huggingface_hub import hf_hub_download # ── HF Token Authentication ──────────────────────────────────────── print("Downloading SAM 3 model...") hf_token = os.environ.get("HF_TOKEN") ckpt_path = hf_hub_download(repo_id="facebook/sam3", filename="sam3.pt", token=hf_token) # ── Monkey Patch SAM 3 CUDA Hardcoding Bug ─────────────────────── # Meta's SAM 3 repo hardcodes `device="cuda"` in many places! # This intercepts common PyTorch tensor constructors to force "cpu" if no GPU is available. if not torch.cuda.is_available(): import functools patch_funcs = ['zeros', 'arange', 'tensor', 'ones', 'empty', 'randn', 'full', 'linspace'] for name in patch_funcs: if hasattr(torch, name): orig_fn = getattr(torch, name) @functools.wraps(orig_fn) def patched_fn(*args, __orig_fn=orig_fn, **kwargs): if 'device' in kwargs and str(kwargs['device']).startswith('cuda'): kwargs['device'] = 'cpu' return __orig_fn(*args, **kwargs) setattr(torch, name, patched_fn) # ── SAM 3 Imports ──────────────────────────────────────────────── try: from sam3.model_builder import build_sam3_image_model from sam3.model.sam3_image_processor import Sam3Processor model_installed = True except ImportError: model_installed = False print("SAM 3 not installed yet (will be installed by requirements.txt).") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") processor = None if model_installed: print(f"Loading SAM 3 onto {device}...") model = build_sam3_image_model(checkpoint_path=ckpt_path) # Cast to float16 — T4 has native float16 Tensor Core acceleration. # bfloat16 hangs (software emulated on Turing), float32 produced zero masks. model.half() # Diagnostic: verify checkpoint loaded correctly total_params = sum(p.numel() for p in model.parameters()) print(f"Total parameters: {total_params:,}", flush=True) sample_dtype = next(model.parameters()).dtype print(f"Model dtype: {sample_dtype}", flush=True) processor = Sam3Processor(model) if not torch.cuda.is_available(): processor.device = "cpu" print("Model loaded successfully.") # Two-pass concept detection: parent (composite) + child (individual) elements # Excludes 'text block' (user doesn't want text) and 'logo' (picks up watermarks) PARENT_CONCEPTS = [ "chart", "diagram", "graph", "table", "illustration", "infographic", "figure", "photo", "picture", "image" ] CHILD_CONCEPTS = [ "icon", "symbol", "arrow", "bar", "person", "object", "button", "badge", "circle", "label" ] ALL_CONCEPTS = PARENT_CONCEPTS + CHILD_CONCEPTS # Persistent asset library import tempfile, zipfile ASSET_LIBRARY_DIR = os.path.join(tempfile.gettempdir(), "sam3_library") os.makedirs(ASSET_LIBRARY_DIR, exist_ok=True) asset_counter = 0 def box_iou(b1, b2): """IoU between two boxes [x0, y0, x1, y1].""" x0 = max(b1[0], b2[0]) y0 = max(b1[1], b2[1]) x1 = min(b1[2], b2[2]) y1 = min(b1[3], b2[3]) inter = max(0, x1 - x0) * max(0, y1 - y0) a1 = (b1[2] - b1[0]) * (b1[3] - b1[1]) a2 = (b2[2] - b2[0]) * (b2[3] - b2[1]) union = a1 + a2 - inter return inter / union if union > 0 else 0.0 def remove_color_bg(crop_rgb: np.ndarray, bg_color=(255, 255, 255), tolerance=30) -> np.ndarray: """Remove background by flood-filling from edges. Only removes pixels CONNECTED to the border that match bg_color. White/colored areas INSIDE objects are preserved. """ h, w = crop_rgb.shape[:2] if h < 2 or w < 2: rgba = np.zeros((h, w, 4), dtype=np.uint8) rgba[:, :, :3] = crop_rgb rgba[:, :, 3] = 255 return rgba # Create a mask of pixels matching the background color within tolerance bg = np.array(bg_color, dtype=np.float32) diff = np.sqrt(np.sum((crop_rgb.astype(np.float32) - bg) ** 2, axis=2)) color_match = (diff < tolerance).astype(np.uint8) * 255 # Flood fill from all border pixels to find CONNECTED background # Use floodFill on a padded version to handle edge connectivity flood_mask = np.zeros((h + 2, w + 2), dtype=np.uint8) bg_connected = np.zeros((h, w), dtype=np.uint8) # Seed from all border pixels that match background color border_seeds = [] for x in range(w): if color_match[0, x]: border_seeds.append((x, 0)) if color_match[h-1, x]: border_seeds.append((x, h-1)) for y in range(h): if color_match[y, 0]: border_seeds.append((0, y)) if color_match[y, w-1]: border_seeds.append((w-1, y)) # Flood fill from each border seed for sx, sy in border_seeds: if bg_connected[sy, sx] == 0 and color_match[sy, sx]: flood_mask[:] = 0 cv2.floodFill(color_match.copy(), flood_mask, (sx, sy), 128, loDiff=0, upDiff=0, flags=cv2.FLOODFILL_MASK_ONLY | (8 << 8)) # flood_mask has 1s where the fill reached (in the +1 padded area) bg_connected |= flood_mask[1:-1, 1:-1] # Alpha: 255 for foreground, 0 for connected background alpha = np.where(bg_connected > 0, np.uint8(0), np.uint8(255)) # Slight edge AA: blur alpha then re-clamp interior alpha_f = alpha.astype(np.float32) alpha_blur = cv2.GaussianBlur(alpha_f, (3, 3), sigmaX=0.8) interior = alpha > 240 alpha_aa = np.where(interior, 255.0, alpha_blur) alpha = alpha_aa.clip(0, 255).astype(np.uint8) # Build RGBA rgba = np.zeros((h, w, 4), dtype=np.uint8) rgba[:, :, :3] = crop_rgb rgba[:, :, 3] = alpha return rgba def upscale_4x(rgba: np.ndarray) -> np.ndarray: """4x Lanczos upscale with unsharp masking.""" h, w = rgba.shape[:2] new_w, new_h = w * 4, h * 4 upscaled = cv2.resize(rgba, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4) # Unsharp mask on RGB only rgb = upscaled[:, :, :3] blurred = cv2.GaussianBlur(rgb, (0, 0), sigmaX=1.0) rgb_sharp = cv2.addWeighted(rgb, 1.5, blurred, -0.5, 0) upscaled[:, :, :3] = rgb_sharp return upscaled def is_notebooklm_logo(box, img_w, img_h): """Filter out small detections in bottom-right corner (NotebookLM watermark).""" x0, y0, x1, y1 = box bw, bh = x1 - x0, y1 - y0 # Skip if small AND in bottom-right 15% of image if bw < 80 and bh < 80: center_x = (x0 + x1) / 2 center_y = (y0 + y1) / 2 if center_x > img_w * 0.85 and center_y > img_h * 0.85: return True return False def extract_assets(input_image, bg_color_hex="#FFFFFF"): import sys, traceback try: print(">>> extract_assets V2 called", flush=True) if input_image is None: gr.Info("Please upload an image first.") return [] if processor is None: gr.Warning("Model is still loading. Please wait and try again.") return [] # Parse background color bg_hex = bg_color_hex.lstrip("#") try: bg_color = tuple(int(bg_hex[i:i+2], 16) for i in (0, 2, 4)) except: bg_color = (255, 255, 255) print(f">>> Background color: {bg_color}", flush=True) orig_rgb = input_image h, w = orig_rgb.shape[:2] img_area = h * w print(f">>> Image size: {w}x{h}, area: {img_area}", flush=True) pil_img = Image.fromarray(orig_rgb) all_boxes = [] all_scores = [] with torch.inference_mode(): print(">>> Running set_image...", flush=True) state = processor.set_image(pil_img) print(">>> set_image complete! Running two-pass detection...", flush=True) for concept in ALL_CONCEPTS: print(f">>> Concept: '{concept}'...", flush=True) out = processor.set_text_prompt(state=state, prompt=concept) masks = out["masks"] scores = out["scores"] if masks is None or len(masks) == 0: print(f" [{concept}] No detections", flush=True) continue if torch.is_tensor(masks): masks = masks.float().cpu().numpy() if torch.is_tensor(scores): scores = scores.float().cpu().numpy() print(f" [{concept}] Found {len(masks)} masks", flush=True) for j in range(len(masks)): m = masks[j] while m.ndim > 2: m = m[0] m_bool = m.astype(bool) score = float(scores[j]) if scores.ndim > 0 else float(scores) # Derive bounding box from mask ys, xs = np.nonzero(m_bool) if len(ys) == 0: continue x0, y0 = int(xs.min()), int(ys.min()) x1, y1 = int(xs.max()), int(ys.max()) bw, bh = x1 - x0, y1 - y0 box_area = bw * bh # Filters if score < 0.1: print(f" [{j}] SKIP low score: {score:.4f}", flush=True) continue if box_area < 500 or bw < 20 or bh < 20: print(f" [{j}] SKIP too small: {bw}x{bh}", flush=True) continue if box_area > img_area * 0.90: print(f" [{j}] SKIP too large", flush=True) continue if is_notebooklm_logo([x0, y0, x1, y1], w, h): print(f" [{j}] SKIP NotebookLM logo position", flush=True) continue # Add padding (8% of box size) pad_x = max(8, int(bw * 0.08)) pad_y = max(8, int(bh * 0.08)) bx0 = max(0, x0 - pad_x) by0 = max(0, y0 - pad_y) bx1 = min(w, x1 + pad_x) by1 = min(h, y1 + pad_y) all_boxes.append([bx0, by0, bx1, by1]) all_scores.append(score) print(f" [{j}] KEPT: score={score:.4f}, box=[{bx0},{by0},{bx1},{by1}]", flush=True) print(f">>> Total detections: {len(all_boxes)}", flush=True) if not all_boxes: gr.Info("No visual assets found. Try a different slide with more illustrations.") return [] # Deduplicate by box IoU (keep highest score) order = sorted(range(len(all_boxes)), key=lambda i: all_scores[i], reverse=True) keep = [] for i in order: dup = False for ki in keep: if box_iou(all_boxes[i], all_boxes[ki]) > 0.5: dup = True break if not dup: keep.append(i) print(f">>> After dedup: {len(keep)} unique assets", flush=True) # For each: crop → flood-fill BG removal → upscale → save results = [] global asset_counter for idx, ki in enumerate(keep): bx0, by0, bx1, by1 = all_boxes[ki] crop_rgb = orig_rgb[by0:by1, bx0:bx1] # Flood-fill background removal (preserves interior fills) rgba = remove_color_bg(crop_rgb, bg_color=bg_color, tolerance=30) # 4x upscale rgba = upscale_4x(rgba) asset_counter += 1 lib_path = os.path.join(ASSET_LIBRARY_DIR, f"asset_{asset_counter:04d}.png") Image.fromarray(rgba, "RGBA").save(lib_path, format="PNG") results.append(lib_path) print(f" asset[{idx}] saved: {lib_path}", flush=True) print(f">>> Returning {len(results)} assets (library: {asset_counter})", flush=True) return results except Exception as e: print(f">>> EXCEPTION in extract_assets: {e}", flush=True) traceback.print_exc() sys.stdout.flush() return [] def extract_from_pdf(pdf_file, bg_color_hex="#FFFFFF", progress=gr.Progress()): """Process every page of a PDF through SAM 3 extraction.""" import sys, traceback try: if pdf_file is None: return [] pdf_path = pdf_file if isinstance(pdf_file, str) else pdf_file.name print(f">>> PDF upload: {pdf_path}", flush=True) doc = fitz.open(pdf_path) total_pages = len(doc) print(f">>> PDF has {total_pages} pages", flush=True) all_results = [] for page_num in progress.tqdm(range(total_pages), desc="Processing PDF pages"): print(f">>> Processing page {page_num + 1}/{total_pages}...", flush=True) page = doc[page_num] mat = fitz.Matrix(2.0, 2.0) pix = page.get_pixmap(matrix=mat) img_array = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.h, pix.w, pix.n) if pix.n == 4: img_rgb = img_array[:, :, :3].copy() else: img_rgb = img_array.copy() page_results = extract_assets(img_rgb, bg_color_hex=bg_color_hex) all_results.extend(page_results) print(f">>> Page {page_num + 1}: extracted {len(page_results)} assets", flush=True) doc.close() print(f">>> PDF complete: {len(all_results)} total assets from {total_pages} pages", flush=True) return all_results except Exception as e: print(f">>> EXCEPTION in extract_from_pdf: {e}", flush=True) traceback.print_exc() sys.stdout.flush() return [] custom_css = """ /* ── Premium Dark Theme ───────────────────────────── */ .gradio-container { max-width: 1400px !important; margin: auto; } #app-title { text-align: center; background: linear-gradient(135deg, #667eea 0%, #f97316 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; font-size: 2.2rem !important; font-weight: 800 !important; margin-bottom: 0 !important; } #app-subtitle { text-align: center; color: #94a3b8 !important; font-size: 0.95rem !important; margin-top: 0 !important; } /* Gallery with hover download */ .gallery-container { min-height: 650px !important; } .gallery-container .gallery-item { position: relative; border-radius: 12px; overflow: hidden; transition: transform 0.2s ease, box-shadow 0.2s ease; background: #1e293b; } .gallery-container .gallery-item:hover { transform: scale(1.03); box-shadow: 0 8px 32px rgba(102, 126, 234, 0.3); } /* Download button: hidden by default, shown on hover */ .gallery-container .gallery-item button.download { opacity: 0 !important; transition: opacity 0.25s ease !important; position: absolute !important; bottom: 8px !important; right: 8px !important; z-index: 10 !important; background: rgba(249, 115, 22, 0.9) !important; color: white !important; border-radius: 8px !important; padding: 6px 14px !important; font-weight: 600 !important; border: none !important; cursor: pointer !important; } .gallery-container .gallery-item:hover button.download { opacity: 1 !important; } /* Extract button styling */ #extract-btn { background: linear-gradient(135deg, #f97316 0%, #ea580c 100%) !important; border: none !important; font-weight: 700 !important; font-size: 1.1rem !important; padding: 14px 0 !important; border-radius: 12px !important; transition: all 0.3s ease !important; } #extract-btn:hover { transform: translateY(-2px) !important; box-shadow: 0 6px 24px rgba(249, 115, 22, 0.4) !important; } /* Upload area */ #upload-area { border: 2px dashed #475569 !important; border-radius: 12px !important; transition: border-color 0.3s ease !important; } #upload-area:hover { border-color: #667eea !important; } /* Color picker label */ #bg-color-picker { max-width: 200px; } """ app_theme = gr.themes.Soft( primary_hue="orange", secondary_hue="blue", neutral_hue="slate", font=gr.themes.GoogleFont("Inter"), ) def download_all_zip(): """Package all extracted assets into a downloadable ZIP.""" zip_path = os.path.join(tempfile.gettempdir(), "extracted_assets.zip") pngs = sorted([f for f in os.listdir(ASSET_LIBRARY_DIR) if f.endswith(".png")]) if not pngs: return None with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf: for f in pngs: zf.write(os.path.join(ASSET_LIBRARY_DIR, f), f) return zip_path with gr.Blocks(title="SAM 3 Asset Extractor") as demo: gr.Markdown("# 🎨 SAM 3 Visual Asset Extractor", elem_id="app-title") gr.Markdown( "Upload a presentation slide or PDF to extract all **visual elements** " "(charts, diagrams, icons, illustrations) as **transparent PNGs** ready for " "**video editing** — powered by Meta's SAM 3 + intelligent background removal.", elem_id="app-subtitle" ) with gr.Row(equal_height=False): with gr.Column(scale=1, min_width=340): with gr.Tabs(): with gr.Tab("🖼️ Single Image"): input_image = gr.Image( label="📤 Upload Slide", type="numpy", elem_id="upload-area", height=300, ) submit_btn = gr.Button( "🔍 Extract Visual Assets", variant="primary", elem_id="extract-btn", size="lg", ) with gr.Tab("📄 PDF Batch"): input_pdf = gr.File( label="📤 Upload PDF", file_types=[".pdf"], ) pdf_btn = gr.Button( "📄 Extract from All Pages", variant="primary", elem_id="extract-btn", size="lg", ) bg_color_input = gr.Textbox( label="🎨 Background Color to Remove", value="#FFFFFF", elem_id="bg-color-picker", info="Hex color of slide background (e.g. #FFFFFF for white)", max_lines=1, ) download_btn = gr.DownloadButton( "📦 Download All as ZIP", size="lg", ) gr.Markdown( "**🔍 Detects:** charts · diagrams · graphs · tables · " "illustrations · infographics · figures · photos · " "icons · symbols · arrows · bars · persons · badges\n\n" "**🚫 Excludes:** text blocks · logos · watermarks", elem_id="concept-list" ) with gr.Column(scale=3): output_gallery = gr.Gallery( label="🎨 Extracted Assets — Hover to download individual PNGs", columns=4, object_fit="contain", height=700, format="png", elem_classes=["gallery-container"], ) submit_btn.click( fn=extract_assets, inputs=[input_image, bg_color_input], outputs=[output_gallery] ) pdf_btn.click( fn=extract_from_pdf, inputs=[input_pdf, bg_color_input], outputs=[output_gallery] ) download_btn.click(fn=download_all_zip, inputs=[], outputs=[download_btn]) if __name__ == "__main__": auth_user = os.environ.get("APP_USERNAME", "veurone") auth_pass = os.environ.get("APP_PASSWORD", "sam3extract") demo.launch(css=custom_css, theme=app_theme, auth=(auth_user, auth_pass))