SAM_three_UI / app.py
AI Agent
Fix Gradio compat: remove show_download_button, replace ColorPicker with Textbox
525c3fb
import gradio as gr
import numpy as np
import cv2
import torch
from PIL import Image
import os
import io
import fitz # PyMuPDF
# ── UNCONDITIONAL BFloat16 → Float16 Patch for T4 Turing GPUs ────
# CRITICAL: torch.cuda.is_bf16_supported() returns True on T4 because CUDA
# can *emulate* bfloat16 in software, but the actual kernels crash on mixed
# dtype operations (linear, conv2d). We MUST patch unconditionally.
if torch.cuda.is_available():
# 1. Intercept ALL autocast entry points to force float16
import torch.amp.autocast_mode
_OriginalAutocast = torch.amp.autocast_mode.autocast
class _Fp16Autocast(_OriginalAutocast):
def __init__(self, device_type, dtype=None, *args, **kwargs):
# Intercept Meta's bfloat16 request and force float16 for Turing support
if dtype == torch.bfloat16:
dtype = torch.float16
super().__init__(device_type, dtype=dtype, *args, **kwargs)
torch.autocast = _Fp16Autocast
torch.amp.autocast_mode.autocast = _Fp16Autocast
if hasattr(torch.amp, 'autocast'):
torch.amp.autocast = _Fp16Autocast
if hasattr(torch.cuda.amp, 'autocast'):
torch.cuda.amp.autocast = _Fp16Autocast
# 2. Patch core Math Kernels to deterministically auto-cast mismatched float matrices natively.
# This acts as our unbreakable "AMP Engine" that never drops state inside deep transformer blocks!
import torch.nn.functional as F
orig_linear = F.linear
def patched_linear(input, weight, bias=None):
if input.is_floating_point() and input.dtype != weight.dtype:
input = input.to(weight.dtype)
return orig_linear(input, weight, bias)
F.linear = patched_linear
orig_conv2d = F.conv2d
def patched_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
if input.is_floating_point() and input.dtype != weight.dtype:
input = input.to(weight.dtype)
return orig_conv2d(input, weight, bias, stride, padding, dilation, groups)
F.conv2d = patched_conv2d
# 3. Patch torchvision.ops.roi_align — Meta's geometry_encoders.py
# calls boxes_xyxy.float() which creates float32 while img_feats is float16.
try:
import torchvision.ops
orig_roi_align = torchvision.ops.roi_align
def patched_roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1, aligned=False):
# Handle Tensor, list, or tuple (Meta uses .unbind() which returns a tuple!)
if isinstance(boxes, torch.Tensor):
if input.is_floating_point() and boxes.dtype != input.dtype:
boxes = boxes.to(input.dtype)
elif isinstance(boxes, (list, tuple)):
boxes = [b.to(input.dtype) if isinstance(b, torch.Tensor) and b.dtype != input.dtype else b for b in boxes]
return orig_roi_align(input, boxes, output_size, spatial_scale, sampling_ratio, aligned)
torchvision.ops.roi_align = patched_roi_align
except ImportError:
pass
# 4. Patch layer_norm / group_norm — common ViT dtype mismatch points
orig_layer_norm = F.layer_norm
def patched_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5):
if weight is not None and input.is_floating_point() and input.dtype != weight.dtype:
input = input.to(weight.dtype)
return orig_layer_norm(input, normalized_shape, weight, bias, eps)
F.layer_norm = patched_layer_norm
# ── Ensure SAM 3 Checkpoint is downloaded ────────────────────────
# (HuggingFace Spaces can use the hf_hub_download mechanism)
from huggingface_hub import hf_hub_download
# ── HF Token Authentication ────────────────────────────────────────
print("Downloading SAM 3 model...")
hf_token = os.environ.get("HF_TOKEN")
ckpt_path = hf_hub_download(repo_id="facebook/sam3", filename="sam3.pt", token=hf_token)
# ── Monkey Patch SAM 3 CUDA Hardcoding Bug ───────────────────────
# Meta's SAM 3 repo hardcodes `device="cuda"` in many places!
# This intercepts common PyTorch tensor constructors to force "cpu" if no GPU is available.
if not torch.cuda.is_available():
import functools
patch_funcs = ['zeros', 'arange', 'tensor', 'ones', 'empty', 'randn', 'full', 'linspace']
for name in patch_funcs:
if hasattr(torch, name):
orig_fn = getattr(torch, name)
@functools.wraps(orig_fn)
def patched_fn(*args, __orig_fn=orig_fn, **kwargs):
if 'device' in kwargs and str(kwargs['device']).startswith('cuda'):
kwargs['device'] = 'cpu'
return __orig_fn(*args, **kwargs)
setattr(torch, name, patched_fn)
# ── SAM 3 Imports ────────────────────────────────────────────────
try:
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
model_installed = True
except ImportError:
model_installed = False
print("SAM 3 not installed yet (will be installed by requirements.txt).")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor = None
if model_installed:
print(f"Loading SAM 3 onto {device}...")
model = build_sam3_image_model(checkpoint_path=ckpt_path)
# Cast to float16 — T4 has native float16 Tensor Core acceleration.
# bfloat16 hangs (software emulated on Turing), float32 produced zero masks.
model.half()
# Diagnostic: verify checkpoint loaded correctly
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}", flush=True)
sample_dtype = next(model.parameters()).dtype
print(f"Model dtype: {sample_dtype}", flush=True)
processor = Sam3Processor(model)
if not torch.cuda.is_available():
processor.device = "cpu"
print("Model loaded successfully.")
# Two-pass concept detection: parent (composite) + child (individual) elements
# Excludes 'text block' (user doesn't want text) and 'logo' (picks up watermarks)
PARENT_CONCEPTS = [
"chart", "diagram", "graph", "table", "illustration",
"infographic", "figure", "photo", "picture", "image"
]
CHILD_CONCEPTS = [
"icon", "symbol", "arrow", "bar", "person",
"object", "button", "badge", "circle", "label"
]
ALL_CONCEPTS = PARENT_CONCEPTS + CHILD_CONCEPTS
# Persistent asset library
import tempfile, zipfile
ASSET_LIBRARY_DIR = os.path.join(tempfile.gettempdir(), "sam3_library")
os.makedirs(ASSET_LIBRARY_DIR, exist_ok=True)
asset_counter = 0
def box_iou(b1, b2):
"""IoU between two boxes [x0, y0, x1, y1]."""
x0 = max(b1[0], b2[0])
y0 = max(b1[1], b2[1])
x1 = min(b1[2], b2[2])
y1 = min(b1[3], b2[3])
inter = max(0, x1 - x0) * max(0, y1 - y0)
a1 = (b1[2] - b1[0]) * (b1[3] - b1[1])
a2 = (b2[2] - b2[0]) * (b2[3] - b2[1])
union = a1 + a2 - inter
return inter / union if union > 0 else 0.0
def remove_color_bg(crop_rgb: np.ndarray, bg_color=(255, 255, 255), tolerance=30) -> np.ndarray:
"""Remove background by flood-filling from edges.
Only removes pixels CONNECTED to the border that match bg_color.
White/colored areas INSIDE objects are preserved.
"""
h, w = crop_rgb.shape[:2]
if h < 2 or w < 2:
rgba = np.zeros((h, w, 4), dtype=np.uint8)
rgba[:, :, :3] = crop_rgb
rgba[:, :, 3] = 255
return rgba
# Create a mask of pixels matching the background color within tolerance
bg = np.array(bg_color, dtype=np.float32)
diff = np.sqrt(np.sum((crop_rgb.astype(np.float32) - bg) ** 2, axis=2))
color_match = (diff < tolerance).astype(np.uint8) * 255
# Flood fill from all border pixels to find CONNECTED background
# Use floodFill on a padded version to handle edge connectivity
flood_mask = np.zeros((h + 2, w + 2), dtype=np.uint8)
bg_connected = np.zeros((h, w), dtype=np.uint8)
# Seed from all border pixels that match background color
border_seeds = []
for x in range(w):
if color_match[0, x]: border_seeds.append((x, 0))
if color_match[h-1, x]: border_seeds.append((x, h-1))
for y in range(h):
if color_match[y, 0]: border_seeds.append((0, y))
if color_match[y, w-1]: border_seeds.append((w-1, y))
# Flood fill from each border seed
for sx, sy in border_seeds:
if bg_connected[sy, sx] == 0 and color_match[sy, sx]:
flood_mask[:] = 0
cv2.floodFill(color_match.copy(), flood_mask, (sx, sy), 128,
loDiff=0, upDiff=0, flags=cv2.FLOODFILL_MASK_ONLY | (8 << 8))
# flood_mask has 1s where the fill reached (in the +1 padded area)
bg_connected |= flood_mask[1:-1, 1:-1]
# Alpha: 255 for foreground, 0 for connected background
alpha = np.where(bg_connected > 0, np.uint8(0), np.uint8(255))
# Slight edge AA: blur alpha then re-clamp interior
alpha_f = alpha.astype(np.float32)
alpha_blur = cv2.GaussianBlur(alpha_f, (3, 3), sigmaX=0.8)
interior = alpha > 240
alpha_aa = np.where(interior, 255.0, alpha_blur)
alpha = alpha_aa.clip(0, 255).astype(np.uint8)
# Build RGBA
rgba = np.zeros((h, w, 4), dtype=np.uint8)
rgba[:, :, :3] = crop_rgb
rgba[:, :, 3] = alpha
return rgba
def upscale_4x(rgba: np.ndarray) -> np.ndarray:
"""4x Lanczos upscale with unsharp masking."""
h, w = rgba.shape[:2]
new_w, new_h = w * 4, h * 4
upscaled = cv2.resize(rgba, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
# Unsharp mask on RGB only
rgb = upscaled[:, :, :3]
blurred = cv2.GaussianBlur(rgb, (0, 0), sigmaX=1.0)
rgb_sharp = cv2.addWeighted(rgb, 1.5, blurred, -0.5, 0)
upscaled[:, :, :3] = rgb_sharp
return upscaled
def is_notebooklm_logo(box, img_w, img_h):
"""Filter out small detections in bottom-right corner (NotebookLM watermark)."""
x0, y0, x1, y1 = box
bw, bh = x1 - x0, y1 - y0
# Skip if small AND in bottom-right 15% of image
if bw < 80 and bh < 80:
center_x = (x0 + x1) / 2
center_y = (y0 + y1) / 2
if center_x > img_w * 0.85 and center_y > img_h * 0.85:
return True
return False
def extract_assets(input_image, bg_color_hex="#FFFFFF"):
import sys, traceback
try:
print(">>> extract_assets V2 called", flush=True)
if input_image is None:
gr.Info("Please upload an image first.")
return []
if processor is None:
gr.Warning("Model is still loading. Please wait and try again.")
return []
# Parse background color
bg_hex = bg_color_hex.lstrip("#")
try:
bg_color = tuple(int(bg_hex[i:i+2], 16) for i in (0, 2, 4))
except:
bg_color = (255, 255, 255)
print(f">>> Background color: {bg_color}", flush=True)
orig_rgb = input_image
h, w = orig_rgb.shape[:2]
img_area = h * w
print(f">>> Image size: {w}x{h}, area: {img_area}", flush=True)
pil_img = Image.fromarray(orig_rgb)
all_boxes = []
all_scores = []
with torch.inference_mode():
print(">>> Running set_image...", flush=True)
state = processor.set_image(pil_img)
print(">>> set_image complete! Running two-pass detection...", flush=True)
for concept in ALL_CONCEPTS:
print(f">>> Concept: '{concept}'...", flush=True)
out = processor.set_text_prompt(state=state, prompt=concept)
masks = out["masks"]
scores = out["scores"]
if masks is None or len(masks) == 0:
print(f" [{concept}] No detections", flush=True)
continue
if torch.is_tensor(masks): masks = masks.float().cpu().numpy()
if torch.is_tensor(scores): scores = scores.float().cpu().numpy()
print(f" [{concept}] Found {len(masks)} masks", flush=True)
for j in range(len(masks)):
m = masks[j]
while m.ndim > 2: m = m[0]
m_bool = m.astype(bool)
score = float(scores[j]) if scores.ndim > 0 else float(scores)
# Derive bounding box from mask
ys, xs = np.nonzero(m_bool)
if len(ys) == 0: continue
x0, y0 = int(xs.min()), int(ys.min())
x1, y1 = int(xs.max()), int(ys.max())
bw, bh = x1 - x0, y1 - y0
box_area = bw * bh
# Filters
if score < 0.1:
print(f" [{j}] SKIP low score: {score:.4f}", flush=True)
continue
if box_area < 500 or bw < 20 or bh < 20:
print(f" [{j}] SKIP too small: {bw}x{bh}", flush=True)
continue
if box_area > img_area * 0.90:
print(f" [{j}] SKIP too large", flush=True)
continue
if is_notebooklm_logo([x0, y0, x1, y1], w, h):
print(f" [{j}] SKIP NotebookLM logo position", flush=True)
continue
# Add padding (8% of box size)
pad_x = max(8, int(bw * 0.08))
pad_y = max(8, int(bh * 0.08))
bx0 = max(0, x0 - pad_x)
by0 = max(0, y0 - pad_y)
bx1 = min(w, x1 + pad_x)
by1 = min(h, y1 + pad_y)
all_boxes.append([bx0, by0, bx1, by1])
all_scores.append(score)
print(f" [{j}] KEPT: score={score:.4f}, box=[{bx0},{by0},{bx1},{by1}]", flush=True)
print(f">>> Total detections: {len(all_boxes)}", flush=True)
if not all_boxes:
gr.Info("No visual assets found. Try a different slide with more illustrations.")
return []
# Deduplicate by box IoU (keep highest score)
order = sorted(range(len(all_boxes)), key=lambda i: all_scores[i], reverse=True)
keep = []
for i in order:
dup = False
for ki in keep:
if box_iou(all_boxes[i], all_boxes[ki]) > 0.5:
dup = True
break
if not dup:
keep.append(i)
print(f">>> After dedup: {len(keep)} unique assets", flush=True)
# For each: crop → flood-fill BG removal → upscale → save
results = []
global asset_counter
for idx, ki in enumerate(keep):
bx0, by0, bx1, by1 = all_boxes[ki]
crop_rgb = orig_rgb[by0:by1, bx0:bx1]
# Flood-fill background removal (preserves interior fills)
rgba = remove_color_bg(crop_rgb, bg_color=bg_color, tolerance=30)
# 4x upscale
rgba = upscale_4x(rgba)
asset_counter += 1
lib_path = os.path.join(ASSET_LIBRARY_DIR, f"asset_{asset_counter:04d}.png")
Image.fromarray(rgba, "RGBA").save(lib_path, format="PNG")
results.append(lib_path)
print(f" asset[{idx}] saved: {lib_path}", flush=True)
print(f">>> Returning {len(results)} assets (library: {asset_counter})", flush=True)
return results
except Exception as e:
print(f">>> EXCEPTION in extract_assets: {e}", flush=True)
traceback.print_exc()
sys.stdout.flush()
return []
def extract_from_pdf(pdf_file, bg_color_hex="#FFFFFF", progress=gr.Progress()):
"""Process every page of a PDF through SAM 3 extraction."""
import sys, traceback
try:
if pdf_file is None:
return []
pdf_path = pdf_file if isinstance(pdf_file, str) else pdf_file.name
print(f">>> PDF upload: {pdf_path}", flush=True)
doc = fitz.open(pdf_path)
total_pages = len(doc)
print(f">>> PDF has {total_pages} pages", flush=True)
all_results = []
for page_num in progress.tqdm(range(total_pages), desc="Processing PDF pages"):
print(f">>> Processing page {page_num + 1}/{total_pages}...", flush=True)
page = doc[page_num]
mat = fitz.Matrix(2.0, 2.0)
pix = page.get_pixmap(matrix=mat)
img_array = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.h, pix.w, pix.n)
if pix.n == 4:
img_rgb = img_array[:, :, :3].copy()
else:
img_rgb = img_array.copy()
page_results = extract_assets(img_rgb, bg_color_hex=bg_color_hex)
all_results.extend(page_results)
print(f">>> Page {page_num + 1}: extracted {len(page_results)} assets", flush=True)
doc.close()
print(f">>> PDF complete: {len(all_results)} total assets from {total_pages} pages", flush=True)
return all_results
except Exception as e:
print(f">>> EXCEPTION in extract_from_pdf: {e}", flush=True)
traceback.print_exc()
sys.stdout.flush()
return []
custom_css = """
/* ── Premium Dark Theme ───────────────────────────── */
.gradio-container {
max-width: 1400px !important;
margin: auto;
}
#app-title {
text-align: center;
background: linear-gradient(135deg, #667eea 0%, #f97316 100%);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
font-size: 2.2rem !important;
font-weight: 800 !important;
margin-bottom: 0 !important;
}
#app-subtitle {
text-align: center;
color: #94a3b8 !important;
font-size: 0.95rem !important;
margin-top: 0 !important;
}
/* Gallery with hover download */
.gallery-container {
min-height: 650px !important;
}
.gallery-container .gallery-item {
position: relative;
border-radius: 12px;
overflow: hidden;
transition: transform 0.2s ease, box-shadow 0.2s ease;
background: #1e293b;
}
.gallery-container .gallery-item:hover {
transform: scale(1.03);
box-shadow: 0 8px 32px rgba(102, 126, 234, 0.3);
}
/* Download button: hidden by default, shown on hover */
.gallery-container .gallery-item button.download {
opacity: 0 !important;
transition: opacity 0.25s ease !important;
position: absolute !important;
bottom: 8px !important;
right: 8px !important;
z-index: 10 !important;
background: rgba(249, 115, 22, 0.9) !important;
color: white !important;
border-radius: 8px !important;
padding: 6px 14px !important;
font-weight: 600 !important;
border: none !important;
cursor: pointer !important;
}
.gallery-container .gallery-item:hover button.download {
opacity: 1 !important;
}
/* Extract button styling */
#extract-btn {
background: linear-gradient(135deg, #f97316 0%, #ea580c 100%) !important;
border: none !important;
font-weight: 700 !important;
font-size: 1.1rem !important;
padding: 14px 0 !important;
border-radius: 12px !important;
transition: all 0.3s ease !important;
}
#extract-btn:hover {
transform: translateY(-2px) !important;
box-shadow: 0 6px 24px rgba(249, 115, 22, 0.4) !important;
}
/* Upload area */
#upload-area {
border: 2px dashed #475569 !important;
border-radius: 12px !important;
transition: border-color 0.3s ease !important;
}
#upload-area:hover {
border-color: #667eea !important;
}
/* Color picker label */
#bg-color-picker {
max-width: 200px;
}
"""
app_theme = gr.themes.Soft(
primary_hue="orange",
secondary_hue="blue",
neutral_hue="slate",
font=gr.themes.GoogleFont("Inter"),
)
def download_all_zip():
"""Package all extracted assets into a downloadable ZIP."""
zip_path = os.path.join(tempfile.gettempdir(), "extracted_assets.zip")
pngs = sorted([f for f in os.listdir(ASSET_LIBRARY_DIR) if f.endswith(".png")])
if not pngs:
return None
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
for f in pngs:
zf.write(os.path.join(ASSET_LIBRARY_DIR, f), f)
return zip_path
with gr.Blocks(title="SAM 3 Asset Extractor") as demo:
gr.Markdown("# 🎨 SAM 3 Visual Asset Extractor", elem_id="app-title")
gr.Markdown(
"Upload a presentation slide or PDF to extract all **visual elements** "
"(charts, diagrams, icons, illustrations) as **transparent PNGs** ready for "
"**video editing** — powered by Meta's SAM 3 + intelligent background removal.",
elem_id="app-subtitle"
)
with gr.Row(equal_height=False):
with gr.Column(scale=1, min_width=340):
with gr.Tabs():
with gr.Tab("🖼️ Single Image"):
input_image = gr.Image(
label="📤 Upload Slide",
type="numpy",
elem_id="upload-area",
height=300,
)
submit_btn = gr.Button(
"🔍 Extract Visual Assets",
variant="primary",
elem_id="extract-btn",
size="lg",
)
with gr.Tab("📄 PDF Batch"):
input_pdf = gr.File(
label="📤 Upload PDF",
file_types=[".pdf"],
)
pdf_btn = gr.Button(
"📄 Extract from All Pages",
variant="primary",
elem_id="extract-btn",
size="lg",
)
bg_color_input = gr.Textbox(
label="🎨 Background Color to Remove",
value="#FFFFFF",
elem_id="bg-color-picker",
info="Hex color of slide background (e.g. #FFFFFF for white)",
max_lines=1,
)
download_btn = gr.DownloadButton(
"📦 Download All as ZIP",
size="lg",
)
gr.Markdown(
"**🔍 Detects:** charts · diagrams · graphs · tables · "
"illustrations · infographics · figures · photos · "
"icons · symbols · arrows · bars · persons · badges\n\n"
"**🚫 Excludes:** text blocks · logos · watermarks",
elem_id="concept-list"
)
with gr.Column(scale=3):
output_gallery = gr.Gallery(
label="🎨 Extracted Assets — Hover to download individual PNGs",
columns=4,
object_fit="contain",
height=700,
format="png",
elem_classes=["gallery-container"],
)
submit_btn.click(
fn=extract_assets,
inputs=[input_image, bg_color_input],
outputs=[output_gallery]
)
pdf_btn.click(
fn=extract_from_pdf,
inputs=[input_pdf, bg_color_input],
outputs=[output_gallery]
)
download_btn.click(fn=download_all_zip, inputs=[], outputs=[download_btn])
if __name__ == "__main__":
auth_user = os.environ.get("APP_USERNAME", "veurone")
auth_pass = os.environ.get("APP_PASSWORD", "sam3extract")
demo.launch(css=custom_css, theme=app_theme, auth=(auth_user, auth_pass))