File size: 24,294 Bytes
2266dc1
 
 
 
 
 
 
405477b
a387aca
cb7cf0f
 
 
 
 
db52794
4b23049
cb7cf0f
 
4b23049
a9f1afc
 
cb7cf0f
 
 
 
 
4b23049
cb7cf0f
 
 
d10f06b
f8e0698
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db52794
d8802bd
 
 
 
 
 
77f13e1
 
 
 
 
d8802bd
 
 
 
 
 
 
 
 
 
 
 
 
 
2266dc1
 
 
 
09cf416
2266dc1
5ee8874
f6409a9
2266dc1
09cf416
582b3af
 
 
 
 
 
 
 
 
 
 
 
 
 
09cf416
2266dc1
 
 
 
 
 
 
 
 
 
 
 
 
 
c1dd13e
273261c
 
 
599b438
273261c
599b438
 
273261c
 
32691c0
2266dc1
19b33af
 
2266dc1
 
05dd7d6
 
 
 
 
2266dc1
05dd7d6
 
 
 
 
2266dc1
05dd7d6
b32ee3f
 
 
 
 
05dd7d6
 
 
 
 
 
 
 
 
 
 
2266dc1
05dd7d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
488f9ce
05dd7d6
488f9ce
05dd7d6
488f9ce
05dd7d6
2266dc1
ce89ae7
05dd7d6
ce89ae7
9082a6b
956b060
ce89ae7
488f9ce
956b060
 
 
 
 
ce89ae7
05dd7d6
 
 
 
 
 
 
 
 
 
 
 
 
240b539
 
05dd7d6
240b539
ff9d89b
240b539
 
ff9d89b
240b539
05dd7d6
 
 
 
 
 
 
 
 
240b539
 
 
 
 
 
05dd7d6
240b539
 
 
 
 
05dd7d6
166869f
05dd7d6
 
240b539
7d93d9d
240b539
 
488f9ce
240b539
05dd7d6
2266dc1
32ac9f9
240b539
 
 
488f9ce
240b539
 
488f9ce
 
 
b3a18cb
488f9ce
32ac9f9
05dd7d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3a18cb
 
05dd7d6
 
 
 
 
 
 
7d93d9d
05dd7d6
240b539
05dd7d6
32ac9f9
05dd7d6
2266dc1
05dd7d6
 
240b539
32ac9f9
05dd7d6
 
240b539
 
 
 
05dd7d6
240b539
 
 
 
32ac9f9
05dd7d6
 
 
240b539
b32ee3f
c32cbba
05dd7d6
 
 
 
 
 
 
 
 
b32ee3f
 
05dd7d6
b32ee3f
05dd7d6
240b539
05dd7d6
240b539
 
 
 
 
 
 
2266dc1
05dd7d6
405477b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05dd7d6
405477b
 
 
 
 
 
 
 
 
 
 
 
 
620247d
05dd7d6
620247d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05dd7d6
620247d
 
 
 
 
 
 
 
 
 
 
 
 
 
05dd7d6
620247d
05dd7d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
620247d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05dd7d6
 
 
 
620247d
 
e66cd5c
620247d
 
 
 
e66cd5c
 
b32ee3f
 
 
 
 
 
 
 
 
 
 
e66cd5c
05dd7d6
620247d
05dd7d6
 
 
620247d
 
2266dc1
620247d
05dd7d6
405477b
 
 
 
 
 
05dd7d6
405477b
 
05dd7d6
405477b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05dd7d6
525c3fb
05dd7d6
 
 
525c3fb
 
05dd7d6
 
b32ee3f
 
 
 
620247d
05dd7d6
 
 
 
620247d
 
2266dc1
620247d
 
05dd7d6
 
620247d
05dd7d6
66564c2
620247d
 
05dd7d6
 
 
 
 
 
 
 
 
 
 
b32ee3f
2266dc1
 
650d14a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
import gradio as gr
import numpy as np
import cv2
import torch
from PIL import Image
import os
import io
import fitz  # PyMuPDF

# ── UNCONDITIONAL BFloat16 → Float16 Patch for T4 Turing GPUs ────
# CRITICAL: torch.cuda.is_bf16_supported() returns True on T4 because CUDA
# can *emulate* bfloat16 in software, but the actual kernels crash on mixed
# dtype operations (linear, conv2d). We MUST patch unconditionally.
if torch.cuda.is_available():
    # 1. Intercept ALL autocast entry points to force float16
    import torch.amp.autocast_mode
    _OriginalAutocast = torch.amp.autocast_mode.autocast
    class _Fp16Autocast(_OriginalAutocast):
        def __init__(self, device_type, dtype=None, *args, **kwargs):
            # Intercept Meta's bfloat16 request and force float16 for Turing support
            if dtype == torch.bfloat16:
                dtype = torch.float16
            super().__init__(device_type, dtype=dtype, *args, **kwargs)
    
    torch.autocast = _Fp16Autocast
    torch.amp.autocast_mode.autocast = _Fp16Autocast
    if hasattr(torch.amp, 'autocast'):
        torch.amp.autocast = _Fp16Autocast
    if hasattr(torch.cuda.amp, 'autocast'):
        torch.cuda.amp.autocast = _Fp16Autocast

    # 2. Patch core Math Kernels to deterministically auto-cast mismatched float matrices natively.
    # This acts as our unbreakable "AMP Engine" that never drops state inside deep transformer blocks!
    import torch.nn.functional as F
    
    orig_linear = F.linear
    def patched_linear(input, weight, bias=None):
        if input.is_floating_point() and input.dtype != weight.dtype:
            input = input.to(weight.dtype)
        return orig_linear(input, weight, bias)
    F.linear = patched_linear
    
    orig_conv2d = F.conv2d
    def patched_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
        if input.is_floating_point() and input.dtype != weight.dtype:
            input = input.to(weight.dtype)
        return orig_conv2d(input, weight, bias, stride, padding, dilation, groups)
    F.conv2d = patched_conv2d

    # 3. Patch torchvision.ops.roi_align — Meta's geometry_encoders.py
    #    calls boxes_xyxy.float() which creates float32 while img_feats is float16.
    try:
        import torchvision.ops
        orig_roi_align = torchvision.ops.roi_align
        def patched_roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1, aligned=False):
            # Handle Tensor, list, or tuple (Meta uses .unbind() which returns a tuple!)
            if isinstance(boxes, torch.Tensor):
                if input.is_floating_point() and boxes.dtype != input.dtype:
                    boxes = boxes.to(input.dtype)
            elif isinstance(boxes, (list, tuple)):
                boxes = [b.to(input.dtype) if isinstance(b, torch.Tensor) and b.dtype != input.dtype else b for b in boxes]
            return orig_roi_align(input, boxes, output_size, spatial_scale, sampling_ratio, aligned)
        torchvision.ops.roi_align = patched_roi_align
    except ImportError:
        pass

    # 4. Patch layer_norm / group_norm — common ViT dtype mismatch points
    orig_layer_norm = F.layer_norm
    def patched_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5):
        if weight is not None and input.is_floating_point() and input.dtype != weight.dtype:
            input = input.to(weight.dtype)
        return orig_layer_norm(input, normalized_shape, weight, bias, eps)
    F.layer_norm = patched_layer_norm

# ── Ensure SAM 3 Checkpoint is downloaded ────────────────────────
# (HuggingFace Spaces can use the hf_hub_download mechanism)
from huggingface_hub import hf_hub_download

# ── HF Token Authentication ────────────────────────────────────────
print("Downloading SAM 3 model...")
hf_token = os.environ.get("HF_TOKEN")
ckpt_path = hf_hub_download(repo_id="facebook/sam3", filename="sam3.pt", token=hf_token)

# ── Monkey Patch SAM 3 CUDA Hardcoding Bug ───────────────────────
# Meta's SAM 3 repo hardcodes `device="cuda"` in many places!
# This intercepts common PyTorch tensor constructors to force "cpu" if no GPU is available.
if not torch.cuda.is_available():
    import functools
    patch_funcs = ['zeros', 'arange', 'tensor', 'ones', 'empty', 'randn', 'full', 'linspace']
    for name in patch_funcs:
        if hasattr(torch, name):
            orig_fn = getattr(torch, name)
            @functools.wraps(orig_fn)
            def patched_fn(*args, __orig_fn=orig_fn, **kwargs):
                if 'device' in kwargs and str(kwargs['device']).startswith('cuda'):
                    kwargs['device'] = 'cpu'
                return __orig_fn(*args, **kwargs)
            setattr(torch, name, patched_fn)

# ── SAM 3 Imports ────────────────────────────────────────────────
try:
    from sam3.model_builder import build_sam3_image_model
    from sam3.model.sam3_image_processor import Sam3Processor
    model_installed = True
except ImportError:
    model_installed = False
    print("SAM 3 not installed yet (will be installed by requirements.txt).")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor = None

if model_installed:
    print(f"Loading SAM 3 onto {device}...")
    model = build_sam3_image_model(checkpoint_path=ckpt_path)
    # Cast to float16 — T4 has native float16 Tensor Core acceleration.
    # bfloat16 hangs (software emulated on Turing), float32 produced zero masks.
    model.half()
    
    # Diagnostic: verify checkpoint loaded correctly
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total parameters: {total_params:,}", flush=True)
    sample_dtype = next(model.parameters()).dtype
    print(f"Model dtype: {sample_dtype}", flush=True)
        
    processor = Sam3Processor(model)
    if not torch.cuda.is_available():
        processor.device = "cpu"
    print("Model loaded successfully.")

# Two-pass concept detection: parent (composite) + child (individual) elements
# Excludes 'text block' (user doesn't want text) and 'logo' (picks up watermarks)
PARENT_CONCEPTS = [
    "chart", "diagram", "graph", "table", "illustration",
    "infographic", "figure", "photo", "picture", "image"
]
CHILD_CONCEPTS = [
    "icon", "symbol", "arrow", "bar", "person",
    "object", "button", "badge", "circle", "label"
]
ALL_CONCEPTS = PARENT_CONCEPTS + CHILD_CONCEPTS

# Persistent asset library
import tempfile, zipfile
ASSET_LIBRARY_DIR = os.path.join(tempfile.gettempdir(), "sam3_library")
os.makedirs(ASSET_LIBRARY_DIR, exist_ok=True)
asset_counter = 0

def box_iou(b1, b2):
    """IoU between two boxes [x0, y0, x1, y1]."""
    x0 = max(b1[0], b2[0])
    y0 = max(b1[1], b2[1])
    x1 = min(b1[2], b2[2])
    y1 = min(b1[3], b2[3])
    inter = max(0, x1 - x0) * max(0, y1 - y0)
    a1 = (b1[2] - b1[0]) * (b1[3] - b1[1])
    a2 = (b2[2] - b2[0]) * (b2[3] - b2[1])
    union = a1 + a2 - inter
    return inter / union if union > 0 else 0.0

def remove_color_bg(crop_rgb: np.ndarray, bg_color=(255, 255, 255), tolerance=30) -> np.ndarray:
    """Remove background by flood-filling from edges.
    
    Only removes pixels CONNECTED to the border that match bg_color.
    White/colored areas INSIDE objects are preserved.
    """
    h, w = crop_rgb.shape[:2]
    if h < 2 or w < 2:
        rgba = np.zeros((h, w, 4), dtype=np.uint8)
        rgba[:, :, :3] = crop_rgb
        rgba[:, :, 3] = 255
        return rgba
    
    # Create a mask of pixels matching the background color within tolerance
    bg = np.array(bg_color, dtype=np.float32)
    diff = np.sqrt(np.sum((crop_rgb.astype(np.float32) - bg) ** 2, axis=2))
    color_match = (diff < tolerance).astype(np.uint8) * 255
    
    # Flood fill from all border pixels to find CONNECTED background
    # Use floodFill on a padded version to handle edge connectivity
    flood_mask = np.zeros((h + 2, w + 2), dtype=np.uint8)
    bg_connected = np.zeros((h, w), dtype=np.uint8)
    
    # Seed from all border pixels that match background color
    border_seeds = []
    for x in range(w):
        if color_match[0, x]: border_seeds.append((x, 0))
        if color_match[h-1, x]: border_seeds.append((x, h-1))
    for y in range(h):
        if color_match[y, 0]: border_seeds.append((0, y))
        if color_match[y, w-1]: border_seeds.append((w-1, y))
    
    # Flood fill from each border seed
    for sx, sy in border_seeds:
        if bg_connected[sy, sx] == 0 and color_match[sy, sx]:
            flood_mask[:] = 0
            cv2.floodFill(color_match.copy(), flood_mask, (sx, sy), 128,
                         loDiff=0, upDiff=0, flags=cv2.FLOODFILL_MASK_ONLY | (8 << 8))
            # flood_mask has 1s where the fill reached (in the +1 padded area)
            bg_connected |= flood_mask[1:-1, 1:-1]
    
    # Alpha: 255 for foreground, 0 for connected background
    alpha = np.where(bg_connected > 0, np.uint8(0), np.uint8(255))
    
    # Slight edge AA: blur alpha then re-clamp interior
    alpha_f = alpha.astype(np.float32)
    alpha_blur = cv2.GaussianBlur(alpha_f, (3, 3), sigmaX=0.8)
    interior = alpha > 240
    alpha_aa = np.where(interior, 255.0, alpha_blur)
    alpha = alpha_aa.clip(0, 255).astype(np.uint8)
    
    # Build RGBA
    rgba = np.zeros((h, w, 4), dtype=np.uint8)
    rgba[:, :, :3] = crop_rgb
    rgba[:, :, 3] = alpha
    return rgba

def upscale_4x(rgba: np.ndarray) -> np.ndarray:
    """4x Lanczos upscale with unsharp masking."""
    h, w = rgba.shape[:2]
    new_w, new_h = w * 4, h * 4
    upscaled = cv2.resize(rgba, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
    
    # Unsharp mask on RGB only
    rgb = upscaled[:, :, :3]
    blurred = cv2.GaussianBlur(rgb, (0, 0), sigmaX=1.0)
    rgb_sharp = cv2.addWeighted(rgb, 1.5, blurred, -0.5, 0)
    upscaled[:, :, :3] = rgb_sharp
    return upscaled

def is_notebooklm_logo(box, img_w, img_h):
    """Filter out small detections in bottom-right corner (NotebookLM watermark)."""
    x0, y0, x1, y1 = box
    bw, bh = x1 - x0, y1 - y0
    # Skip if small AND in bottom-right 15% of image
    if bw < 80 and bh < 80:
        center_x = (x0 + x1) / 2
        center_y = (y0 + y1) / 2
        if center_x > img_w * 0.85 and center_y > img_h * 0.85:
            return True
    return False

def extract_assets(input_image, bg_color_hex="#FFFFFF"):
    import sys, traceback
    try:
        print(">>> extract_assets V2 called", flush=True)
        if input_image is None:
            gr.Info("Please upload an image first.")
            return []
        if processor is None:
            gr.Warning("Model is still loading. Please wait and try again.")
            return []
        
        # Parse background color
        bg_hex = bg_color_hex.lstrip("#")
        try:
            bg_color = tuple(int(bg_hex[i:i+2], 16) for i in (0, 2, 4))
        except:
            bg_color = (255, 255, 255)
        print(f">>> Background color: {bg_color}", flush=True)
        
        orig_rgb = input_image
        h, w = orig_rgb.shape[:2]
        img_area = h * w
        print(f">>> Image size: {w}x{h}, area: {img_area}", flush=True)
        pil_img = Image.fromarray(orig_rgb)
        
        all_boxes = []
        all_scores = []
        
        with torch.inference_mode():
            print(">>> Running set_image...", flush=True)
            state = processor.set_image(pil_img)
            print(">>> set_image complete! Running two-pass detection...", flush=True)
            
            for concept in ALL_CONCEPTS:
                print(f">>> Concept: '{concept}'...", flush=True)
                out = processor.set_text_prompt(state=state, prompt=concept)
                
                masks = out["masks"]
                scores = out["scores"]
                
                if masks is None or len(masks) == 0:
                    print(f"  [{concept}] No detections", flush=True)
                    continue
                
                if torch.is_tensor(masks): masks = masks.float().cpu().numpy()
                if torch.is_tensor(scores): scores = scores.float().cpu().numpy()
                
                print(f"  [{concept}] Found {len(masks)} masks", flush=True)
                
                for j in range(len(masks)):
                    m = masks[j]
                    while m.ndim > 2: m = m[0]
                    m_bool = m.astype(bool)
                    
                    score = float(scores[j]) if scores.ndim > 0 else float(scores)
                    
                    # Derive bounding box from mask
                    ys, xs = np.nonzero(m_bool)
                    if len(ys) == 0: continue
                    x0, y0 = int(xs.min()), int(ys.min())
                    x1, y1 = int(xs.max()), int(ys.max())
                    
                    bw, bh = x1 - x0, y1 - y0
                    box_area = bw * bh
                    
                    # Filters
                    if score < 0.1:
                        print(f"    [{j}] SKIP low score: {score:.4f}", flush=True)
                        continue
                    if box_area < 500 or bw < 20 or bh < 20:
                        print(f"    [{j}] SKIP too small: {bw}x{bh}", flush=True)
                        continue
                    if box_area > img_area * 0.90:
                        print(f"    [{j}] SKIP too large", flush=True)
                        continue
                    if is_notebooklm_logo([x0, y0, x1, y1], w, h):
                        print(f"    [{j}] SKIP NotebookLM logo position", flush=True)
                        continue
                    
                    # Add padding (8% of box size)
                    pad_x = max(8, int(bw * 0.08))
                    pad_y = max(8, int(bh * 0.08))
                    bx0 = max(0, x0 - pad_x)
                    by0 = max(0, y0 - pad_y)
                    bx1 = min(w, x1 + pad_x)
                    by1 = min(h, y1 + pad_y)
                    
                    all_boxes.append([bx0, by0, bx1, by1])
                    all_scores.append(score)
                    print(f"    [{j}] KEPT: score={score:.4f}, box=[{bx0},{by0},{bx1},{by1}]", flush=True)
        
        print(f">>> Total detections: {len(all_boxes)}", flush=True)
        
        if not all_boxes:
            gr.Info("No visual assets found. Try a different slide with more illustrations.")
            return []
        
        # Deduplicate by box IoU (keep highest score)
        order = sorted(range(len(all_boxes)), key=lambda i: all_scores[i], reverse=True)
        keep = []
        for i in order:
            dup = False
            for ki in keep:
                if box_iou(all_boxes[i], all_boxes[ki]) > 0.5:
                    dup = True
                    break
            if not dup:
                keep.append(i)
        
        print(f">>> After dedup: {len(keep)} unique assets", flush=True)
        
        # For each: crop → flood-fill BG removal → upscale → save
        results = []
        global asset_counter
        for idx, ki in enumerate(keep):
            bx0, by0, bx1, by1 = all_boxes[ki]
            crop_rgb = orig_rgb[by0:by1, bx0:bx1]
            
            # Flood-fill background removal (preserves interior fills)
            rgba = remove_color_bg(crop_rgb, bg_color=bg_color, tolerance=30)
            
            # 4x upscale
            rgba = upscale_4x(rgba)
            
            asset_counter += 1
            lib_path = os.path.join(ASSET_LIBRARY_DIR, f"asset_{asset_counter:04d}.png")
            Image.fromarray(rgba, "RGBA").save(lib_path, format="PNG")
            results.append(lib_path)
            print(f"    asset[{idx}] saved: {lib_path}", flush=True)
        
        print(f">>> Returning {len(results)} assets (library: {asset_counter})", flush=True)
        return results
        
    except Exception as e:
        print(f">>> EXCEPTION in extract_assets: {e}", flush=True)
        traceback.print_exc()
        sys.stdout.flush()
        return []

def extract_from_pdf(pdf_file, bg_color_hex="#FFFFFF", progress=gr.Progress()):
    """Process every page of a PDF through SAM 3 extraction."""
    import sys, traceback
    try:
        if pdf_file is None:
            return []
        
        pdf_path = pdf_file if isinstance(pdf_file, str) else pdf_file.name
        print(f">>> PDF upload: {pdf_path}", flush=True)
        
        doc = fitz.open(pdf_path)
        total_pages = len(doc)
        print(f">>> PDF has {total_pages} pages", flush=True)
        
        all_results = []
        for page_num in progress.tqdm(range(total_pages), desc="Processing PDF pages"):
            print(f">>> Processing page {page_num + 1}/{total_pages}...", flush=True)
            page = doc[page_num]
            mat = fitz.Matrix(2.0, 2.0)
            pix = page.get_pixmap(matrix=mat)
            img_array = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.h, pix.w, pix.n)
            if pix.n == 4:
                img_rgb = img_array[:, :, :3].copy()
            else:
                img_rgb = img_array.copy()
            
            page_results = extract_assets(img_rgb, bg_color_hex=bg_color_hex)
            all_results.extend(page_results)
            print(f">>> Page {page_num + 1}: extracted {len(page_results)} assets", flush=True)
        
        doc.close()
        print(f">>> PDF complete: {len(all_results)} total assets from {total_pages} pages", flush=True)
        return all_results
        
    except Exception as e:
        print(f">>> EXCEPTION in extract_from_pdf: {e}", flush=True)
        traceback.print_exc()
        sys.stdout.flush()
        return []

custom_css = """
/* ── Premium Dark Theme ───────────────────────────── */
.gradio-container {
    max-width: 1400px !important;
    margin: auto;
}
#app-title {
    text-align: center;
    background: linear-gradient(135deg, #667eea 0%, #f97316 100%);
    -webkit-background-clip: text;
    -webkit-text-fill-color: transparent;
    font-size: 2.2rem !important;
    font-weight: 800 !important;
    margin-bottom: 0 !important;
}
#app-subtitle {
    text-align: center;
    color: #94a3b8 !important;
    font-size: 0.95rem !important;
    margin-top: 0 !important;
}
/* Gallery with hover download */
.gallery-container {
    min-height: 650px !important;
}
.gallery-container .gallery-item {
    position: relative;
    border-radius: 12px;
    overflow: hidden;
    transition: transform 0.2s ease, box-shadow 0.2s ease;
    background: #1e293b;
}
.gallery-container .gallery-item:hover {
    transform: scale(1.03);
    box-shadow: 0 8px 32px rgba(102, 126, 234, 0.3);
}
/* Download button: hidden by default, shown on hover */
.gallery-container .gallery-item button.download {
    opacity: 0 !important;
    transition: opacity 0.25s ease !important;
    position: absolute !important;
    bottom: 8px !important;
    right: 8px !important;
    z-index: 10 !important;
    background: rgba(249, 115, 22, 0.9) !important;
    color: white !important;
    border-radius: 8px !important;
    padding: 6px 14px !important;
    font-weight: 600 !important;
    border: none !important;
    cursor: pointer !important;
}
.gallery-container .gallery-item:hover button.download {
    opacity: 1 !important;
}
/* Extract button styling */
#extract-btn {
    background: linear-gradient(135deg, #f97316 0%, #ea580c 100%) !important;
    border: none !important;
    font-weight: 700 !important;
    font-size: 1.1rem !important;
    padding: 14px 0 !important;
    border-radius: 12px !important;
    transition: all 0.3s ease !important;
}
#extract-btn:hover {
    transform: translateY(-2px) !important;
    box-shadow: 0 6px 24px rgba(249, 115, 22, 0.4) !important;
}
/* Upload area */
#upload-area {
    border: 2px dashed #475569 !important;
    border-radius: 12px !important;
    transition: border-color 0.3s ease !important;
}
#upload-area:hover {
    border-color: #667eea !important;
}
/* Color picker label */
#bg-color-picker {
    max-width: 200px;
}
"""

app_theme = gr.themes.Soft(
    primary_hue="orange",
    secondary_hue="blue",
    neutral_hue="slate",
    font=gr.themes.GoogleFont("Inter"),
)

def download_all_zip():
    """Package all extracted assets into a downloadable ZIP."""
    zip_path = os.path.join(tempfile.gettempdir(), "extracted_assets.zip")
    pngs = sorted([f for f in os.listdir(ASSET_LIBRARY_DIR) if f.endswith(".png")])
    if not pngs:
        return None
    with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
        for f in pngs:
            zf.write(os.path.join(ASSET_LIBRARY_DIR, f), f)
    return zip_path

with gr.Blocks(title="SAM 3 Asset Extractor") as demo:
    gr.Markdown("# 🎨 SAM 3 Visual Asset Extractor", elem_id="app-title")
    gr.Markdown(
        "Upload a presentation slide or PDF to extract all **visual elements** "
        "(charts, diagrams, icons, illustrations) as **transparent PNGs** ready for "
        "**video editing** — powered by Meta's SAM 3 + intelligent background removal.",
        elem_id="app-subtitle"
    )
    
    with gr.Row(equal_height=False):
        with gr.Column(scale=1, min_width=340):
            with gr.Tabs():
                with gr.Tab("🖼️ Single Image"):
                    input_image = gr.Image(
                        label="📤 Upload Slide",
                        type="numpy",
                        elem_id="upload-area",
                        height=300,
                    )
                    submit_btn = gr.Button(
                        "🔍 Extract Visual Assets",
                        variant="primary",
                        elem_id="extract-btn",
                        size="lg",
                    )
                with gr.Tab("📄 PDF Batch"):
                    input_pdf = gr.File(
                        label="📤 Upload PDF",
                        file_types=[".pdf"],
                    )
                    pdf_btn = gr.Button(
                        "📄 Extract from All Pages",
                        variant="primary",
                        elem_id="extract-btn",
                        size="lg",
                    )
            
            bg_color_input = gr.Textbox(
                label="🎨 Background Color to Remove",
                value="#FFFFFF",
                elem_id="bg-color-picker",
                info="Hex color of slide background (e.g. #FFFFFF for white)",
                max_lines=1,
            )
            
            download_btn = gr.DownloadButton(
                "📦 Download All as ZIP",
                size="lg",
            )
            gr.Markdown(
                "**🔍 Detects:** charts · diagrams · graphs · tables · "
                "illustrations · infographics · figures · photos · "
                "icons · symbols · arrows · bars · persons · badges\n\n"
                "**🚫 Excludes:** text blocks · logos · watermarks",
                elem_id="concept-list"
            )
        
        with gr.Column(scale=3):
            output_gallery = gr.Gallery(
                label="🎨 Extracted Assets — Hover to download individual PNGs",
                columns=4,
                object_fit="contain",
                height=700,
                format="png",
                elem_classes=["gallery-container"],
            )
    
    submit_btn.click(
        fn=extract_assets,
        inputs=[input_image, bg_color_input],
        outputs=[output_gallery]
    )
    pdf_btn.click(
        fn=extract_from_pdf,
        inputs=[input_pdf, bg_color_input],
        outputs=[output_gallery]
    )
    download_btn.click(fn=download_all_zip, inputs=[], outputs=[download_btn])

if __name__ == "__main__":
    auth_user = os.environ.get("APP_USERNAME", "veurone")
    auth_pass = os.environ.get("APP_PASSWORD", "sam3extract")
    demo.launch(css=custom_css, theme=app_theme, auth=(auth_user, auth_pass))