VolarisLLC commited on
Commit
0e28d43
·
verified ·
1 Parent(s): 72c550c

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +772 -1262
main.py CHANGED
@@ -1,1324 +1,834 @@
1
- import os
2
  import json
3
- import signal
4
- import sys
 
 
 
 
 
5
  from pathlib import Path
6
- from typing import List, Dict, Tuple, Optional, Sequence, Set, Any
7
- from multiprocessing import Pool, cpu_count
8
- from functools import partial
9
-
10
- import fitz # PyMuPDF (Still needed for drawing output PDF)
11
- import pypdfium2 as pdfium
 
 
 
 
 
 
12
  import torch
13
- from doclayout_yolo import YOLOv10
14
- from huggingface_hub import hf_hub_download
15
- from loguru import logger
16
- from PIL import Image
17
- import numpy as np
18
-
19
- try:
20
- import pymupdf4llm # type: ignore
21
- except ImportError: # pragma: no cover - optional dependency
22
- pymupdf4llm = None # type: ignore
23
-
24
- try:
25
- import spaces
26
- except ImportError:
27
- # Mock spaces for local execution
28
- class spaces:
29
- @staticmethod
30
- def GPU(func):
31
- return func
32
-
33
- # ----------------------------------------------------------------------
34
- # CONFIGURATION
35
- # ----------------------------------------------------------------------
36
- # DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Removed for ZeroGPU compatibility (lazy check instead)
37
-
38
- # Model options
39
- MODEL_SIZE = 1024
40
- REPO_ID = "juliozhao/DocLayout-YOLO-DocStructBench"
41
- WEIGHTS_FILE = f"doclayout_yolo_docstructbench_imgsz{MODEL_SIZE}.pt"
42
-
43
- # Detection settings
44
- CONF_THRESHOLD = 0.25
45
-
46
- # Multiprocessing settings
47
- NUM_WORKERS = None # None = auto (cpu_count - 1), or set to specific number like 4
48
- USE_MULTIPROCESSING = False # Set to False to disable parallel processing entirely (Required for ZeroGPU)
49
-
50
- # ----------------------------------------------------------------------
51
- # Color map for the layout classes
52
- # ----------------------------------------------------------------------
53
- CLASS_COLORS = {
54
- "text": (0, 128, 0), # Dark Green
55
- "title": (192, 0, 0), # Dark Red
56
- "figure": (0, 0, 192), # Dark Blue
57
- "table": (218, 165, 32), # Goldenrod (Dark Yellow)
58
- "list": (128, 0, 128), # Purple
59
- "header": (0, 128, 128), # Teal
60
- "footer": (100, 100, 100), # Dark Gray
61
- "figure_caption": (0, 0, 128), # Navy
62
- "table_caption": (139, 69, 19), # Saddle Brown
63
- "table_footnote": (128, 0, 128), # Purple
64
- }
65
-
66
- # Global model instance (will be None in worker processes until loaded)
67
- _model = None
68
- _shutdown_requested = False
69
-
70
- # ----------------------------------------------------------------------
71
- # Signal handler for graceful shutdown
72
- # ----------------------------------------------------------------------
73
- def signal_handler(signum, frame):
74
- """Handle interrupt signals gracefully."""
75
- global _shutdown_requested
76
- if not _shutdown_requested:
77
- _shutdown_requested = True
78
- logger.warning("\n⚠️ Interrupt received! Finishing current page and shutting down gracefully...")
79
- logger.warning("Press Ctrl+C again to force quit (may leave incomplete files)")
80
- else:
81
- logger.error("\n❌ Force quit requested. Exiting immediately.")
82
- sys.exit(1)
83
-
84
- def setup_signal_handlers():
85
- """Setup signal handlers for graceful shutdown."""
86
- signal.signal(signal.SIGINT, signal_handler)
87
- signal.signal(signal.SIGTERM, signal_handler)
88
-
89
- # ----------------------------------------------------------------------
90
- # Model loader function
91
- # ----------------------------------------------------------------------
92
- def get_model():
93
- """Lazy load the model (only once per process)."""
94
- global _model
95
- if _model is None:
96
- weights_path = hf_hub_download(repo_id=REPO_ID, filename=WEIGHTS_FILE)
97
- _model = YOLOv10(weights_path)
98
- logger.info(f"✓ Model loaded in worker process (PID: {os.getpid()})")
99
- return _model
100
 
101
- # ----------------------------------------------------------------------
102
- # Worker initialization function
103
- # ----------------------------------------------------------------------
104
- def init_worker():
105
- """Initialize worker process - loads model once at startup."""
106
- try:
107
- get_model()
108
- logger.success(f"Worker {os.getpid()} ready")
109
- except Exception as e:
110
- logger.error(f"Failed to initialize worker {os.getpid()}: {e}")
111
- raise
112
-
113
- # ----------------------------------------------------------------------
114
- # Run layout detection on a single page image (YOLO)
115
- # ----------------------------------------------------------------------
116
- @spaces.GPU
117
- def detect_page(pil_img: Image.Image) -> List[dict]:
118
- """Detect layout elements using YOLO model."""
119
- # Re-check device availability inside the decorated function (ZeroGPU context)
120
- device = "cuda" if torch.cuda.is_available() else "cpu"
121
- model = get_model() # Will return already-loaded model in worker
122
- img_cv = np.array(pil_img)
123
- results = model.predict(
124
- img_cv,
125
- imgsz=MODEL_SIZE,
126
- conf=CONF_THRESHOLD,
127
- device=device,
128
- verbose=False
129
- )
130
- dets = []
131
- for i, box in enumerate(results[0].boxes):
132
- cls_id = int(box.cls.item())
133
- name = results[0].names[cls_id]
134
- conf = float(box.conf.item())
135
- x0, y0, x1, y1 = box.xyxy[0].cpu().numpy().tolist()
136
- dets.append({
137
- "name": name,
138
- "bbox": [x0, y0, x1, y1],
139
- "conf": conf,
140
- "source": "yolo",
141
- "index": i
142
- })
143
- return dets
144
-
145
- # ----------------------------------------------------------------------
146
- # Crop & save figure/table regions (with captions)
147
- # ----------------------------------------------------------------------
148
- def get_union_box(box1: List[float], box2: List[float]) -> List[float]:
149
- """Get the bounding box enclosing two boxes."""
150
- x0 = min(box1[0], box2[0])
151
- y0 = min(box1[1], box2[1])
152
- x1 = max(box1[2], box2[2])
153
- y1 = max(box1[3], box2[3])
154
- return [x0, y0, x1, y1]
155
-
156
- def collect_caption_elements(
157
- element: Dict,
158
- all_dets: List[Dict],
159
- target_name: str,
160
- max_vertical_gap: float = 60.0,
161
- min_overlap: float = 0.25,
162
- ) -> List[Dict]:
163
- """
164
- Collect contiguous caption detections directly below a figure/table.
165
- """
166
- base_box = element["bbox"]
167
- base_bottom = base_box[3]
168
- selected: List[Dict] = []
169
- last_bottom = base_bottom
170
-
171
- relevant = [
172
- d for d in all_dets
173
- if d["name"] == target_name and d["bbox"][1] >= base_bottom - 5
174
- ]
175
-
176
- relevant.sort(key=lambda d: d["bbox"][1])
177
-
178
- for cand in relevant:
179
- cand_box = cand["bbox"]
180
- top = cand_box[1]
181
- if selected and top - last_bottom > max_vertical_gap:
182
- break
183
 
184
- if selected:
185
- overlap = _horizontal_overlap_ratio(selected[-1]["bbox"], cand_box)
186
- else:
187
- overlap = _horizontal_overlap_ratio(base_box, cand_box)
188
 
189
- if overlap < min_overlap:
190
- continue
 
191
 
192
- selected.append(cand)
193
- last_bottom = cand_box[3]
194
 
195
- return selected
 
 
 
 
 
196
 
197
 
198
- def collect_title_and_text_segments(
199
- element: Dict,
200
- all_dets: List[Dict],
201
- processed_indices: Set[int],
202
- settings: Optional[Dict[str, float]] = None,
203
- ) -> Tuple[List[Dict], List[Dict]]:
204
  """
205
- Locate a title below the element and any contiguous text blocks directly beneath it.
 
206
  """
207
- if settings is None:
208
- settings = TITLE_TEXT_ASSOCIATION
209
-
210
- if not element.get("bbox"):
211
- return [], []
212
 
213
- figure_box = element["bbox"]
214
- figure_bottom = figure_box[3]
215
 
216
- candidates = [
217
- d for d in all_dets
218
- if d.get("bbox") and d["index"] not in processed_indices
219
- ]
220
- candidates.sort(key=lambda d: d["bbox"][1])
221
-
222
- titles: List[Dict] = []
223
- texts: List[Dict] = []
224
-
225
- for idx, det in enumerate(candidates):
226
- if det["name"] != "title":
227
- continue
228
-
229
- title_box = det["bbox"]
230
- if title_box[1] < figure_bottom - 5:
231
- continue
232
-
233
- vertical_gap = title_box[1] - figure_bottom
234
- if vertical_gap > settings["max_title_gap"]:
235
- break
236
-
237
- overlap = _horizontal_overlap_ratio(figure_box, title_box)
238
- if overlap < settings["min_overlap"]:
239
- continue
240
-
241
- titles.append(det)
242
- last_bottom = title_box[3]
243
-
244
- for follower in candidates[idx + 1 :]:
245
- if follower["name"] == "title":
246
- break
247
- if follower["name"] != "text":
248
- continue
249
- text_box = follower["bbox"]
250
- if text_box[1] < title_box[1]:
251
- continue
252
-
253
- gap = text_box[1] - last_bottom
254
- if gap > settings["max_text_gap"]:
255
- break
256
-
257
- if _horizontal_overlap_ratio(title_box, text_box) < settings["min_overlap"]:
258
- continue
259
-
260
- texts.append(follower)
261
- last_bottom = text_box[3]
262
-
263
- break
264
-
265
- return titles, texts
266
 
 
 
 
 
 
 
 
 
267
 
268
- def save_layout_elements(pil_img: Image.Image, page_num: int,
269
- dets: List[dict], out_dir: Path) -> List[dict]:
270
- """Save figure and table crops, merging captions."""
271
- fig_dir = out_dir / "figures"
272
- tab_dir = out_dir / "tables"
273
- os.makedirs(fig_dir, exist_ok=True)
274
- os.makedirs(tab_dir, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
275
 
276
- infos = []
277
- fig_count = 0
278
- tab_count = 0
279
 
280
- processed_indices = set()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
 
282
- for i, d in enumerate(dets):
283
- if d["index"] in processed_indices:
284
- continue
 
 
 
 
 
 
 
 
285
 
286
- name = d["name"].lower()
287
- final_box = d["bbox"]
288
- caption_segments: List[Dict] = []
289
- title_segments: List[Dict] = []
290
- text_segments: List[Dict] = []
291
 
292
- if name == "figure":
293
- elem_type = "figure"
294
- path_template = fig_dir / f"page_{page_num + 1}_fig_{fig_count}.png"
295
- fig_count += 1
296
- caption_segments = collect_caption_elements(d, dets, "figure_caption")
297
- for cap in caption_segments:
298
- final_box = get_union_box(final_box, cap["bbox"])
299
- processed_indices.add(cap["index"])
300
- title_segments, text_segments = collect_title_and_text_segments(
301
- d, dets, processed_indices
302
- )
303
- for seg in title_segments + text_segments:
304
- final_box = get_union_box(final_box, seg["bbox"])
305
- processed_indices.add(seg["index"])
306
 
307
- elif name == "table":
308
- elem_type = "table"
309
- path_template = tab_dir / f"page_{page_num + 1}_tab_{tab_count}.png"
310
- tab_count += 1
311
- caption_segments = collect_caption_elements(d, dets, "table_caption")
312
- for cap in caption_segments:
313
- final_box = get_union_box(final_box, cap["bbox"])
314
- processed_indices.add(cap["index"])
315
- else:
316
- continue
317
-
318
- x0, y0, x1, y1 = map(int, final_box)
319
- crop = pil_img.crop((x0, y0, x1, y1))
320
 
321
- if crop.mode == "CMYK":
322
- crop = crop.convert("RGB")
323
-
324
- crop.save(path_template)
325
 
326
- info_data = {
327
- "type": elem_type,
328
- "page": page_num + 1,
329
- "bbox_pixels": final_box,
330
- "conf": d["conf"],
331
- "source": d.get("source", "yolo"),
332
- "image_path": str(path_template.relative_to(out_dir)),
333
- "width": int(x1 - x0),
334
- "height": int(y1 - y0),
335
- "page_width": pil_img.width,
336
- "page_height": pil_img.height,
337
- }
338
- if caption_segments:
339
- info_data["captions"] = [
340
- {
341
- "bbox": cap["bbox"],
342
- "conf": cap.get("conf"),
343
- "index": cap["index"],
344
- "source": cap.get("source"),
345
- "page": page_num + 1,
346
- }
347
- for cap in caption_segments
348
- ]
349
- if title_segments:
350
- info_data["titles"] = [
351
- {
352
- "bbox": seg["bbox"],
353
- "conf": seg.get("conf"),
354
- "index": seg["index"],
355
- "source": seg.get("source"),
356
- "page": page_num + 1,
357
- }
358
- for seg in title_segments
359
- ]
360
- if text_segments:
361
- info_data["texts"] = [
362
- {
363
- "bbox": seg["bbox"],
364
- "conf": seg.get("conf"),
365
- "index": seg["index"],
366
- "source": seg.get("source"),
367
- "page": page_num + 1,
368
- }
369
- for seg in text_segments
370
- ]
371
 
372
- infos.append(info_data)
373
-
374
- return infos
375
-
376
-
377
- TABLE_STITCH_TOLERANCES = {
378
- "x_tol": 60,
379
- "y_tol": 60,
380
- "width_tol": 120,
381
- "height_tol": 120,
382
- }
383
-
384
- CROSS_PAGE_CAPTION_THRESHOLDS = {
385
- "max_top_ratio": 0.35,
386
- "max_top_pixels": 220,
387
- "x_tol": 120,
388
- "width_tol": 200,
389
- "min_overlap": 0.05,
390
- }
391
-
392
- TITLE_TEXT_ASSOCIATION = {
393
- "max_title_gap": 220,
394
- "max_text_gap": 160,
395
- "min_overlap": 0.2,
396
- }
397
-
398
-
399
- def _horizontal_overlap_ratio(box1: List[float], box2: List[float]) -> float:
400
- """Compute horizontal overlap ratio between two bounding boxes."""
401
- x_left = max(box1[0], box2[0])
402
- x_right = min(box1[2], box2[2])
403
- overlap = max(0.0, x_right - x_left)
404
- if overlap <= 0:
405
- return 0.0
406
- width_union = max(box1[2], box2[2]) - min(box1[0], box2[0])
407
- if width_union <= 0:
408
- return 0.0
409
- return overlap / width_union
410
-
411
-
412
- def _bbox_to_rect(bbox: List[float]) -> Tuple[int, int, int, int]:
413
- """Convert [x0, y0, x1, y1] into (x, y, w, h)."""
414
- x0, y0, x1, y1 = bbox
415
- return int(x0), int(y0), int(x1 - x0), int(y1 - y0)
416
-
417
-
418
- def _open_table_image(elem: Dict, out_dir: Path) -> Optional[Image.Image]:
419
- """Open a table image relative to the output directory."""
420
- image_path = out_dir / elem["image_path"]
421
- if not image_path.exists():
422
- logger.warning(f"Missing table crop for stitching: {image_path}")
423
- return None
424
- img = Image.open(image_path)
425
- if img.mode != "RGB":
426
- img = img.convert("RGB")
427
- return img
428
-
429
-
430
- def _pad_width(img: Image.Image, target_width: int) -> Image.Image:
431
- if img.width >= target_width:
432
- return img
433
- canvas = Image.new("RGB", (target_width, img.height), color=(255, 255, 255))
434
- canvas.paste(img, (0, 0))
435
- return canvas
436
-
437
-
438
- def _pad_height(img: Image.Image, target_height: int) -> Image.Image:
439
- if img.height >= target_height:
440
- return img
441
- canvas = Image.new("RGB", (img.width, target_height), color=(255, 255, 255))
442
- canvas.paste(img, (0, 0))
443
- return canvas
444
-
445
-
446
- def _append_segment_image(
447
- base_img: Image.Image,
448
- segment_img: Image.Image,
449
- resize_to_base: bool = False,
450
- ) -> Image.Image:
451
- """Append segment image below base image with optional width alignment."""
452
- if base_img.mode != "RGB":
453
- base_img = base_img.convert("RGB")
454
- if segment_img.mode != "RGB":
455
- segment_img = segment_img.convert("RGB")
456
-
457
- if resize_to_base and segment_img.width > 0 and base_img.width > 0:
458
- segment_img = segment_img.resize(
459
- (
460
- base_img.width,
461
- max(1, int(segment_img.height * (base_img.width / segment_img.width))),
462
- ),
463
- Image.Resampling.LANCZOS,
464
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
465
 
466
- target_width = max(base_img.width, segment_img.width)
467
- base_img = _pad_width(base_img, target_width)
468
- segment_img = _pad_width(segment_img, target_width)
469
-
470
- stitched = Image.new(
471
- "RGB",
472
- (target_width, base_img.height + segment_img.height),
473
- color=(255, 255, 255),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
474
  )
475
- stitched.paste(base_img, (0, 0))
476
- stitched.paste(segment_img, (0, base_img.height))
477
- return stitched
478
 
479
 
480
- def _render_pdf_page(
481
- pdf_doc: pdfium.PdfDocument,
482
- page_index: int,
483
- scale: float,
484
- cache: Dict[int, Image.Image],
485
- ) -> Optional[Image.Image]:
486
- """Render a PDF page to a PIL image with caching."""
487
- if page_index in cache:
488
- return cache[page_index]
489
 
490
- try:
491
- page = pdf_doc[page_index]
492
- bitmap = page.render(scale=scale)
493
- pil_img = bitmap.to_pil()
494
- page.close()
495
- except Exception as exc:
496
- logger.error(f"Failed to render page {page_index + 1} for caption stitching: {exc}")
497
- return None
498
-
499
- cache[page_index] = pil_img
500
- return pil_img
501
-
502
-
503
- def _crop_pdf_region(
504
- page_img: Optional[Image.Image], bbox: List[float]
505
- ) -> Optional[Image.Image]:
506
- """Crop a region from a rendered PDF page."""
507
- if page_img is None:
508
- return None
509
-
510
- x0, y0, x1, y1 = map(int, bbox)
511
- x0 = max(0, x0)
512
- y0 = max(0, y0)
513
- x1 = min(page_img.width, max(x0 + 1, x1))
514
- y1 = min(page_img.height, max(y0 + 1, y1))
515
 
516
- if x0 >= x1 or y0 >= y1:
517
- return None
518
-
519
- crop = page_img.crop((x0, y0, x1, y1))
520
- if crop.mode == "CMYK":
521
- crop = crop.convert("RGB")
522
- return crop
523
-
524
-
525
- def write_markdown_document(pdf_path: Path, out_dir: Path) -> Optional[Path]:
526
  """
527
- Extract markdown text from a PDF using PyMuPDF4LLM and write it to disk.
528
  """
529
- if pymupdf4llm is None:
530
- logger.warning(
531
- "Skipping markdown extraction for %s because pymupdf4llm is not installed.",
532
- pdf_path.name,
533
- )
534
- return None
535
 
536
- try:
537
- markdown_content = pymupdf4llm.to_markdown(str(pdf_path))
538
- except Exception as exc:
539
- logger.error(f" Failed to create markdown for {pdf_path.name}: {exc}")
540
- return None
541
-
542
- if isinstance(markdown_content, list):
543
- markdown_content = "\n\n".join(
544
- part for part in markdown_content if isinstance(part, str)
545
- )
546
-
547
- if not isinstance(markdown_content, str):
548
- logger.error(
549
- f" Unexpected markdown output type {type(markdown_content)} for {pdf_path.name}"
550
- )
551
- return None
552
-
553
- markdown_content = markdown_content.strip()
554
- if not markdown_content:
555
- logger.warning(f" No textual content extracted from {pdf_path.name}")
556
- return None
557
-
558
- if not markdown_content.endswith("\n"):
559
- markdown_content += "\n"
560
-
561
- md_path = out_dir / f"{pdf_path.stem}.md"
562
- md_path.write_text(markdown_content, encoding="utf-8")
563
- logger.info(f" Saved markdown to {md_path.name}")
564
- return md_path
565
-
566
-
567
- def _collect_text_under_title_cross_page(
568
- title_det: Dict,
569
- sorted_dets: List[Dict],
570
- start_idx: int,
571
- page_idx: int,
572
- used_indices: Set[Tuple[int, int]],
573
- settings: Optional[Dict[str, float]] = None,
574
- ) -> List[Dict]:
575
- """Collect text elements directly below a title on the next page."""
576
- if settings is None:
577
- settings = TITLE_TEXT_ASSOCIATION
578
- texts: List[Dict] = []
579
- title_box = title_det["bbox"]
580
- last_bottom = title_box[3]
581
-
582
- for follower in sorted_dets[start_idx + 1 :]:
583
- det_index = follower.get("index")
584
- if det_index is None or (page_idx, det_index) in used_indices:
585
- continue
586
-
587
- if follower["name"] == "title":
588
- break
589
-
590
- if follower["name"] != "text":
591
- continue
592
-
593
- text_box = follower["bbox"]
594
- if text_box[1] < title_box[1]:
595
- continue
596
-
597
- gap = text_box[1] - last_bottom
598
- if gap > settings["max_text_gap"]:
599
- break
600
-
601
- if _horizontal_overlap_ratio(title_box, text_box) < settings["min_overlap"]:
602
- continue
603
-
604
- texts.append(follower)
605
- last_bottom = text_box[3]
606
-
607
- return texts
608
-
609
-
610
- def attach_cross_page_figure_captions(
611
- elements: List[Dict],
612
- all_dets: Sequence[Optional[List[Dict[str, Any]]]],
613
- pdf_bytes: bytes,
614
- out_dir: Path,
615
- scale: float,
616
- ) -> List[Dict]:
617
- """
618
- If a figure caption appears on the next page, stitch it to the prior figure.
619
- """
620
- figures = [elem for elem in elements if elem.get("type") == "figure"]
621
- if not figures or not all_dets:
622
- return elements
623
 
624
- try:
625
- pdf_doc = pdfium.PdfDocument(pdf_bytes)
626
- except Exception as exc:
627
- logger.error(f"Unable to reopen PDF for figure caption stitching: {exc}")
628
- return elements
629
-
630
- page_cache: Dict[int, Image.Image] = {}
631
- used_following_ids: Set[Tuple[int, int]] = set()
632
-
633
- # Mark existing caption/title/text detections as used
634
- for elem in figures:
635
- for key in ("captions", "titles", "texts"):
636
- for seg in elem.get(key, []) or []:
637
- idx = seg.get("index")
638
- page_no = seg.get("page")
639
- if idx is None or page_no is None:
640
- continue
641
- used_following_ids.add((page_no - 1, idx))
642
-
643
- for elem in figures:
644
- page_no = elem.get("page")
645
- bbox = elem.get("bbox_pixels")
646
- if page_no is None or bbox is None:
647
- continue
648
-
649
- current_idx = page_no - 1
650
- next_idx = current_idx + 1
651
- if next_idx >= len(all_dets):
652
- continue
653
-
654
- next_dets = all_dets[next_idx]
655
- if not next_dets:
656
- continue
657
-
658
- fig_width = bbox[2] - bbox[0]
659
- page_img = _render_pdf_page(pdf_doc, next_idx, scale, page_cache)
660
- if page_img is None:
661
- continue
662
-
663
- next_page_height = page_img.height
664
- max_top_allowed = min(
665
- CROSS_PAGE_CAPTION_THRESHOLDS["max_top_pixels"],
666
- int(next_page_height * CROSS_PAGE_CAPTION_THRESHOLDS["max_top_ratio"]),
667
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
668
 
669
- sorted_next = sorted(
670
- [det for det in next_dets if det.get("bbox")],
671
- key=lambda det: det["bbox"][1],
672
- )
673
 
674
- caption_candidate: Optional[Tuple[Dict, int]] = None
675
- caption_candidates = []
676
- for det in sorted_next:
677
- if det.get("name") != "figure_caption":
678
- continue
679
- det_index = det.get("index")
680
- if det_index is None or (next_idx, det_index) in used_following_ids:
681
- continue
682
-
683
- det_bbox = det.get("bbox")
684
- if not det_bbox or det_bbox[1] > max_top_allowed:
685
- continue
686
-
687
- overlap = _horizontal_overlap_ratio(bbox, det_bbox)
688
- x_diff = abs(bbox[0] - det_bbox[0])
689
- width_diff = abs((bbox[2] - bbox[0]) - (det_bbox[2] - det_bbox[0]))
690
-
691
- if overlap < CROSS_PAGE_CAPTION_THRESHOLDS["min_overlap"]:
692
- if (
693
- x_diff > CROSS_PAGE_CAPTION_THRESHOLDS["x_tol"]
694
- or width_diff > CROSS_PAGE_CAPTION_THRESHOLDS["width_tol"]
695
- ):
696
- continue
697
-
698
- score = width_diff + 0.5 * x_diff
699
- caption_candidates.append((score, det, det_index))
700
-
701
- if caption_candidates:
702
- caption_candidates.sort(key=lambda item: item[0])
703
- _, best_det, best_index = caption_candidates[0]
704
- caption_candidate = (best_det, best_index)
705
-
706
- title_candidate: Optional[Tuple[Dict, int]] = None
707
- title_texts: List[Dict] = []
708
- for idx_sorted, det in enumerate(sorted_next):
709
- if det.get("name") != "title":
710
- continue
711
- det_index = det.get("index")
712
- if det_index is None or (next_idx, det_index) in used_following_ids:
713
- continue
714
-
715
- det_bbox = det.get("bbox")
716
- if not det_bbox or det_bbox[1] > max_top_allowed:
717
- continue
718
-
719
- overlap = _horizontal_overlap_ratio(bbox, det_bbox)
720
- x_diff = abs(bbox[0] - det_bbox[0])
721
- if (
722
- overlap < TITLE_TEXT_ASSOCIATION["min_overlap"]
723
- and x_diff > CROSS_PAGE_CAPTION_THRESHOLDS["x_tol"]
724
- ):
725
- continue
726
-
727
- title_candidate = (det, det_index)
728
- title_texts = _collect_text_under_title_cross_page(
729
- det, sorted_next, idx_sorted, next_idx, used_following_ids
730
- )
731
- break
732
-
733
- if not caption_candidate and not title_candidate and not title_texts:
734
- continue
735
-
736
- figure_path = out_dir / elem["image_path"]
737
- if not figure_path.exists():
738
- continue
739
-
740
- figure_img = Image.open(figure_path)
741
- if figure_img.mode == "CMYK":
742
- figure_img = figure_img.convert("RGB")
743
-
744
- segments_added = False
745
-
746
- if caption_candidate:
747
- cap_det, cap_index = caption_candidate
748
- caption_crop = _crop_pdf_region(page_img, cap_det["bbox"])
749
- if caption_crop is not None:
750
- figure_img = _append_segment_image(
751
- figure_img, caption_crop, resize_to_base=True
752
- )
753
- elem.setdefault("captions", [])
754
- elem["captions"].append(
755
- {
756
- "bbox": cap_det["bbox"],
757
- "conf": cap_det.get("conf"),
758
- "index": cap_index,
759
- "source": cap_det.get("source"),
760
- "page": next_idx + 1,
761
- }
762
- )
763
- used_following_ids.add((next_idx, cap_index))
764
- segments_added = True
765
-
766
- if title_candidate:
767
- title_det, title_index = title_candidate
768
- title_crop = _crop_pdf_region(page_img, title_det["bbox"])
769
- if title_crop is not None:
770
- figure_img = _append_segment_image(figure_img, title_crop)
771
- elem.setdefault("titles", [])
772
- elem["titles"].append(
773
- {
774
- "bbox": title_det["bbox"],
775
- "conf": title_det.get("conf"),
776
- "index": title_index,
777
- "source": title_det.get("source"),
778
- "page": next_idx + 1,
779
- }
780
- )
781
- used_following_ids.add((next_idx, title_index))
782
- segments_added = True
783
-
784
- for text_det in title_texts:
785
- text_index = text_det.get("index")
786
- text_crop = _crop_pdf_region(page_img, text_det["bbox"])
787
- if text_crop is None:
788
- continue
789
- figure_img = _append_segment_image(figure_img, text_crop)
790
- elem.setdefault("texts", [])
791
- elem["texts"].append(
792
- {
793
- "bbox": text_det["bbox"],
794
- "conf": text_det.get("conf"),
795
- "index": text_index,
796
- "source": text_det.get("source"),
797
- "page": next_idx + 1,
798
- }
799
- )
800
- if text_index is not None:
801
- used_following_ids.add((next_idx, text_index))
802
- segments_added = True
803
-
804
- if not segments_added:
805
- continue
806
-
807
- figure_img.save(figure_path)
808
- elem["width"] = figure_img.width
809
- elem["height"] = figure_img.height
810
-
811
- span = elem.get("page_span")
812
- if span:
813
- if next_idx + 1 not in span:
814
- span.append(next_idx + 1)
815
- else:
816
- base_page = elem.get("page")
817
- new_span = [page for page in (base_page, next_idx + 1) if page is not None]
818
- elem["page_span"] = new_span
819
-
820
- pdf_doc.close()
821
- return elements
822
-
823
-
824
- def _stitch_table_pair(
825
- base_elem: Dict,
826
- candidate_elem: Dict,
827
- out_dir: Path,
828
- merge_index: int,
829
- stitch_type: str,
830
- ) -> Optional[Dict]:
831
- """Stitch two table crops either vertically or horizontally."""
832
- base_img = _open_table_image(base_elem, out_dir)
833
- candidate_img = _open_table_image(candidate_elem, out_dir)
834
- if base_img is None or candidate_img is None:
835
- return None
836
-
837
- tables_dir = out_dir / "tables"
838
- tables_dir.mkdir(parents=True, exist_ok=True)
839
-
840
- if stitch_type == "vertical":
841
- target_width = max(base_img.width, candidate_img.width)
842
- base_img = _pad_width(base_img, target_width)
843
- candidate_img = _pad_width(candidate_img, target_width)
844
- merged_height = base_img.height + candidate_img.height
845
- stitched = Image.new("RGB", (target_width, merged_height), color=(255, 255, 255))
846
- stitched.paste(base_img, (0, 0))
847
- stitched.paste(candidate_img, (0, base_img.height))
848
- else:
849
- target_height = max(base_img.height, candidate_img.height)
850
- base_img = _pad_height(base_img, target_height)
851
- candidate_img = _pad_height(candidate_img, target_height)
852
- merged_width = base_img.width + candidate_img.width
853
- stitched = Image.new("RGB", (merged_width, target_height), color=(255, 255, 255))
854
- stitched.paste(base_img, (0, 0))
855
- stitched.paste(candidate_img, (base_img.width, 0))
856
-
857
- merged_name = (
858
- f"page_{base_elem['page']}_to_{candidate_elem['page']}_"
859
- f"table_merged_{merge_index}.png"
860
- )
861
- merged_path = tables_dir / merged_name
862
- stitched.save(merged_path)
863
-
864
- # Remove original partial crops to avoid duplicates
865
- (out_dir / base_elem["image_path"]).unlink(missing_ok=True)
866
- (out_dir / candidate_elem["image_path"]).unlink(missing_ok=True)
867
-
868
- new_bbox = [
869
- min(base_elem["bbox_pixels"][0], candidate_elem["bbox_pixels"][0]),
870
- min(base_elem["bbox_pixels"][1], candidate_elem["bbox_pixels"][1]),
871
- max(base_elem["bbox_pixels"][2], candidate_elem["bbox_pixels"][2]),
872
- max(base_elem["bbox_pixels"][3], candidate_elem["bbox_pixels"][3]),
873
- ]
874
-
875
- merged_elem = base_elem.copy()
876
- merged_elem["page_span"] = [base_elem["page"], candidate_elem["page"]]
877
- merged_elem["box_refs"] = [
878
- {"page": base_elem["page"], "image_path": base_elem["image_path"]},
879
- {"page": candidate_elem["page"], "image_path": candidate_elem["image_path"]},
880
- ]
881
- merged_elem["bbox_pixels"] = new_bbox
882
- merged_elem["image_path"] = str(merged_path.relative_to(out_dir))
883
- merged_elem["width"] = stitched.width
884
- merged_elem["height"] = stitched.height
885
- merged_elem["page_height"] = stitched.height
886
- merged_elem["conf"] = min(
887
- base_elem.get("conf", 1.0), candidate_elem.get("conf", 1.0)
888
- )
889
- return merged_elem
890
 
891
 
892
- def merge_spanning_tables(elements: List[Dict], out_dir: Path) -> List[Dict]:
893
- """
894
- Stitch table crops that continue across adjacent pages using the heuristic
895
- from the legacy OpenCV-based extractor.
896
- """
897
- if not elements:
898
- return elements
899
-
900
- tables_by_page: Dict[int, List[Dict]] = {}
901
- non_tables: List[Dict] = []
902
-
903
- for elem in elements:
904
- if elem.get("type") != "table":
905
- non_tables.append(elem)
906
- continue
907
- page = elem.get("page")
908
- if not isinstance(page, int):
909
- non_tables.append(elem)
910
- continue
911
- tables_by_page.setdefault(page, []).append(elem)
912
-
913
- merged_results: List[Dict] = []
914
- used_next: Dict[int, set[int]] = {}
915
- merge_counter = 0
916
-
917
- for page in sorted(tables_by_page.keys()):
918
- current_tables = tables_by_page.get(page, [])
919
- next_page_tables = tables_by_page.get(page + 1, [])
920
- next_used_indices = used_next.get(page + 1, set())
921
- current_used_indices = used_next.get(page, set())
922
-
923
- for idx_current, table_elem in enumerate(current_tables):
924
- if idx_current in current_used_indices:
925
- continue
926
-
927
- if not next_page_tables:
928
- merged_results.append(table_elem)
929
- continue
930
-
931
- x, y, w, h = _bbox_to_rect(table_elem["bbox_pixels"])
932
- matched = False
933
-
934
- for idx, candidate in enumerate(next_page_tables):
935
- if idx in next_used_indices:
936
- continue
937
- if candidate.get("type") != "table":
938
- continue
939
-
940
- cx, cy, cw, ch = _bbox_to_rect(candidate["bbox_pixels"])
941
-
942
- vertical_match = (
943
- abs(x - cx) <= TABLE_STITCH_TOLERANCES["x_tol"]
944
- and abs((x + w) - (cx + cw)) <= TABLE_STITCH_TOLERANCES["width_tol"]
945
- )
946
- horizontal_match = (
947
- abs(y - cy) <= TABLE_STITCH_TOLERANCES["y_tol"]
948
- and abs((y + h) - (cy + ch))
949
- <= TABLE_STITCH_TOLERANCES["height_tol"]
950
- )
951
-
952
- stitch_type = "vertical" if vertical_match else None
953
- if not stitch_type and horizontal_match:
954
- stitch_type = "horizontal"
955
-
956
- if not stitch_type:
957
- continue
958
-
959
- merge_counter += 1
960
- merged_elem = _stitch_table_pair(
961
- table_elem, candidate, out_dir, merge_counter, stitch_type
962
- )
963
- if merged_elem is None:
964
- continue
965
-
966
- merged_results.append(merged_elem)
967
- next_used_indices.add(idx)
968
- matched = True
969
- break
970
-
971
- if not matched:
972
- merged_results.append(table_elem)
973
-
974
- used_next[page + 1] = next_used_indices
975
-
976
- merged_results.extend(non_tables)
977
- return merged_results
978
-
979
-
980
-
981
- # ----------------------------------------------------------------------
982
- # Draw layout boxes on the original PDF
983
- # ----------------------------------------------------------------------
984
- def draw_layout_pdf(pdf_bytes: bytes, all_dets: List[List[dict]],
985
- scale: float, out_path: Path):
986
- """Annotate PDF with semi-transparent bounding boxes and labels."""
987
- doc = fitz.open(stream=pdf_bytes, filetype="pdf")
988
-
989
- for page_no, dets in enumerate(all_dets):
990
- page = doc[page_no]
991
-
992
- for d in dets:
993
- rgb = CLASS_COLORS.get(d["name"], (0, 0, 0))
994
- rect = fitz.Rect([c / scale for c in d["bbox"]])
995
-
996
- border_color = [c / 255 for c in rgb]
997
- fill_color = [c / 255 for c in rgb]
998
- fill_opacity = 0.15
999
- border_width = 1.5
1000
-
1001
- page.draw_rect(
1002
- rect,
1003
- color=border_color,
1004
- fill=fill_color,
1005
- width=border_width,
1006
- overlay=True,
1007
- fill_opacity=fill_opacity
1008
- )
1009
-
1010
- label = f"{d['name']} {d['conf']:.2f}"
1011
- if d.get("source"):
1012
- label += f" [{d['source'][0].upper()}]"
1013
-
1014
- text_bg = fitz.Rect(rect.x0, rect.y0 - 10, rect.x0 + 60, rect.y0)
1015
- page.draw_rect(text_bg, color=None, fill=(1, 1, 1, 0.6), overlay=True)
1016
-
1017
- page.insert_text(
1018
- (rect.x0 + 2, rect.y0 - 8),
1019
- label,
1020
- fontsize=6.5,
1021
- color=border_color,
1022
- overlay=True
1023
- )
1024
-
1025
- doc.save(str(out_path))
1026
- doc.close()
1027
-
1028
- # ----------------------------------------------------------------------
1029
- # Process a single PDF Page (for parallel execution)
1030
- # ----------------------------------------------------------------------
1031
- def process_page(task_data: Tuple[int, bytes, float, Path, str]) -> Optional[Tuple[int, List[dict], List[dict]]]:
1032
  """
1033
- Process a single page of a PDF in a worker process.
1034
- Returns: (page_number, detections, elements) or None on failure
1035
  """
1036
- pno, pdf_bytes, scale, out_dir, pdf_name = task_data
 
 
 
 
 
 
 
 
1037
 
1038
- if _shutdown_requested:
1039
- return None
 
 
1040
 
1041
- pdf_pdfium = None
1042
  try:
1043
- pdf_pdfium = pdfium.PdfDocument(pdf_bytes)
 
 
1044
 
1045
- page = pdf_pdfium[pno]
1046
- bitmap = page.render(scale=scale)
1047
- pil = bitmap.to_pil()
1048
-
1049
- dets = detect_page(pil)
1050
- elements = save_layout_elements(pil, pno, dets, out_dir)
 
 
1051
 
1052
- page_figures = len([d for d in dets if d['name'] == 'figure'])
1053
- page_tables = len([d for d in dets if d['name'] == 'table'])
1054
- logger.info(f" [{pdf_name}] Page {pno + 1}: {page_figures} figs, {page_tables} tables")
1055
-
1056
- page.close()
1057
- pdf_pdfium.close()
 
 
 
 
 
 
 
 
 
 
1058
 
1059
- return (pno, dets, elements)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1060
 
1061
  except Exception as e:
1062
- logger.error(f"Failed to process page {pno + 1} of {pdf_name}: {e}")
1063
- if pdf_pdfium:
1064
- pdf_pdfium.close()
1065
- return None
1066
-
1067
- # ----------------------------------------------------------------------
1068
- # Process a full PDF using the persistent worker pool
1069
- # ----------------------------------------------------------------------
1070
- def process_pdf_with_pool(
1071
- pdf_path: Path,
1072
- out_dir: Path,
1073
- pool: Optional[Pool] = None,
1074
- *,
1075
- extract_images: bool = True,
1076
- extract_markdown: bool = True,
1077
- ):
1078
- """
1079
- Main processing pipeline for a PDF file.
1080
- If pool is provided, uses it. Otherwise processes serially.
1081
- """
1082
-
1083
- if _shutdown_requested:
1084
- logger.warning(f"Skipping {pdf_path.name} due to shutdown request")
1085
- return
1086
-
1087
- stem = pdf_path.stem
1088
- logger.info(f"Processing {pdf_path.name}")
1089
 
1090
- pdf_bytes = pdf_path.read_bytes()
 
 
 
 
 
 
 
 
 
1091
 
1092
- doc = None
 
 
 
 
 
 
1093
  try:
1094
- doc = pdfium.PdfDocument(pdf_bytes)
1095
- page_count = len(doc)
1096
  except Exception as e:
1097
- logger.error(f"Failed to open PDF {pdf_path.name}: {e}. Skipping.")
1098
- return
1099
- finally:
1100
- if doc is not None:
1101
- doc.close()
1102
-
1103
- scale = 2.0
1104
- all_elements: List[Dict] = []
1105
- filtered_dets: List[List[dict]] = []
1106
-
1107
- if extract_images:
1108
- all_dets: List[Optional[List[dict]]] = [None] * page_count
1109
-
1110
- if pool is not None and USE_MULTIPROCESSING:
1111
- logger.info(f" Using worker pool for {page_count} pages...")
1112
-
1113
- tasks = [
1114
- (pno, pdf_bytes, scale, out_dir, pdf_path.name)
1115
- for pno in range(page_count)
1116
- ]
1117
-
1118
- try:
1119
- results = pool.map(process_page, tasks)
1120
-
1121
- for res in results:
1122
- if res:
1123
- pno, dets, elements = res
1124
- all_dets[pno] = dets
1125
- all_elements.extend(elements)
1126
-
1127
- except KeyboardInterrupt:
1128
- logger.warning("Processing interrupted during parallel execution")
1129
- raise
1130
-
1131
- else:
1132
- logger.info("Using serial processing...")
1133
-
1134
- try:
1135
- pdf_pdfium = pdfium.PdfDocument(pdf_bytes)
1136
-
1137
- for pno in range(page_count):
1138
- if _shutdown_requested:
1139
- logger.warning(
1140
- f"Stopping at page {pno + 1}/{page_count} due to shutdown request"
1141
- )
1142
- break
1143
-
1144
- try:
1145
- logger.info(f" Processing page {pno + 1}/{page_count}")
1146
-
1147
- page = pdf_pdfium[pno]
1148
- bitmap = page.render(scale=scale)
1149
- pil = bitmap.to_pil()
1150
-
1151
- dets = detect_page(pil)
1152
- all_dets[pno] = dets
1153
-
1154
- elements = save_layout_elements(pil, pno, dets, out_dir)
1155
- all_elements.extend(elements)
1156
-
1157
- page_figures = len([d for d in dets if d["name"] == "figure"])
1158
- page_tables = len([d for d in dets if d["name"] == "table"])
1159
- logger.info(
1160
- f" Found {page_figures} figures and {page_tables} tables"
1161
- )
1162
-
1163
- page.close()
1164
-
1165
- except Exception as e:
1166
- logger.error(f"Failed to process page {pno + 1}: {e}. Skipping page.")
1167
-
1168
- pdf_pdfium.close()
1169
-
1170
- except Exception as e:
1171
- logger.error(f"Fatal error processing {pdf_path.name}: {e}")
1172
- if "pdf_pdfium" in locals() and pdf_pdfium:
1173
- pdf_pdfium.close()
1174
- return
1175
-
1176
- dets_per_page: List[Optional[List[Dict[str, Any]]]] = [
1177
- det if det is not None else None for det in all_dets
1178
- ]
1179
-
1180
- filtered_dets = [d for d in all_dets if d is not None]
1181
-
1182
- if all_elements:
1183
- all_elements = merge_spanning_tables(all_elements, out_dir)
1184
- all_elements = attach_cross_page_figure_captions(
1185
- all_elements, dets_per_page, pdf_bytes, out_dir, scale
1186
- )
1187
-
1188
- if all_elements:
1189
- content_list_path = out_dir / f"{stem}_content_list.json"
1190
- with open(content_list_path, "w", encoding="utf-8") as f:
1191
- json.dump(all_elements, f, ensure_ascii=False, indent=4)
1192
- logger.info(f" Saved {len(all_elements)} elements to JSON")
1193
-
1194
- if filtered_dets:
1195
- draw_layout_pdf(
1196
- pdf_bytes, filtered_dets, scale, out_dir / f"{stem}_layout.pdf"
1197
- )
1198
- logger.info(" Generated annotated PDF")
1199
- else:
1200
- logger.warning(f"No detections found for {stem}. Skipping layout PDF.")
1201
 
1202
- else:
1203
- logger.info(" Image extraction skipped per configuration.")
 
1204
 
1205
- markdown_path = None
1206
- if extract_markdown:
1207
- markdown_path = write_markdown_document(pdf_path, out_dir)
1208
- if markdown_path is None:
1209
- logger.warning(f" Markdown extraction yielded no content for {stem}.")
1210
-
1211
- if _shutdown_requested:
1212
- logger.warning(f"⚠️ Partial results saved for {stem} → {out_dir}")
1213
- else:
1214
- if extract_images:
1215
- logger.success(
1216
- f"✓ {stem} → {out_dir} ({len(all_elements)} elements extracted)"
1217
- )
1218
- else:
1219
- logger.success(f"✓ {stem} → {out_dir} (image extraction skipped)")
1220
-
1221
- # ----------------------------------------------------------------------
1222
- # Main
1223
- # ----------------------------------------------------------------------
1224
- if __name__ == "__main__":
1225
- # Important for multiprocessing on Windows/macOS
1226
- torch.multiprocessing.set_start_method('spawn', force=True)
1227
 
1228
- # Setup signal handlers for graceful shutdown
1229
- setup_signal_handlers()
 
 
 
 
 
 
 
1230
 
1231
- INPUT_DIR = Path("./pdfs")
1232
- OUTPUT_DIR = Path("./output")
1233
 
1234
- os.makedirs(INPUT_DIR, exist_ok=True)
1235
- os.makedirs(OUTPUT_DIR, exist_ok=True)
1236
 
1237
- pdf_files = list(INPUT_DIR.glob("*.pdf"))
1238
- if not pdf_files:
1239
- logger.warning("No PDF files found in ./pdfs")
1240
- logger.info("Please add PDF files to the ./pdfs directory")
1241
- logger.info("The script will exit gracefully. No errors occurred.")
1242
- sys.exit(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1243
 
1244
- logger.info(f"Found {len(pdf_files)} PDF file(s) to process")
1245
- logger.info(f"Settings: MODEL_SIZE={MODEL_SIZE}, CONF={CONF_THRESHOLD}")
1246
-
1247
- # Determine worker count
1248
- total_cpus = cpu_count()
1249
- if NUM_WORKERS is None:
1250
- num_workers = max(1, total_cpus - 1)
1251
- else:
1252
- num_workers = max(1, min(NUM_WORKERS, total_cpus))
1253
-
1254
- # Decide whether to use multiprocessing
1255
- # Local device check
1256
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
1257
-
1258
- use_pool = USE_MULTIPROCESSING and DEVICE == "cpu" and total_cpus >= 4
1259
-
1260
- if use_pool:
1261
- logger.info(f"🚀 Creating persistent worker pool with {num_workers} workers...")
1262
- else:
1263
- if not USE_MULTIPROCESSING:
1264
- logger.info("Multiprocessing disabled by configuration")
1265
- elif DEVICE != "cpu":
1266
- logger.info(f"Using serial GPU processing (device: {DEVICE})")
1267
- else:
1268
- logger.info(f"Using serial CPU processing (CPU count {total_cpus} too low)")
1269
-
1270
- pool = None
1271
- try:
1272
- # Create persistent pool ONCE for all PDFs
1273
- if use_pool:
1274
- pool = Pool(processes=num_workers, initializer=init_worker)
1275
- logger.success(f"✓ Worker pool ready with {num_workers} workers\n")
1276
- else:
1277
- # Load model in main process for serial execution
1278
- logger.info("Initializing model in main process...")
1279
- get_model()
1280
- logger.success(f"✓ Model loaded (device: {DEVICE})\n")
1281
-
1282
- # Process all PDFs using the same pool
1283
- for i, pdf_path in enumerate(pdf_files, 1):
1284
- if _shutdown_requested:
1285
- logger.warning(f"\nShutdown requested. Processed {i-1}/{len(pdf_files)} files.")
1286
- break
1287
-
1288
- logger.info(f"\n{'='*60}")
1289
- logger.info(f"📄 File {i}/{len(pdf_files)}: {pdf_path.name}")
1290
- logger.info(f"{'='*60}")
1291
-
1292
- sub_out = OUTPUT_DIR / pdf_path.stem
1293
- os.makedirs(sub_out, exist_ok=True)
1294
-
1295
- try:
1296
- process_pdf_with_pool(pdf_path, sub_out, pool)
1297
- except KeyboardInterrupt:
1298
- logger.warning(f"\nInterrupted while processing {pdf_path.name}")
1299
- break
1300
- except Exception as e:
1301
- logger.error(f"Error processing {pdf_path.name}: {e}")
1302
- if _shutdown_requested:
1303
- break
1304
- logger.info("Continuing with next file...")
1305
- continue
1306
-
1307
- if _shutdown_requested:
1308
- logger.warning(f"\n⚠️ Processing interrupted. Partial results saved in {OUTPUT_DIR}")
1309
- else:
1310
- logger.success(f"\n✨ All done! Results are in {OUTPUT_DIR}")
1311
-
1312
- except KeyboardInterrupt:
1313
- logger.error("\n❌ Processing interrupted by user")
1314
- sys.exit(1)
1315
- except Exception as e:
1316
- logger.error(f"\n❌ Fatal error: {e}")
1317
- sys.exit(1)
1318
- finally:
1319
- # Clean up pool if it exists
1320
- if pool is not None:
1321
- logger.info("\n🧹 Shutting down worker pool...")
1322
- pool.close()
1323
- pool.join()
1324
- logger.success("✓ Worker pool closed cleanly")
 
 
1
  import json
2
+ import os
3
+ import shutil
4
+ import shutil
5
+ import threading
6
+ import uuid
7
+ import time
8
+ import multiprocessing
9
  from pathlib import Path
10
+ from typing import Dict, List, Optional, Any
11
+ from enum import Enum
12
+ from contextlib import asynccontextmanager
13
+
14
+ from fastapi import FastAPI, Request, File, UploadFile, Form, BackgroundTasks, HTTPException
15
+ from fastapi.responses import HTMLResponse, JSONResponse, FileResponse
16
+ from fastapi.staticfiles import StaticFiles
17
+ from fastapi.middleware.cors import CORSMiddleware
18
+ from pydantic import BaseModel, Field
19
+ import re
20
+ import gradio as gr
21
+ # from werkzeug.utils import secure_filename # Removed dependency
22
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ import main as extractor
25
+ from loguru import logger
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ # --------------------------------------------------------------------------------
28
+ # CONFIGURATION
29
+ # --------------------------------------------------------------------------------
 
30
 
31
+ MAX_CONTENT_LENGTH = 500 * 1024 * 1024 # Not strictly enforced by FastAPI by default, but good to know
32
+ UPLOAD_FOLDER = Path('./uploads')
33
+ OUTPUT_FOLDER = Path('./output')
34
 
35
+ UPLOAD_FOLDER.mkdir(parents=True, exist_ok=True)
36
+ OUTPUT_FOLDER.mkdir(parents=True, exist_ok=True)
37
 
38
+ # Global model instance
39
+ _model = None
40
+ _progress_tracker: Dict[str, Dict] = {}
41
+ _progress_lock = threading.RLock()
42
+ # Global process pool
43
+ _pool = None
44
 
45
 
46
+ def secure_filename(filename: str) -> str:
 
 
 
 
 
47
  """
48
+ Sanitize filename to prevent directory traversal and special chars.
49
+ Simplistic implementation to replace werkzeug.
50
  """
51
+ filename = Path(filename).name
52
+ # Keep only alphanumeric, dots, hyphens, and underscores
53
+ filename = re.sub(r'[^a-zA-Z0-9_.-]', '_', filename)
54
+ return filename
 
55
 
 
 
56
 
57
+ def get_device_info() -> Dict[str, Any]:
58
+ """Get information about GPU/CPU availability."""
59
+ cuda_available = torch.cuda.is_available()
60
+ device = "cuda" if cuda_available else "cpu"
61
+
62
+ info = {
63
+ "device": device,
64
+ "cuda_available": cuda_available,
65
+ "device_name": None,
66
+ "device_count": 0,
67
+ }
68
+
69
+ if cuda_available:
70
+ info["device_name"] = torch.cuda.get_device_name(0)
71
+ info["device_count"] = torch.cuda.device_count()
72
+
73
+ return info
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
+ def load_model_once():
76
+ """Load the model once and cache it."""
77
+ global _model
78
+ if _model is None:
79
+ logger.info("Loading DocLayout-YOLO model...")
80
+ _model = extractor.get_model()
81
+ logger.info("Model loaded successfully")
82
+ return _model
83
 
84
+ @asynccontextmanager
85
+ async def lifespan(app: FastAPI):
86
+ """
87
+ Life span context manager for startup and shutdown events.
88
+ Initializes the multiprocessing pool for non-blocking CPU tasks.
89
+ """
90
+ global _pool
91
+ logger.info("Starting up PDF Layout Extractor...")
92
+
93
+ # Configure multiprocessing for PyTorch/CUDA
94
+ try:
95
+ multiprocessing.set_start_method('spawn', force=True)
96
+ except RuntimeError:
97
+ pass # Already set
98
+
99
+ # Initialize worker pool
100
+ # On ZeroGPU / Spaces, multiprocessing prevents GPU access and causes crashes.
101
+ # We will disable it globally as requested.
102
+ logger.info("Multiprocessing disabled for ZeroGPU compatibility.")
103
+ _pool = None
104
 
105
+ yield
 
 
106
 
107
+ # Shutdown
108
+ logger.info("Shutting down PDF Layout Extractor...")
109
+
110
+ app = FastAPI(
111
+ title="PDF Layout Extractor API",
112
+ description="A polished API for extracting layout information (text, tables, figures) from PDFs using DocLayout-YOLO.",
113
+ version="1.0.0",
114
+ lifespan=lifespan
115
+ )
116
+
117
+ # Enable CORS
118
+ app.add_middleware(
119
+ CORSMiddleware,
120
+ allow_origins=["*"],
121
+ allow_credentials=True,
122
+ allow_methods=["*"],
123
+ allow_headers=["*"],
124
+ )
125
+
126
+ # Mount Static Files
127
+ # Mount Output as Static for easy access to generated images/PDFs
128
+ app.mount("/output", StaticFiles(directory="output"), name="output")
129
+
130
+
131
+ # --------------------------------------------------------------------------------
132
+ # Pydantic Models for Response Documentation
133
+ # --------------------------------------------------------------------------------
134
+
135
+ class DeviceInfo(BaseModel):
136
+ device: str = Field(..., description="Compute device being used (e.g., 'cuda' or 'cpu').")
137
+ cuda_available: bool = Field(..., description="Whether CUDA GPU acceleration is available.")
138
+ device_name: Optional[str] = Field(None, description="Name of the GPU if available.")
139
+ device_count: int = Field(..., description="Number of GPU devices detected.")
140
+
141
+ class TaskStartResponse(BaseModel):
142
+ task_id: str = Field(..., description="Unique identifier for the background processing task.")
143
+ message: str = Field(..., description="Status message confirming start.")
144
+ total_files: int = Field(..., description="Number of PDF files accepted for processing.")
145
+
146
+ class ProcessingResult(BaseModel):
147
+ filename: str = Field(..., description="Name of the processed file.")
148
+ stem: Optional[str] = Field(None, description="Filename without extension.")
149
+ output_dir: Optional[str] = Field(None, description="Relative path to the output directory.")
150
+ figures_count: Optional[int] = Field(0, description="Total figures detected.")
151
+ tables_count: Optional[int] = Field(0, description="Total tables detected.")
152
+ elements_count: Optional[int] = Field(0, description="Total layout elements (text, tables, figures).")
153
+ annotated_pdf: Optional[str] = Field(None, description="Path to the PDF with layout bounding boxes drawn.")
154
+ markdown_path: Optional[str] = Field(None, description="Path to the extracted markdown file.")
155
+ # Extended URLs
156
+ annotated_pdf_url: Optional[str] = Field(None, description="Full URL to access the annotated PDF.")
157
+ markdown_url: Optional[str] = Field(None, description="Full URL to access the extracted markdown.")
158
+ figure_urls: Optional[List[Dict[str, Any]]] = Field(None, description="List of URLs for extracted figure images.")
159
+ table_urls: Optional[List[Dict[str, Any]]] = Field(None, description="List of URLs for extracted table images.")
160
+ error: Optional[str] = Field(None, description="Error message if processing failed.")
161
+
162
+ class ExtractionMode(str, Enum):
163
+ images = "images"
164
+ markdown = "markdown"
165
+ both = "both"
166
+
167
+ class ProgressResponse(BaseModel):
168
+ status: str = Field(..., description="Current status of the task (e.g., 'processing', 'completed').")
169
+ progress: int = Field(..., description="Overall progress percentage (0-100).")
170
+ message: str = Field(..., description="Current status message.")
171
+ results: List[ProcessingResult] = Field([], description="List of results for processed files.")
172
+ file_progress: Optional[Dict[str, int]] = Field(None, description="Progress percentage per file.")
173
+
174
+ class PDFInfo(BaseModel):
175
+ stem: str = Field(..., description="Unique identifier/stem of the PDF.")
176
+ output_dir: str = Field(..., description="Directory where results are stored.")
177
+
178
+ class PDFListResponse(BaseModel):
179
+ pdfs: List[PDFInfo] = Field(..., description="List of processed PDFs available on the server.")
180
+
181
+ # --------------------------------------------------------------------------------
182
+ # Helper Functions
183
+ # --------------------------------------------------------------------------------
184
+
185
+ def _update_task_progress(task_id: str, filename: str, file_progress: int, message: str):
186
+ """Update progress for a specific file and calculate overall progress."""
187
+ with _progress_lock:
188
+ if task_id not in _progress_tracker:
189
+ return
190
+
191
+ # Update file-specific progress
192
+ if 'file_progress' not in _progress_tracker[task_id]:
193
+ _progress_tracker[task_id]['file_progress'] = {}
194
+ _progress_tracker[task_id]['file_progress'][filename] = file_progress
195
+
196
+ # Calculate overall progress (average of all files)
197
+ file_progresses = _progress_tracker[task_id]['file_progress']
198
+ if file_progresses:
199
+ total_progress = sum(file_progresses.values()) / len(file_progresses)
200
+ _progress_tracker[task_id]['progress'] = int(total_progress)
201
+
202
+ _progress_tracker[task_id]['message'] = message
203
 
204
+ def process_file_background_task(task_id: str, file_data: bytes, filename: str, extraction_mode: str):
205
+ """
206
+ Process a single file in the background (runs in a thread pool inside FastAPI/Starlette).
207
+ Note: For heavy CPU/GPU tasks, prefer running in a separate process or queue (like Celery),
208
+ but consistent with the request to 'use FastAPI' and the previous design, this is fine
209
+ since `fastapi.BackgroundTasks` runs in a thread pool.
210
+ """
211
+ filename = secure_filename(filename)
212
+
213
+ try:
214
+ _update_task_progress(task_id, filename, 5, f'Processing {filename}...')
215
 
216
+ stem = Path(filename).stem
217
+ include_images = extraction_mode != 'markdown'
218
+ include_markdown = extraction_mode != 'images'
 
 
219
 
220
+ # Ensure upload directory exists
221
+ upload_dir = UPLOAD_FOLDER
222
+ upload_path = upload_dir / filename
223
+ upload_path.write_bytes(file_data)
 
 
 
 
 
 
 
 
 
 
224
 
225
+ _update_task_progress(task_id, filename, 15, f'Saved {filename}, preparing output...')
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
+ # Prepare output directory
228
+ output_dir = OUTPUT_FOLDER / stem
229
+ output_dir.mkdir(parents=True, exist_ok=True)
 
230
 
231
+ # Copy PDF to output directory
232
+ pdf_path = output_dir / filename
233
+ # shutil.copy caused permissions issues in some envs, renaming/moving is safer if fresh upload
234
+ # But here we might want to keep the original in uploads?
235
+ # The original code did `upload_path.rename(pdf_path)`, so let's stick to that semantics:
236
+ # Move from temp upload to output dir
237
+ if pdf_path.exists():
238
+ pdf_path.unlink()
239
+ upload_path.rename(pdf_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
 
241
+ _update_task_progress(task_id, filename, 25, f'Loading model and processing {filename}...')
242
+
243
+ # Process PDF
244
+ # Disable multiprocessing for ZeroGPU compatibility
245
+ extractor.USE_MULTIPROCESSING = False
246
+ logger.info(f"Processing {filename} (images={include_images}, markdown={include_markdown})")
247
+
248
+ # Note: When using a pool, we don't strictly need to load the model in THIS process
249
+ # unless we fallback to serial.
250
+ # But 'init_worker' loaded it in workers.
251
+
252
+ _update_task_progress(task_id, filename, 30, f'Extracting content from {filename}...')
253
+
254
+ # Use the global pool
255
+ # If _pool is None (initialization failed), main.py will fallback to serial (blocking this thread, but working)
256
+ extractor.process_pdf_with_pool(
257
+ pdf_path,
258
+ output_dir,
259
+ pool=_pool,
260
+ extract_images=include_images,
261
+ extract_markdown=include_markdown,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  )
263
+
264
+ _update_task_progress(task_id, filename, 85, f'Collecting results for {filename}...')
265
+
266
+ # Collect results
267
+ json_path = output_dir / f"{stem}_content_list.json"
268
+ elements = []
269
+ if include_images and json_path.exists():
270
+ text_content = json_path.read_text(encoding='utf-8')
271
+ if text_content.strip():
272
+ elements = json.loads(text_content)
273
+
274
+ annotated_pdf = None
275
+ if include_images:
276
+ candidate_pdf = output_dir / f"{stem}_layout.pdf"
277
+ if candidate_pdf.exists():
278
+ annotated_pdf = str(candidate_pdf.relative_to(OUTPUT_FOLDER))
279
+
280
+ markdown_path = None
281
+ if include_markdown:
282
+ candidate_md = output_dir / f"{stem}.md"
283
+ if candidate_md.exists():
284
+ markdown_path = str(candidate_md.relative_to(OUTPUT_FOLDER))
285
+
286
+ figures = [e for e in elements if e.get('type') == 'figure']
287
+ tables = [e for e in elements if e.get('type') == 'table']
288
+
289
+ result = {
290
+ 'filename': filename,
291
+ 'stem': stem,
292
+ 'output_dir': str(output_dir.relative_to(OUTPUT_FOLDER)),
293
+ 'figures_count': len(figures),
294
+ 'tables_count': len(tables),
295
+ 'elements_count': len(elements),
296
+ 'annotated_pdf': annotated_pdf,
297
+ 'markdown_path': markdown_path,
298
+ 'include_images': include_images,
299
+ 'include_markdown': include_markdown,
300
+ }
301
+
302
+ with _progress_lock:
303
+ if 'file_progress' not in _progress_tracker[task_id]:
304
+ _progress_tracker[task_id]['file_progress'] = {}
305
+ _progress_tracker[task_id]['file_progress'][filename] = 100
306
+
307
+ # Recalculate total
308
+ file_progresses = _progress_tracker[task_id]['file_progress']
309
+ if file_progresses:
310
+ total_prog = sum(file_progresses.values()) / len(file_progresses)
311
+ _progress_tracker[task_id]['progress'] = int(total_prog)
312
+
313
+ _progress_tracker[task_id]['results'].append(result)
314
+ _progress_tracker[task_id]['message'] = f'Completed processing {filename}'
315
+
316
+ # Check completion
317
+ total_files = _progress_tracker[task_id].get('total_files', 1)
318
+ completed_count = len([r for r in _progress_tracker[task_id]['results'] if 'error' not in r])
319
+ error_count = len([r for r in _progress_tracker[task_id]['results'] if 'error' in r])
320
+
321
+ if completed_count + error_count >= total_files:
322
+ _progress_tracker[task_id]['status'] = 'completed'
323
+ _progress_tracker[task_id]['progress'] = 100
324
+ _progress_tracker[task_id]['message'] = f'All {total_files} file(s) processed.'
325
 
326
+ except Exception as e:
327
+ logger.error(f"Error processing {filename}: {e}")
328
+ import traceback
329
+ logger.error(traceback.format_exc())
330
+ with _progress_lock:
331
+ _progress_tracker[task_id]['results'].append({
332
+ 'filename': filename,
333
+ 'error': str(e)
334
+ })
335
+ # Check if this was the last file
336
+ total_files = _progress_tracker[task_id].get('total_files', 1)
337
+ if len(_progress_tracker[task_id]['results']) >= total_files:
338
+ _progress_tracker[task_id]['status'] = 'completed' # Mark done even if error, so frontend stops polling
339
+ _progress_tracker[task_id]['message'] = f'Finished with errors.'
340
+
341
+
342
+ # --------------------------------------------------------------------------------
343
+ # Routes
344
+ # --------------------------------------------------------------------------------
345
+
346
+ @app.get("/api/docs", response_class=HTMLResponse, tags=["UI"], include_in_schema=False)
347
+ async def api_docs_redirect():
348
+ """Redirect legacy /api/docs to Swagger UI."""
349
+ return HTMLResponse(
350
+ """
351
+ <html>
352
+ <head>
353
+ <meta http-equiv="refresh" content="0; url=/docs" />
354
+ </head>
355
+ <body>
356
+ <p>Redirecting to <a href="/docs">/docs</a>...</p>
357
+ </body>
358
+ </html>
359
+ """
360
  )
 
 
 
361
 
362
 
363
+ @app.get("/api/device-info", response_model=DeviceInfo, tags=["System"])
364
+ async def device_info_endpoint():
365
+ """Get information about the processing device (CPU/GPU)."""
366
+ return get_device_info()
 
 
 
 
 
367
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
 
369
+ @app.post("/api/upload", response_model=TaskStartResponse, tags=["Processing"])
370
+ async def upload_files(
371
+ background_tasks: BackgroundTasks,
372
+ files: List[UploadFile] = File(...),
373
+ extraction_mode: ExtractionMode = Form(ExtractionMode.images, description="Select extraction mode: 'images' (figures/tables), 'markdown' (text), or 'both'.")
374
+ ):
 
 
 
 
375
  """
376
+ Upload one or more PDF files for background processing.
377
  """
378
+ if not files:
379
+ raise HTTPException(status_code=400, detail="No files provided")
 
 
 
 
380
 
381
+ pdf_files = [f for f in files if f.filename.lower().endswith('.pdf')]
382
+ if not pdf_files:
383
+ raise HTTPException(status_code=400, detail="No valid PDF files selected")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
 
385
+ task_id = str(uuid.uuid4())
386
+
387
+ with _progress_lock:
388
+ _progress_tracker[task_id] = {
389
+ 'status': 'processing',
390
+ 'progress': 0,
391
+ 'message': 'Starting upload...',
392
+ 'results': [],
393
+ 'total_files': len(pdf_files)
394
+ }
395
+
396
+ # Read files into memory to pass to background task (UploadFile is a stream)
397
+ # Be careful with RAM here for huge files. If too big, save to temp disk first.
398
+ # Given the original code read into RAM, we'll do the same for consistency but simpler.
399
+ for file in pdf_files:
400
+ content = await file.read()
401
+ background_tasks.add_task(
402
+ process_file_background_task,
403
+ task_id,
404
+ content,
405
+ file.filename,
406
+ extraction_mode
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407
  )
408
+
409
+ return {
410
+ "task_id": task_id,
411
+ "message": "Processing started",
412
+ "total_files": len(pdf_files)
413
+ }
414
+
415
+
416
+ @app.get("/api/progress/{task_id}", response_model=ProgressResponse, tags=["Processing"])
417
+ async def get_progress(task_id: str, request: Request):
418
+ """Check the progress of a processing task."""
419
+ with _progress_lock:
420
+ progress = _progress_tracker.get(task_id)
421
+ if not progress:
422
+ raise HTTPException(status_code=404, detail="Task not found")
423
+
424
+ # Deep copy to modify for response (adding URLs) without changing state
425
+ # Or just build the response object.
426
+ # Since we are adding computed URLs, we shouldn't modify the stored state every time.
427
+ response_data = progress.copy()
428
+
429
+ # Use request.base_url for absolute URLs
430
+ base_url = str(request.base_url).rstrip('/')
431
+ if 'hf.space' in base_url or request.headers.get("x-forwarded-proto") == "https":
432
+ base_url = base_url.replace("http://", "https://")
433
+
434
+ # Process results to add URLs
435
+ results_with_urls = []
436
+ for res in response_data.get('results', []):
437
+ res_copy = res.copy()
438
+
439
+ # Helper to make url
440
+ def make_url(rel_path):
441
+ if not rel_path: return None
442
+ # Clean windows paths to forward slashes for URLs
443
+ clean_path = str(rel_path).replace('\\', '/')
444
+ return f"{base_url}/output/{clean_path}"
445
+
446
+ res_copy['annotated_pdf_url'] = make_url(res.get('annotated_pdf'))
447
+ res_copy['markdown_url'] = make_url(res.get('markdown_path'))
448
+
449
+ # Figures and Tables URLs need to be discovered from disk if not stored
450
+ # The original code loaded JSON every time. That's a bit heavy but ensures freshness.
451
+ # Let's try to do it if stem is present.
452
+ stem = res.get('stem')
453
+ if stem:
454
+ output_dir = OUTPUT_FOLDER / stem
455
+ if output_dir.exists():
456
+ json_files = list(output_dir.glob('*_content_list.json'))
457
+ if json_files:
458
+ try:
459
+ elements = json.loads(json_files[0].read_text(encoding='utf-8'))
460
+ figures = [e for e in elements if e.get('type') == 'figure']
461
+ tables = [e for e in elements if e.get('type') == 'table']
462
+
463
+ fig_urls = []
464
+ for fig in figures:
465
+ if fig.get('image_path'):
466
+ path = Path(fig['image_path']) # relative to unique output folder usually?
467
+ # Actually in main.py it saves relative to out_dir
468
+ # so image_path is like "figures/page_1_fig_0.png"
469
+ # We need relative to "output" folder for URL
470
+ # output_dir is "output/stem_timestamp"
471
+ # so full path is "output/stem_timestamp/figures/..."
472
+ # The URL mount is /output/ -> output/
473
+
474
+ # "image_path" in JSON is relative to the specific STEM folder (implied by main.py logic)
475
+ # Wait, main.py says: "image_path": str(path_template.relative_to(out_dir))
476
+ # So yes, it is "figures/..."
477
+
478
+ full_rel_path = f"{stem}/{fig['image_path']}"
479
+ fig_urls.append({
480
+ "page": fig.get('page'),
481
+ "url": make_url(full_rel_path),
482
+ "path": full_rel_path
483
+ })
484
+ res_copy['figure_urls'] = fig_urls
485
+
486
+ tab_urls = []
487
+ for tab in tables:
488
+ if tab.get('image_path'):
489
+ full_rel_path = f"{stem}/{tab['image_path']}"
490
+ tab_urls.append({
491
+ "page": tab.get('page'),
492
+ "url": make_url(full_rel_path),
493
+ "path": full_rel_path
494
+ })
495
+ res_copy['table_urls'] = tab_urls
496
+
497
+ except Exception as e:
498
+ logger.error(f"Error reading details for {stem}: {e}")
499
+
500
+ results_with_urls.append(res_copy)
501
+
502
+ response_data['results'] = results_with_urls
503
+ return response_data
504
 
 
 
 
 
505
 
506
+ @app.get("/api/pdf-list", response_model=PDFListResponse, tags=["Retrieval"])
507
+ async def pdf_list():
508
+ """List previously processed PDFs."""
509
+ output_dir = OUTPUT_FOLDER
510
+ pdfs = []
511
+
512
+ if output_dir.exists():
513
+ for item in output_dir.iterdir():
514
+ if item.is_dir():
515
+ # Check for indicators of success
516
+ if list(item.glob('*_content_list.json')) or list(item.glob('*.md')):
517
+ pdfs.append({
518
+ 'stem': item.name,
519
+ 'output_dir': item.name # returning the name as relative dir
520
+ })
521
+ return {'pdfs': pdfs}
522
+
523
+
524
+ @app.get("/api/pdf-details/{pdf_stem}", tags=["Retrieval"])
525
+ async def pdf_details(pdf_stem: str, request: Request):
526
+ """Get detailed information about a processed PDF."""
527
+ output_dir = OUTPUT_FOLDER / pdf_stem
528
+
529
+ if not output_dir.exists():
530
+ raise HTTPException(status_code=404, detail="PDF not found")
531
+
532
+ base_url = str(request.base_url).rstrip('/')
533
+ if 'hf.space' in base_url or request.headers.get("x-forwarded-proto") == "https":
534
+ base_url = base_url.replace("http://", "https://")
535
+
536
+ def make_url(rel_path):
537
+ if not rel_path: return None
538
+ clean_path = str(rel_path).replace('\\', '/')
539
+ return f"{base_url}/output/{clean_path}"
540
+
541
+ # Load content list
542
+ json_files = list(output_dir.glob('*_content_list.json'))
543
+ elements = []
544
+ if json_files:
545
+ elements = json.loads(json_files[0].read_text(encoding='utf-8'))
546
+
547
+ figures = [e for e in elements if e.get('type') == 'figure']
548
+ tables = [e for e in elements if e.get('type') == 'table']
549
+
550
+ # PDF Layout
551
+ annotated_pdf = None
552
+ pdf_files = list(output_dir.glob('*_layout.pdf'))
553
+ if pdf_files:
554
+ annotated_pdf = f"{pdf_stem}/{pdf_files[0].name}"
555
+
556
+ # Markdown
557
+ markdown_path = None
558
+ md_files = list(output_dir.glob('*.md'))
559
+ if md_files:
560
+ markdown_path = f"{pdf_stem}/{md_files[0].name}"
561
+
562
+ # Image lists
563
+ figure_images = []
564
+ fig_dir = output_dir / 'figures'
565
+ if fig_dir.exists():
566
+ figure_images = [f"{pdf_stem}/figures/{f.name}" for f in sorted(fig_dir.glob('*.png'))]
567
+
568
+ table_images = []
569
+ tab_dir = output_dir / 'tables'
570
+ if tab_dir.exists():
571
+ table_images = [f"{pdf_stem}/tables/{f.name}" for f in sorted(tab_dir.glob('*.png'))]
572
+
573
+ return {
574
+ 'stem': pdf_stem,
575
+ 'figures': figures,
576
+ 'tables': tables,
577
+ 'figures_count': len(figures),
578
+ 'tables_count': len(tables),
579
+ 'elements_count': len(elements),
580
+ 'annotated_pdf': annotated_pdf,
581
+ 'markdown_path': markdown_path,
582
+ 'figure_images': figure_images,
583
+ 'table_images': table_images,
584
+ 'urls': {
585
+ 'annotated_pdf': make_url(annotated_pdf),
586
+ 'markdown': make_url(markdown_path),
587
+ 'figures': [make_url(img) for img in figure_images],
588
+ 'tables': [make_url(img) for img in table_images],
589
+ }
590
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
591
 
592
 
593
+ @app.post("/api/predict", tags=["Legacy"], include_in_schema=True)
594
+ async def predict(
595
+ file: UploadFile = File(...),
596
+ request: Request = None
597
+ ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
598
  """
599
+ Direct API endpoint for extracting text/tables/figures from a single PDF.
600
+ Waits for completion and returns JSON result.
601
  """
602
+ if not file.filename.lower().endswith('.pdf'):
603
+ raise HTTPException(status_code=400, detail="Invalid file type. Please upload a PDF.")
604
+
605
+ # Create unique output directory
606
+ filename = secure_filename(file.filename)
607
+ stem = Path(filename).stem
608
+ unique_id = f"{stem}_{int(time.time())}"
609
+ output_dir = OUTPUT_FOLDER / unique_id
610
+ output_dir.mkdir(parents=True, exist_ok=True)
611
 
612
+ # Save file
613
+ pdf_path = output_dir / filename
614
+ content = await file.read()
615
+ pdf_path.write_bytes(content)
616
 
 
617
  try:
618
+ # Load model logic (sync call to stay simple for this endpoint)
619
+ load_model_once()
620
+ extractor.USE_MULTIPROCESSING = False
621
 
622
+ # Process
623
+ extractor.process_pdf_with_pool(
624
+ pdf_path,
625
+ output_dir,
626
+ pool=None,
627
+ extract_images=True,
628
+ extract_markdown=True,
629
+ )
630
 
631
+ # Build Result
632
+ base_url = str(request.base_url).rstrip('/')
633
+ if 'hf.space' in base_url or request.headers.get("x-forwarded-proto") == "https":
634
+ base_url = base_url.replace("http://", "https://")
635
+
636
+ def make_url(rel_path):
637
+ return f"{base_url}/output/{unique_id}/{rel_path}"
638
+
639
+ result = {
640
+ "status": "success",
641
+ "filename": filename,
642
+ "text": "",
643
+ "tables": [],
644
+ "figures": [],
645
+ "summary": {}
646
+ }
647
 
648
+ # Text
649
+ md_path = output_dir / f"{stem}.md"
650
+ if md_path.exists():
651
+ result['text'] = md_path.read_text(encoding='utf-8')
652
+
653
+ # JSON content
654
+ json_path = output_dir / f"{stem}_content_list.json"
655
+ if json_path.exists():
656
+ elements = json.loads(json_path.read_text(encoding='utf-8'))
657
+
658
+ figures = [e for e in elements if e.get('type') == 'figure']
659
+ result['figures'] = [{
660
+ **fig,
661
+ 'image_url': make_url(fig.get('image_path')) if fig.get('image_path') else None
662
+ } for fig in figures]
663
+
664
+ tables = [e for e in elements if e.get('type') == 'table']
665
+ result['tables'] = [{
666
+ **tab,
667
+ 'image_url': make_url(tab.get('image_path')) if tab.get('image_path') else None
668
+ } for tab in tables]
669
+
670
+ result['summary'] = {
671
+ 'figures_count': len(figures),
672
+ 'tables_count': len(tables),
673
+ 'elements_count': len(elements)
674
+ }
675
+
676
+ return result
677
 
678
  except Exception as e:
679
+ logger.error(f"Error in predict: {e}")
680
+ import traceback
681
+ logger.error(traceback.format_exc())
682
+ raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
683
 
684
+
685
+ @app.post("/api/delete", tags=["Processing"])
686
+ async def delete_pdf(stem: str = Form(...)):
687
+ """Delete a processed PDF and its output directory."""
688
+ if not stem:
689
+ raise HTTPException(status_code=400, detail="Missing stem")
690
+
691
+ # Resolve output directory safely
692
+ output_root = OUTPUT_FOLDER.resolve()
693
+ target_dir = (output_root / stem).resolve()
694
 
695
+ # Prevent path traversal
696
+ if output_root not in target_dir.parents and target_dir != output_root:
697
+ raise HTTPException(status_code=400, detail="Invalid stem path")
698
+
699
+ if not target_dir.exists() or not target_dir.is_dir():
700
+ raise HTTPException(status_code=404, detail="Not found")
701
+
702
  try:
703
+ shutil.rmtree(target_dir)
704
+ return {"status": "success", "message": f"Deleted {stem}"}
705
  except Exception as e:
706
+ # Try to fix read-only files (common on Windows)
707
+ try:
708
+ import stat
709
+ def on_rm_error(func, path, exc_info):
710
+ os.chmod(path, stat.S_IWRITE)
711
+ func(path)
712
+ shutil.rmtree(target_dir, onerror=on_rm_error)
713
+ return {"status": "success", "message": f"Deleted {stem}"}
714
+ except Exception as e2:
715
+ logger.error(f"Error deleting {stem}: {e2}")
716
+ raise HTTPException(status_code=500, detail=f"Failed to delete: {str(e2)}")
717
+
718
+
719
+ # --------------------------------------------------------------------------------
720
+ # Gradio Interface
721
+ # --------------------------------------------------------------------------------
722
+
723
+ def gradio_process(pdf_file, mode_str):
724
+ """
725
+ Wrapper for Gradio to call the extractor logic.
726
+ """
727
+ if pdf_file is None:
728
+ return None, None, None, "No file uploaded."
729
+
730
+ try:
731
+ # Create unique directory
732
+ filename = secure_filename(Path(pdf_file.name).name)
733
+ stem = Path(filename).stem
734
+ unique_id = f"{stem}_{int(time.time())}"
735
+ output_dir = OUTPUT_FOLDER / unique_id
736
+ output_dir.mkdir(parents=True, exist_ok=True)
737
+
738
+ # Copy file
739
+ dest_path = output_dir / filename
740
+ shutil.copy(pdf_file.name, dest_path)
741
+
742
+ # Determine flags
743
+ include_images = (mode_str != "markdown")
744
+ include_markdown = (mode_str != "images")
745
+
746
+ # Process using the multiprocessing pool for speed
747
+ # The global pool is already initialized in lifespan
748
+ extractor.USE_MULTIPROCESSING = False
749
+
750
+ extractor.process_pdf_with_pool(
751
+ dest_path,
752
+ output_dir,
753
+ pool=None, # Use the global pool instead of None
754
+ extract_images=include_images,
755
+ extract_markdown=include_markdown
756
+ )
757
+
758
+ # Collect outputs
759
+ md_text = ""
760
+ md_path = output_dir / f"{stem}.md"
761
+ if md_path.exists():
762
+ md_text = md_path.read_text(encoding='utf-8')
763
+
764
+ annotated_pdf = None
765
+ pdf_layout_path = output_dir / f"{stem}_layout.pdf"
766
+ if pdf_layout_path.exists():
767
+ annotated_pdf = str(pdf_layout_path)
768
+
769
+ gallery = []
770
+ if include_images:
771
+ fig_dir = output_dir / 'figures'
772
+ if fig_dir.exists():
773
+ gallery.extend([str(p) for p in fig_dir.glob('*.png')])
774
+ tab_dir = output_dir / 'tables'
775
+ if tab_dir.exists():
776
+ gallery.extend([str(p) for p in tab_dir.glob('*.png')])
777
+
778
+ return md_text, gallery, annotated_pdf, f"Processed {filename} successfully."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
779
 
780
+ except Exception as e:
781
+ logger.error(f"Gradio Error: {e}")
782
+ return str(e), None, None, f"Error: {e}"
783
 
784
+ # Define Gradio App
785
+ with gr.Blocks(title="PDF Layout Extractor") as demo:
786
+ gr.Markdown("# PDF Layout Extractor")
787
+ gr.Markdown("Upload a PDF to extract text (Markdown), figures, tables, and visualization.")
788
+
789
+ with gr.Row():
790
+ with gr.Column():
791
+ input_pdf = gr.File(label="Upload PDF", file_types=[".pdf"])
792
+ mode_input = gr.Radio(["both", "images", "markdown"], label="Extraction Mode", value="both")
793
+ process_btn = gr.Button("Extract Layout", variant="primary")
794
+
795
+ with gr.Column():
796
+ status_msg = gr.Textbox(label="Status", interactive=False)
797
+ output_md = gr.Code(label="Extracted Simple Markdown", language="markdown")
 
 
 
 
 
 
 
 
798
 
799
+ with gr.Row():
800
+ output_pdf = gr.File(label="Annotated PDF Layout")
801
+ output_gallery = gr.Gallery(label="Extracted Images (Figures/Tables)")
802
+
803
+ process_btn.click(
804
+ fn=gradio_process,
805
+ inputs=[input_pdf, mode_input],
806
+ outputs=[output_md, output_gallery, output_pdf, status_msg]
807
+ )
808
 
809
+ # Enable queueing for better stability and performance on Spaces
810
+ demo.queue(default_concurrency_limit=5)
811
 
 
 
812
 
813
+ # --------------------------------------------------------------------------------
814
+ # Integrate Gradio with FastAPI
815
+ # --------------------------------------------------------------------------------
816
+ # Mount Gradio at /gradio path (this ensures static files work correctly)
817
+ app = gr.mount_gradio_app(
818
+ app,
819
+ demo,
820
+ path="/gradio",
821
+ allowed_paths=["./output", "./uploads"],
822
+ ssr_mode=False
823
+ )
824
+
825
+ # Redirect root to Gradio interface
826
+ @app.get("/", response_class=HTMLResponse, include_in_schema=False)
827
+ async def root_redirect():
828
+ """Redirect to Gradio interface."""
829
+ return HTMLResponse('<meta http-equiv="refresh" content="0; url=/gradio/" />')
830
+
831
+ if __name__ == "__main__":
832
+ import uvicorn
833
+ uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False)
834