feeedback / bababa.py
heerjtdev's picture
Update bababa.py
ab8d2bc verified
raw
history blame
13.6 kB
import fitz # PyMuPDF
import numpy as np
import cv2
import torch
import torch.serialization
_original_torch_load = torch.load
def patched_torch_load(*args, **kwargs):
# FORCE classic behavior
kwargs["weights_only"] = False
return _original_torch_load(*args, **kwargs)
torch.load = patched_torch_load
import json
import argparse
import os
import re
import torch.nn as nn
from TorchCRF import CRF
# from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Model, LayoutLMv3Config
from typing import List, Dict, Any, Optional, Union, Tuple
from ultralytics import YOLO
import glob
from PIL import Image
import sys
import io
import base64
import tempfile
import time
import shutil
import logging
# ============================================================================
# --- TR-OCR/ORT MODEL INITIALIZATION ---
# ============================================================================
logging.basicConfig(level=logging.WARNING)
# ============================================================================
# --- CONFIGURATION AND CONSTANTS ---
# ============================================================================
# NOTE: Update these paths to match your environment before running!
WEIGHTS_PATH = 'best.pt'
# DIRECTORY CONFIGURATION
OCR_JSON_OUTPUT_DIR = './ocr_json_output_final'
FIGURE_EXTRACTION_DIR = './figure_extraction'
TEMP_IMAGE_DIR = './temp_pdf_images'
# Detection parameters
CONF_THRESHOLD = 0.2
TARGET_CLASSES = ['figure', 'equation']
IOU_MERGE_THRESHOLD = 0.4
IOA_SUPPRESSION_THRESHOLD = 0.7
LINE_TOLERANCE = 15
# Global counters for sequential numbering across the entire PDF
GLOBAL_FIGURE_COUNT = 0
GLOBAL_EQUATION_COUNT = 0
# ============================================================================
# --- PERFORMANCE OPTIMIZATION: OCR CACHE ---
# ============================================================================
class OCRCache:
"""Caches OCR results per page to avoid redundant Tesseract runs."""
def __init__(self):
self.cache = {}
def get_key(self, pdf_path: str, page_num: int) -> str:
return f"{pdf_path}:{page_num}"
def has_ocr(self, pdf_path: str, page_num: int) -> bool:
return self.get_key(pdf_path, page_num) in self.cache
def get_ocr(self, pdf_path: str, page_num: int) -> Optional[list]:
return self.cache.get(self.get_key(pdf_path, page_num))
def set_ocr(self, pdf_path: str, page_num: int, ocr_data: list):
self.cache[self.get_key(pdf_path, page_num)] = ocr_data
def clear(self):
self.cache.clear()
# Global OCR cache instance
_ocr_cache = OCRCache()
# ============================================================================
# --- PHASE 1: YOLO/OCR PREPROCESSING FUNCTIONS ---
# ============================================================================
def calculate_iou(box1, box2):
x1_a, y1_a, x2_a, y2_a = box1
x1_b, y1_b, x2_b, y2_b = box2
x_left = max(x1_a, x1_b)
y_top = max(y1_a, y1_b)
x_right = min(x2_a, x2_b)
y_bottom = min(y2_a, y2_b)
intersection_area = max(0, x_right - x_left) * max(0, y_bottom - y_top)
box_a_area = (x2_a - x1_a) * (y2_a - y1_a)
box_b_area = (x2_b - x1_b) * (y2_b - y1_b)
union_area = float(box_a_area + box_b_area - intersection_area)
return intersection_area / union_area if union_area > 0 else 0
def calculate_ioa(box1, box2):
x1_a, y1_a, x2_a, y2_a = box1
x1_b, y1_b, x2_b, y2_b = box2
x_left = max(x1_a, x1_b)
y_top = max(y1_a, y1_b)
x_right = min(x2_a, x2_b)
y_bottom = min(y2_a, y2_b)
intersection_area = max(0, x_right - x_left) * max(0, y_bottom - y_top)
box_a_area = (x2_a - x1_a) * (y2_a - y1_a)
return intersection_area / box_a_area if box_a_area > 0 else 0
def filter_nested_boxes(detections, ioa_threshold=0.80):
"""
Removes boxes that are inside larger boxes (Containment Check).
Prioritizes keeping the LARGEST box (the 'parent' container).
"""
if not detections:
return []
# 1. Calculate Area for all detections
for d in detections:
x1, y1, x2, y2 = d['coords']
d['area'] = (x2 - x1) * (y2 - y1)
# 2. Sort by Area Descending (Largest to Smallest)
# This ensures we process the 'container' first
detections.sort(key=lambda x: x['area'], reverse=True)
keep_indices = []
is_suppressed = [False] * len(detections)
for i in range(len(detections)):
if is_suppressed[i]: continue
keep_indices.append(i)
box_a = detections[i]['coords']
# Compare with all smaller boxes
for j in range(i + 1, len(detections)):
if is_suppressed[j]: continue
box_b = detections[j]['coords']
# Calculate Intersection
x_left = max(box_a[0], box_b[0])
y_top = max(box_a[1], box_b[1])
x_right = min(box_a[2], box_b[2])
y_bottom = min(box_a[3], box_b[3])
if x_right < x_left or y_bottom < y_top:
intersection = 0
else:
intersection = (x_right - x_left) * (y_bottom - y_top)
# Calculate IoA (Intersection over Area of the SMALLER box)
area_b = detections[j]['area']
if area_b > 0:
ioa_small = intersection / area_b
# If the small box is > 90% inside the big box, suppress the small one.
if ioa_small > ioa_threshold:
is_suppressed[j] = True
# print(f" [Suppress] Removed nested object inside larger '{detections[i]['class']}'")
return [detections[i] for i in keep_indices]
def merge_overlapping_boxes(detections, iou_threshold):
if not detections: return []
detections.sort(key=lambda d: d['conf'], reverse=True)
merged_detections = []
is_merged = [False] * len(detections)
for i in range(len(detections)):
if is_merged[i]: continue
current_box = detections[i]['coords']
current_class = detections[i]['class']
merged_x1, merged_y1, merged_x2, merged_y2 = current_box
for j in range(i + 1, len(detections)):
if is_merged[j] or detections[j]['class'] != current_class: continue
other_box = detections[j]['coords']
iou = calculate_iou(current_box, other_box)
if iou > iou_threshold:
merged_x1 = min(merged_x1, other_box[0])
merged_y1 = min(merged_y1, other_box[1])
merged_x2 = max(merged_x2, other_box[2])
merged_y2 = max(merged_y2, other_box[3])
is_merged[j] = True
merged_detections.append({
'coords': (merged_x1, merged_y1, merged_x2, merged_y2),
'y1': merged_y1, 'class': current_class, 'conf': detections[i]['conf']
})
return merged_detections
def merge_yolo_into_word_data(raw_word_data: list, yolo_detections: list, scale_factor: float) -> list:
"""
Filters out raw words that are inside YOLO boxes and replaces them with
a single solid 'placeholder' block for the column detector.
"""
if not yolo_detections:
return raw_word_data
# 1. Convert YOLO boxes (Pixels) to PDF Coordinates (Points)
pdf_space_boxes = []
for det in yolo_detections:
x1, y1, x2, y2 = det['coords']
pdf_box = (
x1 / scale_factor,
y1 / scale_factor,
x2 / scale_factor,
y2 / scale_factor
)
pdf_space_boxes.append(pdf_box)
# 2. Filter out raw words that are inside YOLO boxes
cleaned_word_data = []
for word_tuple in raw_word_data:
wx1, wy1, wx2, wy2 = word_tuple[1], word_tuple[2], word_tuple[3], word_tuple[4]
w_center_x = (wx1 + wx2) / 2
w_center_y = (wy1 + wy2) / 2
is_inside_yolo = False
for px1, py1, px2, py2 in pdf_space_boxes:
if px1 <= w_center_x <= px2 and py1 <= w_center_y <= py2:
is_inside_yolo = True
break
if not is_inside_yolo:
cleaned_word_data.append(word_tuple)
# 3. Add the YOLO boxes themselves as "Solid Words"
for i, (px1, py1, px2, py2) in enumerate(pdf_space_boxes):
dummy_entry = (f"BLOCK_{i}", px1, py1, px2, py2)
cleaned_word_data.append(dummy_entry)
return cleaned_word_data
# ============================================================================
# --- MISSING HELPER FUNCTION ---
# ============================================================================
def run_single_pdf_preprocessing(pdf_path: str, preprocessed_json_path: str) -> Optional[str]:
global GLOBAL_FIGURE_COUNT, GLOBAL_EQUATION_COUNT
GLOBAL_FIGURE_COUNT = 0
GLOBAL_EQUATION_COUNT = 0
_ocr_cache.clear()
print("\n" + "=" * 80)
print("--- 1. STARTING OPTIMIZED YOLO/OCR PREPROCESSING PIPELINE ---")
print("=" * 80)
if not os.path.exists(pdf_path):
print(f"❌ FATAL ERROR: Input PDF not found at {pdf_path}.")
return None
os.makedirs(os.path.dirname(preprocessed_json_path), exist_ok=True)
os.makedirs(FIGURE_EXTRACTION_DIR, exist_ok=True)
model = YOLO(WEIGHTS_PATH)
pdf_name = os.path.splitext(os.path.basename(pdf_path))[0]
try:
doc = fitz.open(pdf_path)
print(f"βœ… Opened PDF: {pdf_name} ({doc.page_count} pages)")
except Exception as e:
print(f"❌ ERROR loading PDF file: {e}")
return None
all_pages_data = []
total_pages_processed = 0
mat = fitz.Matrix(2.0, 2.0)
print("\n[STEP 1.2: ITERATING PAGES - IN-MEMORY PROCESSING]")
for page_num_0_based in range(doc.page_count):
page_num = page_num_0_based + 1
print(f" -> Processing Page {page_num}/{doc.page_count}...")
fitz_page = doc.load_page(page_num_0_based)
try:
pix = fitz_page.get_pixmap(matrix=mat)
original_img = pixmap_to_numpy(pix)
except Exception as e:
print(f" ❌ Error converting page {page_num} to image: {e}")
continue
final_output, page_separator_x = preprocess_and_ocr_page(
original_img,
model,
pdf_path,
page_num,
fitz_page,
pdf_name
)
if final_output is not None:
page_data = {
"page_number": page_num,
"data": final_output,
"column_separator_x": page_separator_x
}
all_pages_data.append(page_data)
total_pages_processed += 1
else:
print(f" ❌ Skipped page {page_num} due to processing error.")
doc.close()
if all_pages_data:
try:
with open(preprocessed_json_path, 'w') as f:
json.dump(all_pages_data, f, indent=4)
print(f"\n βœ… Combined structured OCR JSON saved to: {os.path.basename(preprocessed_json_path)}")
except Exception as e:
print(f"❌ ERROR saving combined JSON output: {e}")
return None
else:
print("❌ WARNING: No page data generated. Halting pipeline.")
return None
print("\n" + "=" * 80)
print(f"--- YOLO/OCR PREPROCESSING COMPLETE ({total_pages_processed} pages processed) ---")
print("=" * 80)
return preprocessed_json_path
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Complete Pipeline")
parser.add_argument("--input_pdf", type=str, required=True, help="Input PDF")
parser.add_argument("--layoutlmv3_model_path", type=str, default=DEFAULT_LAYOUTLMV3_MODEL_PATH, help="Model Path")
# --- ADDED ARGUMENT FOR DEBUGGING ---
parser.add_argument("--raw_preds_path", type=str, default='BIO_debug.json',
help="Debug path for raw BIO tag predictions (JSON).")
# ------------------------------------
args = parser.parse_args()
pdf_name = os.path.splitext(os.path.basename(args.input_pdf))[0]
final_output_path = os.path.abspath(f"{pdf_name}_final_output_embedded.json")
# --- CALCULATE RAW PREDICTIONS OUTPUT PATH (Kept commented as per original script) ---
# raw_predictions_output_path = os.path.abspath(
# args.raw_preds_path if args.raw_preds_path else f"{pdf_name}_raw_predictions_debug.json")
# ---------------------------------------------
# --- UPDATED FUNCTION CALL ---
final_json_data = run_document_pipeline(
args.input_pdf,
args.layoutlmv3_model_path)
# -----------------------------
# πŸ›‘ CRITICAL FINAL FIX: AGGRESSIVE CUSTOM JSON SAVING πŸ›‘
if final_json_data:
# 1. Dump the Python object to a standard JSON string.
# This converts the in-memory double backslash ('\\') into a quadruple backslash ('\\\\')
# in the raw json_str string content.
json_str = json.dumps(final_json_data, indent=2, ensure_ascii=False)
# 2. **AGGRESSIVE UNDO ESCAPING:** We assume we have quadruple backslashes and
# replace them with the double backslashes needed for the LaTeX command to work.
# This operation essentially replaces four literal backslashes with two literal backslashes.
# final_output_content = json_str.replace('\\\\\\\\', '\\\\')
# 3. Write the corrected string content to the file.
with open(final_output_path, 'w', encoding='utf-8') as f:
f.write(json_str)
print(f"\nβœ… Final Data Saved: {final_output_path}")
else:
print("\n❌ Pipeline Failed.")
sys.exit(1)