CodeShamza
Hotfix: Update rembg model string from bria to bria-rmbg for correct session mapping
63b7bc7
import os
import io
import base64
import tempfile
import zipfile
import numpy as np
import cv2
import torch
from PIL import Image
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
from rembg import remove, new_session
# ── App Setup ────────────────────────────────────────────────
app = FastAPI(title="Visual Asset Extractor API", version="3.1")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Netlify frontend
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ── Model Loading ────────────────────────────────────────────
print("Loading Grounding DINO Tiny...", flush=True)
MODEL_ID = "IDEA-Research/grounding-dino-tiny"
device = "cuda" if torch.cuda.is_available() else "cpu"
gd_processor = AutoProcessor.from_pretrained(MODEL_ID)
gd_model = AutoModelForZeroShotObjectDetection.from_pretrained(MODEL_ID).to(device)
gd_model.eval()
model_size_mb = sum(p.numel() * p.element_size() for p in gd_model.parameters()) / 1e6
print(f"Grounding DINO loaded on {device} ({model_size_mb:.0f} MB)", flush=True)
print("Loading RMBG-1.4 Engine via rembg...", flush=True)
bria_session = new_session("bria-rmbg")
# ── Detection Concepts ───────────────────────────────────────
# Two-pass: parent (composite) + child (individual) elements
# No text, no logo β€” video-editing focused
PARENT_CONCEPTS = [
"chart", "diagram", "graph", "table", "illustration",
"infographic", "figure", "photo", "picture",
]
CHILD_CONCEPTS = [
"icon", "symbol", "arrow", "bar", "person",
"object", "button", "badge", "circle",
]
# Grounding DINO uses a single text prompt with "." separator
ALL_CONCEPTS_TEXT = " . ".join(PARENT_CONCEPTS + CHILD_CONCEPTS) + " ."
# ── Utility Functions ────────────────────────────────────────
def box_iou(b1, b2):
"""IoU between two boxes [x0, y0, x1, y1]."""
x0 = max(b1[0], b2[0])
y0 = max(b1[1], b2[1])
x1 = min(b1[2], b2[2])
y1 = min(b1[3], b2[3])
inter = max(0, x1 - x0) * max(0, y1 - y0)
a1 = (b1[2] - b1[0]) * (b1[3] - b1[1])
a2 = (b2[2] - b2[0]) * (b2[3] - b2[1])
union = a1 + a2 - inter
return inter / union if union > 0 else 0.0
def is_notebooklm_logo(box, img_w, img_h):
"""Filter small detections in bottom-right corner (NotebookLM watermark)."""
x0, y0, x1, y1 = box
bw, bh = x1 - x0, y1 - y0
if bw < 80 and bh < 80:
cx, cy = (x0 + x1) / 2, (y0 + y1) / 2
if cx > img_w * 0.85 and cy > img_h * 0.85:
return True
return False
def trim_transparent(rgba: np.ndarray) -> np.ndarray:
"""Trim fully transparent borders from an RGBA image so the object perfectly fits the rect."""
alpha = rgba[:, :, 3]
y_non_zero, x_non_zero = np.nonzero(alpha)
if len(y_non_zero) == 0:
return rgba # Edge case: totally empty mask
top, bottom = np.min(y_non_zero), np.max(y_non_zero)
left, right = np.min(x_non_zero), np.max(x_non_zero)
return rgba[top:bottom + 1, left:right + 1]
def upscale_crisp(rgba: np.ndarray) -> np.ndarray:
"""High quality Lanczos upscale tailored to make small UI assets look extremely crisp."""
h, w = rgba.shape[:2]
max_dim = max(h, w)
# Calculate optimal resolution multiplier (ensure output is ~800px-1200px where possible)
if max_dim >= 1200:
scale = 1.0
elif max_dim * 4 <= 1200:
scale = 4.0
elif max_dim * 3 <= 1200:
scale = 3.0
elif max_dim * 2 <= 1200:
scale = 2.0
else:
scale = 1200.0 / max_dim
new_w = int(w * scale)
new_h = int(h * scale)
if scale <= 1.0:
return rgba
# Perform High-Quality Resize
upscaled = cv2.resize(rgba, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
upscaled[:, :, 3] = np.clip(upscaled[:, :, 3], 0, 255)
# Moderate unsharp masking on RGB channels (keeps anti-aliased alpha channel smooth)
rgb = upscaled[:, :, :3]
blurred = cv2.GaussianBlur(rgb, (0, 0), sigmaX=0.8)
rgb_sharp = cv2.addWeighted(rgb, 1.4, blurred, -0.4, 0)
upscaled[:, :, :3] = rgb_sharp
return upscaled
def rgba_to_base64_png(rgba: np.ndarray) -> str:
"""Convert RGBA numpy array to base64 PNG string."""
img = Image.fromarray(rgba, "RGBA")
buf = io.BytesIO()
img.save(buf, format="PNG")
return base64.b64encode(buf.getvalue()).decode("utf-8")
def detect_and_extract(image_rgb: np.ndarray, bg_color=(255, 255, 255), tolerance=30):
"""Run Grounding DINO detection β†’ flood-fill BG removal β†’ upscale.
Returns list of base64 PNG strings.
"""
h, w = image_rgb.shape[:2]
img_area = h * w
pil_img = Image.fromarray(image_rgb)
# Run Grounding DINO
inputs = gd_processor(images=pil_img, text=ALL_CONCEPTS_TEXT, return_tensors="pt").to(device)
with torch.no_grad():
outputs = gd_model(**inputs)
# Post-process: get boxes and scores above threshold
results = gd_processor.post_process_grounded_object_detection(
outputs,
inputs["input_ids"],
threshold=0.20,
text_threshold=0.20,
target_sizes=[(h, w)],
)[0]
boxes = results["boxes"].cpu().numpy() # [N, 4] as x0,y0,x1,y1
scores = results["scores"].cpu().numpy()
labels = results.get("labels", results.get("text_labels", [""] * len(boxes)))
print(f" Grounding DINO: {len(boxes)} raw detections", flush=True)
# Filter detections
kept_boxes = []
kept_scores = []
for i in range(len(boxes)):
x0, y0, x1, y1 = boxes[i]
x0, y0 = int(max(0, x0)), int(max(0, y0))
x1, y1 = int(min(w, x1)), int(min(h, y1))
bw, bh = x1 - x0, y1 - y0
box_area = bw * bh
score = float(scores[i])
if score < 0.20:
continue
if box_area < 500 or bw < 20 or bh < 20:
continue
if box_area > img_area * 0.90:
continue
if is_notebooklm_logo([x0, y0, x1, y1], w, h):
print(f" [{i}] SKIP NotebookLM logo", flush=True)
continue
pad_x = max(10, int(bw * 0.10))
pad_y = max(10, int(bh * 0.10))
bx0 = max(0, x0 - pad_x)
by0 = max(0, y0 - pad_y)
bx1 = min(w, x1 + pad_x)
by1 = min(h, y1 + pad_y)
kept_boxes.append([bx0, by0, bx1, by1])
kept_scores.append(score)
print(f" [{i}] KEPT: {labels[i]} score={score:.3f} box=[{bx0},{by0},{bx1},{by1}]", flush=True)
if not kept_boxes:
return []
# Deduplicate by box IoU
order = sorted(range(len(kept_boxes)), key=lambda i: kept_scores[i], reverse=True)
keep = []
for i in order:
dup = False
for ki in keep:
if box_iou(kept_boxes[i], kept_boxes[ki]) > 0.5:
dup = True
break
if not dup:
keep.append(i)
print(f" After dedup: {len(keep)} unique assets", flush=True)
# Run rembg Bria session on padded crop
results_b64 = []
for idx, ki in enumerate(keep):
bx0, by0, bx1, by1 = kept_boxes[ki]
crop_rgb = image_rgb[by0:by1, bx0:bx1]
# rembg returns raw RGBA image natively (PIL image if input is PIL, numpy if numpy)
# We pass crop_rgb (numpy H, W, 3) and it returns (H, W, 4)
rgba_rmbg = remove(crop_rgb, session=bria_session)
rgba = trim_transparent(rgba_rmbg)
rgba = upscale_crisp(rgba)
b64 = rgba_to_base64_png(rgba)
results_b64.append(b64)
print(f" asset[{idx}] done ({rgba.shape[1]}x{rgba.shape[0]})", flush=True)
return results_b64
# ── API Endpoints ────────────────────────────────────────────
@app.get("/health")
async def health():
return {"status": "ok", "model": MODEL_ID, "device": device}
@app.post("/extract")
async def extract(
image: UploadFile = File(...),
bg_color: str = Form("#FFFFFF"),
tolerance: int = Form(30),
):
"""Extract visual assets from a single image."""
try:
# Parse bg color
bg_hex = bg_color.lstrip("#")
try:
bg_rgb = tuple(int(bg_hex[i:i+2], 16) for i in (0, 2, 4))
except:
bg_rgb = (255, 255, 255)
# Read image
contents = await image.read()
pil_img = Image.open(io.BytesIO(contents)).convert("RGB")
img_np = np.array(pil_img)
print(f">>> /extract: {img_np.shape[1]}x{img_np.shape[0]}, bg={bg_rgb}", flush=True)
assets = detect_and_extract(img_np, bg_color=bg_rgb, tolerance=tolerance)
print(f">>> Returning {len(assets)} assets", flush=True)
return JSONResponse({"assets": assets, "count": len(assets)})
except Exception as e:
print(f">>> ERROR in /extract: {e}", flush=True)
import traceback; traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
@app.post("/extract-pdf")
async def extract_pdf(
pdf: UploadFile = File(...),
bg_color: str = Form("#FFFFFF"),
tolerance: int = Form(30),
):
"""Extract visual assets from every page of a PDF."""
try:
import fitz
bg_hex = bg_color.lstrip("#")
try:
bg_rgb = tuple(int(bg_hex[i:i+2], 16) for i in (0, 2, 4))
except:
bg_rgb = (255, 255, 255)
contents = await pdf.read()
doc = fitz.open(stream=contents, filetype="pdf")
total_pages = len(doc)
print(f">>> /extract-pdf: {total_pages} pages, bg={bg_rgb}", flush=True)
all_assets = []
for page_num in range(total_pages):
page = doc[page_num]
mat = fitz.Matrix(2.0, 2.0) # 144 DPI
pix = page.get_pixmap(matrix=mat)
img_array = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.h, pix.w, pix.n)
if pix.n == 4:
img_rgb = img_array[:, :, :3].copy()
else:
img_rgb = img_array.copy()
page_assets = detect_and_extract(img_rgb, bg_color=bg_rgb, tolerance=tolerance)
all_assets.extend(page_assets)
print(f" Page {page_num + 1}/{total_pages}: {len(page_assets)} assets", flush=True)
doc.close()
print(f">>> PDF complete: {len(all_assets)} total assets", flush=True)
return JSONResponse({"assets": all_assets, "count": len(all_assets)})
except Exception as e:
print(f">>> ERROR in /extract-pdf: {e}", flush=True)
import traceback; traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
@app.post("/download-zip")
async def download_zip(
image: UploadFile = File(...),
bg_color: str = Form("#FFFFFF"),
tolerance: int = Form(30),
):
"""Extract assets and return as a ZIP file."""
try:
bg_hex = bg_color.lstrip("#")
try:
bg_rgb = tuple(int(bg_hex[i:i+2], 16) for i in (0, 2, 4))
except:
bg_rgb = (255, 255, 255)
contents = await image.read()
pil_img = Image.open(io.BytesIO(contents)).convert("RGB")
img_np = np.array(pil_img)
assets_b64 = detect_and_extract(img_np, bg_color=bg_rgb, tolerance=tolerance)
# Build ZIP in memory
zip_buf = io.BytesIO()
with zipfile.ZipFile(zip_buf, "w", zipfile.ZIP_DEFLATED) as zf:
for i, b64 in enumerate(assets_b64):
png_bytes = base64.b64decode(b64)
zf.writestr(f"asset_{i+1:04d}.png", png_bytes)
zip_buf.seek(0)
return StreamingResponse(
zip_buf,
media_type="application/zip",
headers={"Content-Disposition": "attachment; filename=extracted_assets.zip"},
)
except Exception as e:
print(f">>> ERROR in /download-zip: {e}", flush=True)
import traceback; traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))