Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import cv2 | |
| import torch | |
| from PIL import Image | |
| import os | |
| import io | |
| import fitz # PyMuPDF | |
| # ── UNCONDITIONAL BFloat16 → Float16 Patch for T4 Turing GPUs ──── | |
| # CRITICAL: torch.cuda.is_bf16_supported() returns True on T4 because CUDA | |
| # can *emulate* bfloat16 in software, but the actual kernels crash on mixed | |
| # dtype operations (linear, conv2d). We MUST patch unconditionally. | |
| if torch.cuda.is_available(): | |
| # 1. Intercept ALL autocast entry points to force float16 | |
| import torch.amp.autocast_mode | |
| _OriginalAutocast = torch.amp.autocast_mode.autocast | |
| class _Fp16Autocast(_OriginalAutocast): | |
| def __init__(self, device_type, dtype=None, *args, **kwargs): | |
| # Intercept Meta's bfloat16 request and force float16 for Turing support | |
| if dtype == torch.bfloat16: | |
| dtype = torch.float16 | |
| super().__init__(device_type, dtype=dtype, *args, **kwargs) | |
| torch.autocast = _Fp16Autocast | |
| torch.amp.autocast_mode.autocast = _Fp16Autocast | |
| if hasattr(torch.amp, 'autocast'): | |
| torch.amp.autocast = _Fp16Autocast | |
| if hasattr(torch.cuda.amp, 'autocast'): | |
| torch.cuda.amp.autocast = _Fp16Autocast | |
| # 2. Patch core Math Kernels to deterministically auto-cast mismatched float matrices natively. | |
| # This acts as our unbreakable "AMP Engine" that never drops state inside deep transformer blocks! | |
| import torch.nn.functional as F | |
| orig_linear = F.linear | |
| def patched_linear(input, weight, bias=None): | |
| if input.is_floating_point() and input.dtype != weight.dtype: | |
| input = input.to(weight.dtype) | |
| return orig_linear(input, weight, bias) | |
| F.linear = patched_linear | |
| orig_conv2d = F.conv2d | |
| def patched_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): | |
| if input.is_floating_point() and input.dtype != weight.dtype: | |
| input = input.to(weight.dtype) | |
| return orig_conv2d(input, weight, bias, stride, padding, dilation, groups) | |
| F.conv2d = patched_conv2d | |
| # 3. Patch torchvision.ops.roi_align — Meta's geometry_encoders.py | |
| # calls boxes_xyxy.float() which creates float32 while img_feats is float16. | |
| try: | |
| import torchvision.ops | |
| orig_roi_align = torchvision.ops.roi_align | |
| def patched_roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1, aligned=False): | |
| # Handle Tensor, list, or tuple (Meta uses .unbind() which returns a tuple!) | |
| if isinstance(boxes, torch.Tensor): | |
| if input.is_floating_point() and boxes.dtype != input.dtype: | |
| boxes = boxes.to(input.dtype) | |
| elif isinstance(boxes, (list, tuple)): | |
| boxes = [b.to(input.dtype) if isinstance(b, torch.Tensor) and b.dtype != input.dtype else b for b in boxes] | |
| return orig_roi_align(input, boxes, output_size, spatial_scale, sampling_ratio, aligned) | |
| torchvision.ops.roi_align = patched_roi_align | |
| except ImportError: | |
| pass | |
| # 4. Patch layer_norm / group_norm — common ViT dtype mismatch points | |
| orig_layer_norm = F.layer_norm | |
| def patched_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5): | |
| if weight is not None and input.is_floating_point() and input.dtype != weight.dtype: | |
| input = input.to(weight.dtype) | |
| return orig_layer_norm(input, normalized_shape, weight, bias, eps) | |
| F.layer_norm = patched_layer_norm | |
| # ── Ensure SAM 3 Checkpoint is downloaded ──────────────────────── | |
| # (HuggingFace Spaces can use the hf_hub_download mechanism) | |
| from huggingface_hub import hf_hub_download | |
| # ── HF Token Authentication ──────────────────────────────────────── | |
| print("Downloading SAM 3 model...") | |
| hf_token = os.environ.get("HF_TOKEN") | |
| ckpt_path = hf_hub_download(repo_id="facebook/sam3", filename="sam3.pt", token=hf_token) | |
| # ── Monkey Patch SAM 3 CUDA Hardcoding Bug ─────────────────────── | |
| # Meta's SAM 3 repo hardcodes `device="cuda"` in many places! | |
| # This intercepts common PyTorch tensor constructors to force "cpu" if no GPU is available. | |
| if not torch.cuda.is_available(): | |
| import functools | |
| patch_funcs = ['zeros', 'arange', 'tensor', 'ones', 'empty', 'randn', 'full', 'linspace'] | |
| for name in patch_funcs: | |
| if hasattr(torch, name): | |
| orig_fn = getattr(torch, name) | |
| def patched_fn(*args, __orig_fn=orig_fn, **kwargs): | |
| if 'device' in kwargs and str(kwargs['device']).startswith('cuda'): | |
| kwargs['device'] = 'cpu' | |
| return __orig_fn(*args, **kwargs) | |
| setattr(torch, name, patched_fn) | |
| # ── SAM 3 Imports ──────────────────────────────────────────────── | |
| try: | |
| from sam3.model_builder import build_sam3_image_model | |
| from sam3.model.sam3_image_processor import Sam3Processor | |
| model_installed = True | |
| except ImportError: | |
| model_installed = False | |
| print("SAM 3 not installed yet (will be installed by requirements.txt).") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| processor = None | |
| if model_installed: | |
| print(f"Loading SAM 3 onto {device}...") | |
| model = build_sam3_image_model(checkpoint_path=ckpt_path) | |
| # Cast to float16 — T4 has native float16 Tensor Core acceleration. | |
| # bfloat16 hangs (software emulated on Turing), float32 produced zero masks. | |
| model.half() | |
| # Diagnostic: verify checkpoint loaded correctly | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| print(f"Total parameters: {total_params:,}", flush=True) | |
| sample_dtype = next(model.parameters()).dtype | |
| print(f"Model dtype: {sample_dtype}", flush=True) | |
| processor = Sam3Processor(model) | |
| if not torch.cuda.is_available(): | |
| processor.device = "cpu" | |
| print("Model loaded successfully.") | |
| # Two-pass concept detection: parent (composite) + child (individual) elements | |
| # Excludes 'text block' (user doesn't want text) and 'logo' (picks up watermarks) | |
| PARENT_CONCEPTS = [ | |
| "chart", "diagram", "graph", "table", "illustration", | |
| "infographic", "figure", "photo", "picture", "image" | |
| ] | |
| CHILD_CONCEPTS = [ | |
| "icon", "symbol", "arrow", "bar", "person", | |
| "object", "button", "badge", "circle", "label" | |
| ] | |
| ALL_CONCEPTS = PARENT_CONCEPTS + CHILD_CONCEPTS | |
| # Persistent asset library | |
| import tempfile, zipfile | |
| ASSET_LIBRARY_DIR = os.path.join(tempfile.gettempdir(), "sam3_library") | |
| os.makedirs(ASSET_LIBRARY_DIR, exist_ok=True) | |
| asset_counter = 0 | |
| def box_iou(b1, b2): | |
| """IoU between two boxes [x0, y0, x1, y1].""" | |
| x0 = max(b1[0], b2[0]) | |
| y0 = max(b1[1], b2[1]) | |
| x1 = min(b1[2], b2[2]) | |
| y1 = min(b1[3], b2[3]) | |
| inter = max(0, x1 - x0) * max(0, y1 - y0) | |
| a1 = (b1[2] - b1[0]) * (b1[3] - b1[1]) | |
| a2 = (b2[2] - b2[0]) * (b2[3] - b2[1]) | |
| union = a1 + a2 - inter | |
| return inter / union if union > 0 else 0.0 | |
| def remove_color_bg(crop_rgb: np.ndarray, bg_color=(255, 255, 255), tolerance=30) -> np.ndarray: | |
| """Remove background by flood-filling from edges. | |
| Only removes pixels CONNECTED to the border that match bg_color. | |
| White/colored areas INSIDE objects are preserved. | |
| """ | |
| h, w = crop_rgb.shape[:2] | |
| if h < 2 or w < 2: | |
| rgba = np.zeros((h, w, 4), dtype=np.uint8) | |
| rgba[:, :, :3] = crop_rgb | |
| rgba[:, :, 3] = 255 | |
| return rgba | |
| # Create a mask of pixels matching the background color within tolerance | |
| bg = np.array(bg_color, dtype=np.float32) | |
| diff = np.sqrt(np.sum((crop_rgb.astype(np.float32) - bg) ** 2, axis=2)) | |
| color_match = (diff < tolerance).astype(np.uint8) * 255 | |
| # Flood fill from all border pixels to find CONNECTED background | |
| # Use floodFill on a padded version to handle edge connectivity | |
| flood_mask = np.zeros((h + 2, w + 2), dtype=np.uint8) | |
| bg_connected = np.zeros((h, w), dtype=np.uint8) | |
| # Seed from all border pixels that match background color | |
| border_seeds = [] | |
| for x in range(w): | |
| if color_match[0, x]: border_seeds.append((x, 0)) | |
| if color_match[h-1, x]: border_seeds.append((x, h-1)) | |
| for y in range(h): | |
| if color_match[y, 0]: border_seeds.append((0, y)) | |
| if color_match[y, w-1]: border_seeds.append((w-1, y)) | |
| # Flood fill from each border seed | |
| for sx, sy in border_seeds: | |
| if bg_connected[sy, sx] == 0 and color_match[sy, sx]: | |
| flood_mask[:] = 0 | |
| cv2.floodFill(color_match.copy(), flood_mask, (sx, sy), 128, | |
| loDiff=0, upDiff=0, flags=cv2.FLOODFILL_MASK_ONLY | (8 << 8)) | |
| # flood_mask has 1s where the fill reached (in the +1 padded area) | |
| bg_connected |= flood_mask[1:-1, 1:-1] | |
| # Alpha: 255 for foreground, 0 for connected background | |
| alpha = np.where(bg_connected > 0, np.uint8(0), np.uint8(255)) | |
| # Slight edge AA: blur alpha then re-clamp interior | |
| alpha_f = alpha.astype(np.float32) | |
| alpha_blur = cv2.GaussianBlur(alpha_f, (3, 3), sigmaX=0.8) | |
| interior = alpha > 240 | |
| alpha_aa = np.where(interior, 255.0, alpha_blur) | |
| alpha = alpha_aa.clip(0, 255).astype(np.uint8) | |
| # Build RGBA | |
| rgba = np.zeros((h, w, 4), dtype=np.uint8) | |
| rgba[:, :, :3] = crop_rgb | |
| rgba[:, :, 3] = alpha | |
| return rgba | |
| def upscale_4x(rgba: np.ndarray) -> np.ndarray: | |
| """4x Lanczos upscale with unsharp masking.""" | |
| h, w = rgba.shape[:2] | |
| new_w, new_h = w * 4, h * 4 | |
| upscaled = cv2.resize(rgba, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4) | |
| # Unsharp mask on RGB only | |
| rgb = upscaled[:, :, :3] | |
| blurred = cv2.GaussianBlur(rgb, (0, 0), sigmaX=1.0) | |
| rgb_sharp = cv2.addWeighted(rgb, 1.5, blurred, -0.5, 0) | |
| upscaled[:, :, :3] = rgb_sharp | |
| return upscaled | |
| def is_notebooklm_logo(box, img_w, img_h): | |
| """Filter out small detections in bottom-right corner (NotebookLM watermark).""" | |
| x0, y0, x1, y1 = box | |
| bw, bh = x1 - x0, y1 - y0 | |
| # Skip if small AND in bottom-right 15% of image | |
| if bw < 80 and bh < 80: | |
| center_x = (x0 + x1) / 2 | |
| center_y = (y0 + y1) / 2 | |
| if center_x > img_w * 0.85 and center_y > img_h * 0.85: | |
| return True | |
| return False | |
| def extract_assets(input_image, bg_color_hex="#FFFFFF"): | |
| import sys, traceback | |
| try: | |
| print(">>> extract_assets V2 called", flush=True) | |
| if input_image is None: | |
| gr.Info("Please upload an image first.") | |
| return [] | |
| if processor is None: | |
| gr.Warning("Model is still loading. Please wait and try again.") | |
| return [] | |
| # Parse background color | |
| bg_hex = bg_color_hex.lstrip("#") | |
| try: | |
| bg_color = tuple(int(bg_hex[i:i+2], 16) for i in (0, 2, 4)) | |
| except: | |
| bg_color = (255, 255, 255) | |
| print(f">>> Background color: {bg_color}", flush=True) | |
| orig_rgb = input_image | |
| h, w = orig_rgb.shape[:2] | |
| img_area = h * w | |
| print(f">>> Image size: {w}x{h}, area: {img_area}", flush=True) | |
| pil_img = Image.fromarray(orig_rgb) | |
| all_boxes = [] | |
| all_scores = [] | |
| with torch.inference_mode(): | |
| print(">>> Running set_image...", flush=True) | |
| state = processor.set_image(pil_img) | |
| print(">>> set_image complete! Running two-pass detection...", flush=True) | |
| for concept in ALL_CONCEPTS: | |
| print(f">>> Concept: '{concept}'...", flush=True) | |
| out = processor.set_text_prompt(state=state, prompt=concept) | |
| masks = out["masks"] | |
| scores = out["scores"] | |
| if masks is None or len(masks) == 0: | |
| print(f" [{concept}] No detections", flush=True) | |
| continue | |
| if torch.is_tensor(masks): masks = masks.float().cpu().numpy() | |
| if torch.is_tensor(scores): scores = scores.float().cpu().numpy() | |
| print(f" [{concept}] Found {len(masks)} masks", flush=True) | |
| for j in range(len(masks)): | |
| m = masks[j] | |
| while m.ndim > 2: m = m[0] | |
| m_bool = m.astype(bool) | |
| score = float(scores[j]) if scores.ndim > 0 else float(scores) | |
| # Derive bounding box from mask | |
| ys, xs = np.nonzero(m_bool) | |
| if len(ys) == 0: continue | |
| x0, y0 = int(xs.min()), int(ys.min()) | |
| x1, y1 = int(xs.max()), int(ys.max()) | |
| bw, bh = x1 - x0, y1 - y0 | |
| box_area = bw * bh | |
| # Filters | |
| if score < 0.1: | |
| print(f" [{j}] SKIP low score: {score:.4f}", flush=True) | |
| continue | |
| if box_area < 500 or bw < 20 or bh < 20: | |
| print(f" [{j}] SKIP too small: {bw}x{bh}", flush=True) | |
| continue | |
| if box_area > img_area * 0.90: | |
| print(f" [{j}] SKIP too large", flush=True) | |
| continue | |
| if is_notebooklm_logo([x0, y0, x1, y1], w, h): | |
| print(f" [{j}] SKIP NotebookLM logo position", flush=True) | |
| continue | |
| # Add padding (8% of box size) | |
| pad_x = max(8, int(bw * 0.08)) | |
| pad_y = max(8, int(bh * 0.08)) | |
| bx0 = max(0, x0 - pad_x) | |
| by0 = max(0, y0 - pad_y) | |
| bx1 = min(w, x1 + pad_x) | |
| by1 = min(h, y1 + pad_y) | |
| all_boxes.append([bx0, by0, bx1, by1]) | |
| all_scores.append(score) | |
| print(f" [{j}] KEPT: score={score:.4f}, box=[{bx0},{by0},{bx1},{by1}]", flush=True) | |
| print(f">>> Total detections: {len(all_boxes)}", flush=True) | |
| if not all_boxes: | |
| gr.Info("No visual assets found. Try a different slide with more illustrations.") | |
| return [] | |
| # Deduplicate by box IoU (keep highest score) | |
| order = sorted(range(len(all_boxes)), key=lambda i: all_scores[i], reverse=True) | |
| keep = [] | |
| for i in order: | |
| dup = False | |
| for ki in keep: | |
| if box_iou(all_boxes[i], all_boxes[ki]) > 0.5: | |
| dup = True | |
| break | |
| if not dup: | |
| keep.append(i) | |
| print(f">>> After dedup: {len(keep)} unique assets", flush=True) | |
| # For each: crop → flood-fill BG removal → upscale → save | |
| results = [] | |
| global asset_counter | |
| for idx, ki in enumerate(keep): | |
| bx0, by0, bx1, by1 = all_boxes[ki] | |
| crop_rgb = orig_rgb[by0:by1, bx0:bx1] | |
| # Flood-fill background removal (preserves interior fills) | |
| rgba = remove_color_bg(crop_rgb, bg_color=bg_color, tolerance=30) | |
| # 4x upscale | |
| rgba = upscale_4x(rgba) | |
| asset_counter += 1 | |
| lib_path = os.path.join(ASSET_LIBRARY_DIR, f"asset_{asset_counter:04d}.png") | |
| Image.fromarray(rgba, "RGBA").save(lib_path, format="PNG") | |
| results.append(lib_path) | |
| print(f" asset[{idx}] saved: {lib_path}", flush=True) | |
| print(f">>> Returning {len(results)} assets (library: {asset_counter})", flush=True) | |
| return results | |
| except Exception as e: | |
| print(f">>> EXCEPTION in extract_assets: {e}", flush=True) | |
| traceback.print_exc() | |
| sys.stdout.flush() | |
| return [] | |
| def extract_from_pdf(pdf_file, bg_color_hex="#FFFFFF", progress=gr.Progress()): | |
| """Process every page of a PDF through SAM 3 extraction.""" | |
| import sys, traceback | |
| try: | |
| if pdf_file is None: | |
| return [] | |
| pdf_path = pdf_file if isinstance(pdf_file, str) else pdf_file.name | |
| print(f">>> PDF upload: {pdf_path}", flush=True) | |
| doc = fitz.open(pdf_path) | |
| total_pages = len(doc) | |
| print(f">>> PDF has {total_pages} pages", flush=True) | |
| all_results = [] | |
| for page_num in progress.tqdm(range(total_pages), desc="Processing PDF pages"): | |
| print(f">>> Processing page {page_num + 1}/{total_pages}...", flush=True) | |
| page = doc[page_num] | |
| mat = fitz.Matrix(2.0, 2.0) | |
| pix = page.get_pixmap(matrix=mat) | |
| img_array = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.h, pix.w, pix.n) | |
| if pix.n == 4: | |
| img_rgb = img_array[:, :, :3].copy() | |
| else: | |
| img_rgb = img_array.copy() | |
| page_results = extract_assets(img_rgb, bg_color_hex=bg_color_hex) | |
| all_results.extend(page_results) | |
| print(f">>> Page {page_num + 1}: extracted {len(page_results)} assets", flush=True) | |
| doc.close() | |
| print(f">>> PDF complete: {len(all_results)} total assets from {total_pages} pages", flush=True) | |
| return all_results | |
| except Exception as e: | |
| print(f">>> EXCEPTION in extract_from_pdf: {e}", flush=True) | |
| traceback.print_exc() | |
| sys.stdout.flush() | |
| return [] | |
| custom_css = """ | |
| /* ── Premium Dark Theme ───────────────────────────── */ | |
| .gradio-container { | |
| max-width: 1400px !important; | |
| margin: auto; | |
| } | |
| #app-title { | |
| text-align: center; | |
| background: linear-gradient(135deg, #667eea 0%, #f97316 100%); | |
| -webkit-background-clip: text; | |
| -webkit-text-fill-color: transparent; | |
| font-size: 2.2rem !important; | |
| font-weight: 800 !important; | |
| margin-bottom: 0 !important; | |
| } | |
| #app-subtitle { | |
| text-align: center; | |
| color: #94a3b8 !important; | |
| font-size: 0.95rem !important; | |
| margin-top: 0 !important; | |
| } | |
| /* Gallery with hover download */ | |
| .gallery-container { | |
| min-height: 650px !important; | |
| } | |
| .gallery-container .gallery-item { | |
| position: relative; | |
| border-radius: 12px; | |
| overflow: hidden; | |
| transition: transform 0.2s ease, box-shadow 0.2s ease; | |
| background: #1e293b; | |
| } | |
| .gallery-container .gallery-item:hover { | |
| transform: scale(1.03); | |
| box-shadow: 0 8px 32px rgba(102, 126, 234, 0.3); | |
| } | |
| /* Download button: hidden by default, shown on hover */ | |
| .gallery-container .gallery-item button.download { | |
| opacity: 0 !important; | |
| transition: opacity 0.25s ease !important; | |
| position: absolute !important; | |
| bottom: 8px !important; | |
| right: 8px !important; | |
| z-index: 10 !important; | |
| background: rgba(249, 115, 22, 0.9) !important; | |
| color: white !important; | |
| border-radius: 8px !important; | |
| padding: 6px 14px !important; | |
| font-weight: 600 !important; | |
| border: none !important; | |
| cursor: pointer !important; | |
| } | |
| .gallery-container .gallery-item:hover button.download { | |
| opacity: 1 !important; | |
| } | |
| /* Extract button styling */ | |
| #extract-btn { | |
| background: linear-gradient(135deg, #f97316 0%, #ea580c 100%) !important; | |
| border: none !important; | |
| font-weight: 700 !important; | |
| font-size: 1.1rem !important; | |
| padding: 14px 0 !important; | |
| border-radius: 12px !important; | |
| transition: all 0.3s ease !important; | |
| } | |
| #extract-btn:hover { | |
| transform: translateY(-2px) !important; | |
| box-shadow: 0 6px 24px rgba(249, 115, 22, 0.4) !important; | |
| } | |
| /* Upload area */ | |
| #upload-area { | |
| border: 2px dashed #475569 !important; | |
| border-radius: 12px !important; | |
| transition: border-color 0.3s ease !important; | |
| } | |
| #upload-area:hover { | |
| border-color: #667eea !important; | |
| } | |
| /* Color picker label */ | |
| #bg-color-picker { | |
| max-width: 200px; | |
| } | |
| """ | |
| app_theme = gr.themes.Soft( | |
| primary_hue="orange", | |
| secondary_hue="blue", | |
| neutral_hue="slate", | |
| font=gr.themes.GoogleFont("Inter"), | |
| ) | |
| def download_all_zip(): | |
| """Package all extracted assets into a downloadable ZIP.""" | |
| zip_path = os.path.join(tempfile.gettempdir(), "extracted_assets.zip") | |
| pngs = sorted([f for f in os.listdir(ASSET_LIBRARY_DIR) if f.endswith(".png")]) | |
| if not pngs: | |
| return None | |
| with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf: | |
| for f in pngs: | |
| zf.write(os.path.join(ASSET_LIBRARY_DIR, f), f) | |
| return zip_path | |
| with gr.Blocks(title="SAM 3 Asset Extractor") as demo: | |
| gr.Markdown("# 🎨 SAM 3 Visual Asset Extractor", elem_id="app-title") | |
| gr.Markdown( | |
| "Upload a presentation slide or PDF to extract all **visual elements** " | |
| "(charts, diagrams, icons, illustrations) as **transparent PNGs** ready for " | |
| "**video editing** — powered by Meta's SAM 3 + intelligent background removal.", | |
| elem_id="app-subtitle" | |
| ) | |
| with gr.Row(equal_height=False): | |
| with gr.Column(scale=1, min_width=340): | |
| with gr.Tabs(): | |
| with gr.Tab("🖼️ Single Image"): | |
| input_image = gr.Image( | |
| label="📤 Upload Slide", | |
| type="numpy", | |
| elem_id="upload-area", | |
| height=300, | |
| ) | |
| submit_btn = gr.Button( | |
| "🔍 Extract Visual Assets", | |
| variant="primary", | |
| elem_id="extract-btn", | |
| size="lg", | |
| ) | |
| with gr.Tab("📄 PDF Batch"): | |
| input_pdf = gr.File( | |
| label="📤 Upload PDF", | |
| file_types=[".pdf"], | |
| ) | |
| pdf_btn = gr.Button( | |
| "📄 Extract from All Pages", | |
| variant="primary", | |
| elem_id="extract-btn", | |
| size="lg", | |
| ) | |
| bg_color_input = gr.Textbox( | |
| label="🎨 Background Color to Remove", | |
| value="#FFFFFF", | |
| elem_id="bg-color-picker", | |
| info="Hex color of slide background (e.g. #FFFFFF for white)", | |
| max_lines=1, | |
| ) | |
| download_btn = gr.DownloadButton( | |
| "📦 Download All as ZIP", | |
| size="lg", | |
| ) | |
| gr.Markdown( | |
| "**🔍 Detects:** charts · diagrams · graphs · tables · " | |
| "illustrations · infographics · figures · photos · " | |
| "icons · symbols · arrows · bars · persons · badges\n\n" | |
| "**🚫 Excludes:** text blocks · logos · watermarks", | |
| elem_id="concept-list" | |
| ) | |
| with gr.Column(scale=3): | |
| output_gallery = gr.Gallery( | |
| label="🎨 Extracted Assets — Hover to download individual PNGs", | |
| columns=4, | |
| object_fit="contain", | |
| height=700, | |
| format="png", | |
| elem_classes=["gallery-container"], | |
| ) | |
| submit_btn.click( | |
| fn=extract_assets, | |
| inputs=[input_image, bg_color_input], | |
| outputs=[output_gallery] | |
| ) | |
| pdf_btn.click( | |
| fn=extract_from_pdf, | |
| inputs=[input_pdf, bg_color_input], | |
| outputs=[output_gallery] | |
| ) | |
| download_btn.click(fn=download_all_zip, inputs=[], outputs=[download_btn]) | |
| if __name__ == "__main__": | |
| auth_user = os.environ.get("APP_USERNAME", "veurone") | |
| auth_pass = os.environ.get("APP_PASSWORD", "sam3extract") | |
| demo.launch(css=custom_css, theme=app_theme, auth=(auth_user, auth_pass)) | |