| import fitz |
| import numpy as np |
| import cv2 |
| import torch |
| import torch.serialization |
|
|
| _original_torch_load = torch.load |
|
|
|
|
| def patched_torch_load(*args, **kwargs): |
| |
| kwargs["weights_only"] = False |
| return _original_torch_load(*args, **kwargs) |
|
|
|
|
| torch.load = patched_torch_load |
|
|
| import json |
| import argparse |
| import os |
| import re |
|
|
| import torch.nn as nn |
| from TorchCRF import CRF |
| |
|
|
| from typing import List, Dict, Any, Optional, Union, Tuple |
| from ultralytics import YOLO |
| import glob |
| from PIL import Image |
|
|
| import sys |
| import io |
| import base64 |
| import tempfile |
| import time |
| import shutil |
|
|
| import logging |
|
|
|
|
| |
| |
| |
|
|
| logging.basicConfig(level=logging.WARNING) |
|
|
|
|
|
|
|
|
| |
| |
| |
|
|
|
|
| |
| WEIGHTS_PATH = 'best.pt' |
|
|
|
|
| |
| OCR_JSON_OUTPUT_DIR = './ocr_json_output_final' |
| FIGURE_EXTRACTION_DIR = './figure_extraction' |
| TEMP_IMAGE_DIR = './temp_pdf_images' |
|
|
| |
| CONF_THRESHOLD = 0.2 |
| TARGET_CLASSES = ['figure', 'equation'] |
| IOU_MERGE_THRESHOLD = 0.4 |
| IOA_SUPPRESSION_THRESHOLD = 0.7 |
| LINE_TOLERANCE = 15 |
|
|
|
|
| |
| GLOBAL_FIGURE_COUNT = 0 |
| GLOBAL_EQUATION_COUNT = 0 |
|
|
|
|
|
|
|
|
| |
| |
| |
|
|
| class OCRCache: |
| """Caches OCR results per page to avoid redundant Tesseract runs.""" |
|
|
| def __init__(self): |
| self.cache = {} |
|
|
| def get_key(self, pdf_path: str, page_num: int) -> str: |
| return f"{pdf_path}:{page_num}" |
|
|
| def has_ocr(self, pdf_path: str, page_num: int) -> bool: |
| return self.get_key(pdf_path, page_num) in self.cache |
|
|
| def get_ocr(self, pdf_path: str, page_num: int) -> Optional[list]: |
| return self.cache.get(self.get_key(pdf_path, page_num)) |
|
|
| def set_ocr(self, pdf_path: str, page_num: int, ocr_data: list): |
| self.cache[self.get_key(pdf_path, page_num)] = ocr_data |
|
|
| def clear(self): |
| self.cache.clear() |
|
|
|
|
| |
| _ocr_cache = OCRCache() |
|
|
|
|
| |
| |
| |
|
|
| def calculate_iou(box1, box2): |
| x1_a, y1_a, x2_a, y2_a = box1 |
| x1_b, y1_b, x2_b, y2_b = box2 |
| x_left = max(x1_a, x1_b) |
| y_top = max(y1_a, y1_b) |
| x_right = min(x2_a, x2_b) |
| y_bottom = min(y2_a, y2_b) |
| intersection_area = max(0, x_right - x_left) * max(0, y_bottom - y_top) |
| box_a_area = (x2_a - x1_a) * (y2_a - y1_a) |
| box_b_area = (x2_b - x1_b) * (y2_b - y1_b) |
| union_area = float(box_a_area + box_b_area - intersection_area) |
| return intersection_area / union_area if union_area > 0 else 0 |
|
|
|
|
| def calculate_ioa(box1, box2): |
| x1_a, y1_a, x2_a, y2_a = box1 |
| x1_b, y1_b, x2_b, y2_b = box2 |
| x_left = max(x1_a, x1_b) |
| y_top = max(y1_a, y1_b) |
| x_right = min(x2_a, x2_b) |
| y_bottom = min(y2_a, y2_b) |
| intersection_area = max(0, x_right - x_left) * max(0, y_bottom - y_top) |
| box_a_area = (x2_a - x1_a) * (y2_a - y1_a) |
| return intersection_area / box_a_area if box_a_area > 0 else 0 |
|
|
|
|
| def filter_nested_boxes(detections, ioa_threshold=0.80): |
| """ |
| Removes boxes that are inside larger boxes (Containment Check). |
| Prioritizes keeping the LARGEST box (the 'parent' container). |
| """ |
| if not detections: |
| return [] |
|
|
| |
| for d in detections: |
| x1, y1, x2, y2 = d['coords'] |
| d['area'] = (x2 - x1) * (y2 - y1) |
|
|
| |
| |
| detections.sort(key=lambda x: x['area'], reverse=True) |
|
|
| keep_indices = [] |
| is_suppressed = [False] * len(detections) |
|
|
| for i in range(len(detections)): |
| if is_suppressed[i]: continue |
|
|
| keep_indices.append(i) |
| box_a = detections[i]['coords'] |
|
|
| |
| for j in range(i + 1, len(detections)): |
| if is_suppressed[j]: continue |
|
|
| box_b = detections[j]['coords'] |
|
|
| |
| x_left = max(box_a[0], box_b[0]) |
| y_top = max(box_a[1], box_b[1]) |
| x_right = min(box_a[2], box_b[2]) |
| y_bottom = min(box_a[3], box_b[3]) |
|
|
| if x_right < x_left or y_bottom < y_top: |
| intersection = 0 |
| else: |
| intersection = (x_right - x_left) * (y_bottom - y_top) |
|
|
| |
| area_b = detections[j]['area'] |
|
|
| if area_b > 0: |
| ioa_small = intersection / area_b |
|
|
| |
| if ioa_small > ioa_threshold: |
| is_suppressed[j] = True |
| |
|
|
| return [detections[i] for i in keep_indices] |
|
|
|
|
| def merge_overlapping_boxes(detections, iou_threshold): |
| if not detections: return [] |
| detections.sort(key=lambda d: d['conf'], reverse=True) |
| merged_detections = [] |
| is_merged = [False] * len(detections) |
| for i in range(len(detections)): |
| if is_merged[i]: continue |
| current_box = detections[i]['coords'] |
| current_class = detections[i]['class'] |
| merged_x1, merged_y1, merged_x2, merged_y2 = current_box |
| for j in range(i + 1, len(detections)): |
| if is_merged[j] or detections[j]['class'] != current_class: continue |
| other_box = detections[j]['coords'] |
| iou = calculate_iou(current_box, other_box) |
| if iou > iou_threshold: |
| merged_x1 = min(merged_x1, other_box[0]) |
| merged_y1 = min(merged_y1, other_box[1]) |
| merged_x2 = max(merged_x2, other_box[2]) |
| merged_y2 = max(merged_y2, other_box[3]) |
| is_merged[j] = True |
| merged_detections.append({ |
| 'coords': (merged_x1, merged_y1, merged_x2, merged_y2), |
| 'y1': merged_y1, 'class': current_class, 'conf': detections[i]['conf'] |
| }) |
| return merged_detections |
|
|
|
|
| def merge_yolo_into_word_data(raw_word_data: list, yolo_detections: list, scale_factor: float) -> list: |
| """ |
| Filters out raw words that are inside YOLO boxes and replaces them with |
| a single solid 'placeholder' block for the column detector. |
| """ |
| if not yolo_detections: |
| return raw_word_data |
|
|
| |
| pdf_space_boxes = [] |
| for det in yolo_detections: |
| x1, y1, x2, y2 = det['coords'] |
| pdf_box = ( |
| x1 / scale_factor, |
| y1 / scale_factor, |
| x2 / scale_factor, |
| y2 / scale_factor |
| ) |
| pdf_space_boxes.append(pdf_box) |
|
|
| |
| cleaned_word_data = [] |
| for word_tuple in raw_word_data: |
| wx1, wy1, wx2, wy2 = word_tuple[1], word_tuple[2], word_tuple[3], word_tuple[4] |
| w_center_x = (wx1 + wx2) / 2 |
| w_center_y = (wy1 + wy2) / 2 |
|
|
| is_inside_yolo = False |
| for px1, py1, px2, py2 in pdf_space_boxes: |
| if px1 <= w_center_x <= px2 and py1 <= w_center_y <= py2: |
| is_inside_yolo = True |
| break |
|
|
| if not is_inside_yolo: |
| cleaned_word_data.append(word_tuple) |
|
|
| |
| for i, (px1, py1, px2, py2) in enumerate(pdf_space_boxes): |
| dummy_entry = (f"BLOCK_{i}", px1, py1, px2, py2) |
| cleaned_word_data.append(dummy_entry) |
|
|
| return cleaned_word_data |
|
|
|
|
| |
| |
| |
|
|
|
|
|
|
|
|
| def run_single_pdf_preprocessing(pdf_path: str, preprocessed_json_path: str) -> Optional[str]: |
| global GLOBAL_FIGURE_COUNT, GLOBAL_EQUATION_COUNT |
|
|
| GLOBAL_FIGURE_COUNT = 0 |
| GLOBAL_EQUATION_COUNT = 0 |
| _ocr_cache.clear() |
|
|
| print("\n" + "=" * 80) |
| print("--- 1. STARTING OPTIMIZED YOLO/OCR PREPROCESSING PIPELINE ---") |
| print("=" * 80) |
|
|
| if not os.path.exists(pdf_path): |
| print(f"β FATAL ERROR: Input PDF not found at {pdf_path}.") |
| return None |
|
|
| os.makedirs(os.path.dirname(preprocessed_json_path), exist_ok=True) |
| os.makedirs(FIGURE_EXTRACTION_DIR, exist_ok=True) |
|
|
| model = YOLO(WEIGHTS_PATH) |
| pdf_name = os.path.splitext(os.path.basename(pdf_path))[0] |
|
|
| try: |
| doc = fitz.open(pdf_path) |
| print(f"β
Opened PDF: {pdf_name} ({doc.page_count} pages)") |
| except Exception as e: |
| print(f"β ERROR loading PDF file: {e}") |
| return None |
|
|
| all_pages_data = [] |
| total_pages_processed = 0 |
| mat = fitz.Matrix(2.0, 2.0) |
|
|
| print("\n[STEP 1.2: ITERATING PAGES - IN-MEMORY PROCESSING]") |
|
|
| for page_num_0_based in range(doc.page_count): |
| page_num = page_num_0_based + 1 |
| print(f" -> Processing Page {page_num}/{doc.page_count}...") |
|
|
| fitz_page = doc.load_page(page_num_0_based) |
|
|
| try: |
| pix = fitz_page.get_pixmap(matrix=mat) |
| original_img = pixmap_to_numpy(pix) |
| except Exception as e: |
| print(f" β Error converting page {page_num} to image: {e}") |
| continue |
|
|
| final_output, page_separator_x = preprocess_and_ocr_page( |
| original_img, |
| model, |
| pdf_path, |
| page_num, |
| fitz_page, |
| pdf_name |
| ) |
|
|
| if final_output is not None: |
| page_data = { |
| "page_number": page_num, |
| "data": final_output, |
| "column_separator_x": page_separator_x |
| } |
| all_pages_data.append(page_data) |
| total_pages_processed += 1 |
| else: |
| print(f" β Skipped page {page_num} due to processing error.") |
|
|
| doc.close() |
|
|
| if all_pages_data: |
| try: |
| with open(preprocessed_json_path, 'w') as f: |
| json.dump(all_pages_data, f, indent=4) |
| print(f"\n β
Combined structured OCR JSON saved to: {os.path.basename(preprocessed_json_path)}") |
| except Exception as e: |
| print(f"β ERROR saving combined JSON output: {e}") |
| return None |
| else: |
| print("β WARNING: No page data generated. Halting pipeline.") |
| return None |
|
|
| print("\n" + "=" * 80) |
| print(f"--- YOLO/OCR PREPROCESSING COMPLETE ({total_pages_processed} pages processed) ---") |
| print("=" * 80) |
|
|
| return preprocessed_json_path |
|
|
|
|
|
|
|
|
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Complete Pipeline") |
| parser.add_argument("--input_pdf", type=str, required=True, help="Input PDF") |
| parser.add_argument("--layoutlmv3_model_path", type=str, default=DEFAULT_LAYOUTLMV3_MODEL_PATH, help="Model Path") |
|
|
| |
| parser.add_argument("--raw_preds_path", type=str, default='BIO_debug.json', |
| help="Debug path for raw BIO tag predictions (JSON).") |
| |
| args = parser.parse_args() |
|
|
| pdf_name = os.path.splitext(os.path.basename(args.input_pdf))[0] |
| final_output_path = os.path.abspath(f"{pdf_name}_final_output_embedded.json") |
|
|
| |
| |
| |
| |
|
|
| |
| final_json_data = run_document_pipeline( |
| args.input_pdf, |
| args.layoutlmv3_model_path) |
| |
|
|
| |
| if final_json_data: |
| |
| |
| |
| json_str = json.dumps(final_json_data, indent=2, ensure_ascii=False) |
|
|
| |
| |
| |
| |
|
|
| |
| with open(final_output_path, 'w', encoding='utf-8') as f: |
| f.write(json_str) |
|
|
| print(f"\nβ
Final Data Saved: {final_output_path}") |
| else: |
| print("\nβ Pipeline Failed.") |
| sys.exit(1) |