diff --git "a/bababa.py" "b/bababa.py" new file mode 100644--- /dev/null +++ "b/bababa.py" @@ -0,0 +1,2802 @@ +import fitz # PyMuPDF +import numpy as np +import cv2 +import torch +import torch.serialization + +_original_torch_load = torch.load + + +def patched_torch_load(*args, **kwargs): + # FORCE classic behavior + kwargs["weights_only"] = False + return _original_torch_load(*args, **kwargs) + + +torch.load = patched_torch_load + +import json +import argparse +import os +import re + +import torch.nn as nn +from TorchCRF import CRF +# from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Model, LayoutLMv3Config +from transformers import LayoutLMv3Tokenizer, LayoutLMv3Model, LayoutLMv3Config +from typing import List, Dict, Any, Optional, Union, Tuple +from ultralytics import YOLO +import glob +import pytesseract +from PIL import Image +from scipy.signal import find_peaks +from scipy.ndimage import gaussian_filter1d +import sys +import io +import base64 +import tempfile +import time +import shutil +from sklearn.feature_extraction.text import CountVectorizer +from sklearn.metrics.pairwise import cosine_similarity +import logging +from transformers import TrOCRProcessor +from optimum.onnxruntime import ORTModelForVision2Seq + +# ============================================================================ +# --- TR-OCR/ORT MODEL INITIALIZATION --- +# ============================================================================ + +logging.basicConfig(level=logging.WARNING) + +processor = None +ort_model = None + +try: + MODEL_NAME = 'breezedeus/pix2text-mfr-1.5' + processor = TrOCRProcessor.from_pretrained(MODEL_NAME) + + # Initialize the model for ONNX Runtime + # NOTE: Set use_cache=False to avoid caching warnings/issues if reloading + ort_model = ORTModelForVision2Seq.from_pretrained(MODEL_NAME, use_cache=False) + + print("✅ ORTModelForVision2Seq and TrOCRProcessor initialized successfully for equation conversion.") +except Exception as e: + print(f"❌ Error initializing TrOCR/ORT model. Equations will not be converted: {e}") + processor = None + ort_model = None + +from typing import Optional + + +def sanitize_text(text: Optional[str]) -> str: + """Removes surrogate characters and other invalid code points that cause UTF-8 encoding errors.""" + if not isinstance(text, str) or text is None: + return "" + + # Matches all surrogates (\ud800-\udfff) and common non-characters (\ufffe, \uffff). + # This specifically removes '\udefd' which is causing your error. + surrogates_and_nonchars = re.compile(r'[\ud800-\udfff\ufffe\uffff]') + + # Replace the invalid characters with a standard space. + # We strip afterward in the calling function. + return surrogates_and_nonchars.sub(' ', text) + + +def get_latex_from_base64(base64_string: str) -> str: + """ + Decodes a Base64 image string and uses the pre-initialized TrOCR/ORT model + to recognize the formula. It cleans the output by removing spaces and + crucially, replacing double backslashes with single backslashes for correct LaTeX. + """ + if ort_model is None or processor is None: + return "[MODEL_ERROR: Model not initialized]" + + try: + # 1. Decode Base64 to Image + image_data = base64.b64decode(base64_string) + # We must ensure the image is RGB format for the model input + image = Image.open(io.BytesIO(image_data)).convert('RGB') + + # 2. Preprocess the image + pixel_values = processor(images=image, return_tensors="pt").pixel_values + + # 3. Text Generation (OCR) + generated_ids = ort_model.generate(pixel_values) + raw_generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True) + + if not raw_generated_text: + return "[OCR_WARNING: No formula found]" + + latex_string = raw_generated_text[0] + + # --- 4. Post-processing and Cleanup --- + + # # A. Remove all spaces/line breaks + # cleaned_latex = re.sub(r'\s+', '', latex_string) + cleaned_latex = re.sub(r'[\r\n]+', '', latex_string) + + # B. CRITICAL FIX: Replace double backslashes (\\) with single backslashes (\). + # This corrects model output that already over-escaped the LaTeX commands. + # Python literal: '\\\\' is replaced with '\\'. + # cleaned_latex = cleaned_latex.replace('\\\\', '\\') + + return cleaned_latex + + + except Exception as e: + # Catch any unexpected errors + print(f" ❌ TR-OCR Recognition failed: {e}") + return f"[TR_OCR_ERROR: Recognition failed: {e}]" + + +# def get_latex_from_base64(base64_string: str) -> str: +# """ +# Decodes a Base64 image string and uses the pre-initialized TrOCR/ORT model +# to recognize the formula. It cleans the output by removing spaces and +# crucially, replacing double backslashes with single backslashes for correct LaTeX. +# """ +# if ort_model is None or processor is None: +# return "[MODEL_ERROR: Model not initialized]" + +# try: +# # 1. Decode Base64 to Image +# image_data = base64.b64decode(base64_string) +# # We must ensure the image is RGB format for the model input +# image = Image.open(io.BytesIO(image_data)).convert('RGB') + +# # 2. Preprocess the image +# pixel_values = processor(images=image, return_tensors="pt").pixel_values + +# # 3. Text Generation (OCR) +# generated_ids = ort_model.generate(pixel_values) +# raw_generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True) + +# if not raw_generated_text: +# return "[OCR_WARNING: No formula found]" + +# latex_string = raw_generated_text[0] + +# # ============================================================================== +# # --- DEBUGGING BLOCK: CHECK TrOCR RAW OUTPUT --- +# # ============================================================================== +# print(f"[DEBUG] TrOCR Raw Output: '{latex_string}'") +# # ============================================================================== + +# # --- 4. Post-processing and Cleanup --- + +# # # A. Remove all spaces/line breaks +# # cleaned_latex = re.sub(r'\s+', '', latex_string) +# cleaned_latex = re.sub(r'[\r\n]+', '', latex_string) + +# # B. CRITICAL FIX: Replace double backslashes (\\) with single backslashes (\). +# # This corrects model output that already over-escaped the LaTeX commands. +# # Python literal: '\\\\' is replaced with '\\'. +# #cleaned_latex = cleaned_latex.replace('\\\\', '\\') + +# return cleaned_latex + + +# except Exception as e: +# # Catch any unexpected errors +# print(f" ❌ TR-OCR Recognition failed: {e}") +# return f"[TR_OCR_ERROR: Recognition failed: {e}]" + + +# ============================================================================ +# --- CONFIGURATION AND CONSTANTS --- +# ============================================================================ + + +# NOTE: Update these paths to match your environment before running! +WEIGHTS_PATH = 'best.pt' +DEFAULT_LAYOUTLMV3_MODEL_PATH = "98.pth" + +# DIRECTORY CONFIGURATION +OCR_JSON_OUTPUT_DIR = './ocr_json_output_final' +FIGURE_EXTRACTION_DIR = './figure_extraction' +TEMP_IMAGE_DIR = './temp_pdf_images' + +# Detection parameters +CONF_THRESHOLD = 0.2 +TARGET_CLASSES = ['figure', 'equation'] +IOU_MERGE_THRESHOLD = 0.4 +IOA_SUPPRESSION_THRESHOLD = 0.7 +LINE_TOLERANCE = 15 + +# Similarity +SIMILARITY_THRESHOLD = 0.10 +RESOLUTION_MARGIN = 0.05 + +# Global counters for sequential numbering across the entire PDF +GLOBAL_FIGURE_COUNT = 0 +GLOBAL_EQUATION_COUNT = 0 + +# LayoutLMv3 Labels +ID_TO_LABEL = { + 0: "O", + 1: "B-QUESTION", 2: "I-QUESTION", + 3: "B-OPTION", 4: "I-OPTION", + 5: "B-ANSWER", 6: "I-ANSWER", + 7: "B-SECTION_HEADING", 8: "I-SECTION_HEADING", + 9: "B-PASSAGE", 10: "I-PASSAGE" +} +NUM_LABELS = len(ID_TO_LABEL) + + +# ============================================================================ +# --- PERFORMANCE OPTIMIZATION: OCR CACHE --- +# ============================================================================ + +class OCRCache: + """Caches OCR results per page to avoid redundant Tesseract runs.""" + + def __init__(self): + self.cache = {} + + def get_key(self, pdf_path: str, page_num: int) -> str: + return f"{pdf_path}:{page_num}" + + def has_ocr(self, pdf_path: str, page_num: int) -> bool: + return self.get_key(pdf_path, page_num) in self.cache + + def get_ocr(self, pdf_path: str, page_num: int) -> Optional[list]: + return self.cache.get(self.get_key(pdf_path, page_num)) + + def set_ocr(self, pdf_path: str, page_num: int, ocr_data: list): + self.cache[self.get_key(pdf_path, page_num)] = ocr_data + + def clear(self): + self.cache.clear() + + +# Global OCR cache instance +_ocr_cache = OCRCache() + + +# ============================================================================ +# --- PHASE 1: YOLO/OCR PREPROCESSING FUNCTIONS --- +# ============================================================================ + +def calculate_iou(box1, box2): + x1_a, y1_a, x2_a, y2_a = box1 + x1_b, y1_b, x2_b, y2_b = box2 + x_left = max(x1_a, x1_b) + y_top = max(y1_a, y1_b) + x_right = min(x2_a, x2_b) + y_bottom = min(y2_a, y2_b) + intersection_area = max(0, x_right - x_left) * max(0, y_bottom - y_top) + box_a_area = (x2_a - x1_a) * (y2_a - y1_a) + box_b_area = (x2_b - x1_b) * (y2_b - y1_b) + union_area = float(box_a_area + box_b_area - intersection_area) + return intersection_area / union_area if union_area > 0 else 0 + + +def calculate_ioa(box1, box2): + x1_a, y1_a, x2_a, y2_a = box1 + x1_b, y1_b, x2_b, y2_b = box2 + x_left = max(x1_a, x1_b) + y_top = max(y1_a, y1_b) + x_right = min(x2_a, x2_b) + y_bottom = min(y2_a, y2_b) + intersection_area = max(0, x_right - x_left) * max(0, y_bottom - y_top) + box_a_area = (x2_a - x1_a) * (y2_a - y1_a) + return intersection_area / box_a_area if box_a_area > 0 else 0 + + +def filter_nested_boxes(detections, ioa_threshold=0.80): + """ + Removes boxes that are inside larger boxes (Containment Check). + Prioritizes keeping the LARGEST box (the 'parent' container). + """ + if not detections: + return [] + + # 1. Calculate Area for all detections + for d in detections: + x1, y1, x2, y2 = d['coords'] + d['area'] = (x2 - x1) * (y2 - y1) + + # 2. Sort by Area Descending (Largest to Smallest) + # This ensures we process the 'container' first + detections.sort(key=lambda x: x['area'], reverse=True) + + keep_indices = [] + is_suppressed = [False] * len(detections) + + for i in range(len(detections)): + if is_suppressed[i]: continue + + keep_indices.append(i) + box_a = detections[i]['coords'] + + # Compare with all smaller boxes + for j in range(i + 1, len(detections)): + if is_suppressed[j]: continue + + box_b = detections[j]['coords'] + + # Calculate Intersection + x_left = max(box_a[0], box_b[0]) + y_top = max(box_a[1], box_b[1]) + x_right = min(box_a[2], box_b[2]) + y_bottom = min(box_a[3], box_b[3]) + + if x_right < x_left or y_bottom < y_top: + intersection = 0 + else: + intersection = (x_right - x_left) * (y_bottom - y_top) + + # Calculate IoA (Intersection over Area of the SMALLER box) + area_b = detections[j]['area'] + + if area_b > 0: + ioa_small = intersection / area_b + + # If the small box is > 90% inside the big box, suppress the small one. + if ioa_small > ioa_threshold: + is_suppressed[j] = True + # print(f" [Suppress] Removed nested object inside larger '{detections[i]['class']}'") + + return [detections[i] for i in keep_indices] + + +def merge_overlapping_boxes(detections, iou_threshold): + if not detections: return [] + detections.sort(key=lambda d: d['conf'], reverse=True) + merged_detections = [] + is_merged = [False] * len(detections) + for i in range(len(detections)): + if is_merged[i]: continue + current_box = detections[i]['coords'] + current_class = detections[i]['class'] + merged_x1, merged_y1, merged_x2, merged_y2 = current_box + for j in range(i + 1, len(detections)): + if is_merged[j] or detections[j]['class'] != current_class: continue + other_box = detections[j]['coords'] + iou = calculate_iou(current_box, other_box) + if iou > iou_threshold: + merged_x1 = min(merged_x1, other_box[0]) + merged_y1 = min(merged_y1, other_box[1]) + merged_x2 = max(merged_x2, other_box[2]) + merged_y2 = max(merged_y2, other_box[3]) + is_merged[j] = True + merged_detections.append({ + 'coords': (merged_x1, merged_y1, merged_x2, merged_y2), + 'y1': merged_y1, 'class': current_class, 'conf': detections[i]['conf'] + }) + return merged_detections + + +def merge_yolo_into_word_data(raw_word_data: list, yolo_detections: list, scale_factor: float) -> list: + """ + Filters out raw words that are inside YOLO boxes and replaces them with + a single solid 'placeholder' block for the column detector. + """ + if not yolo_detections: + return raw_word_data + + # 1. Convert YOLO boxes (Pixels) to PDF Coordinates (Points) + pdf_space_boxes = [] + for det in yolo_detections: + x1, y1, x2, y2 = det['coords'] + pdf_box = ( + x1 / scale_factor, + y1 / scale_factor, + x2 / scale_factor, + y2 / scale_factor + ) + pdf_space_boxes.append(pdf_box) + + # 2. Filter out raw words that are inside YOLO boxes + cleaned_word_data = [] + for word_tuple in raw_word_data: + wx1, wy1, wx2, wy2 = word_tuple[1], word_tuple[2], word_tuple[3], word_tuple[4] + w_center_x = (wx1 + wx2) / 2 + w_center_y = (wy1 + wy2) / 2 + + is_inside_yolo = False + for px1, py1, px2, py2 in pdf_space_boxes: + if px1 <= w_center_x <= px2 and py1 <= w_center_y <= py2: + is_inside_yolo = True + break + + if not is_inside_yolo: + cleaned_word_data.append(word_tuple) + + # 3. Add the YOLO boxes themselves as "Solid Words" + for i, (px1, py1, px2, py2) in enumerate(pdf_space_boxes): + dummy_entry = (f"BLOCK_{i}", px1, py1, px2, py2) + cleaned_word_data.append(dummy_entry) + + return cleaned_word_data + + +# ============================================================================ +# --- MISSING HELPER FUNCTION --- +# ============================================================================ + +def preprocess_image_for_ocr(img_np): + """ + Converts image to grayscale and applies Otsu's Binarization + to separate text from background clearly. + """ + # 1. Convert to Grayscale if needed + if len(img_np.shape) == 3: + gray = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY) + else: + gray = img_np + + # 2. Apply Otsu's Thresholding (Automatic binary threshold) + # This makes text solid black and background solid white + _, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) + + return thresh + + +def calculate_vertical_gap_coverage(word_data: list, sep_x: int, page_height: float, gutter_width: int = 10) -> float: + """ + Calculates what percentage of the page's vertical text span is 'cleanly split' by the separator. + A valid column split should split > 65% of the page verticality. + """ + if not word_data: + return 0.0 + + # Determine the vertical span of the actual text content + y_coords = [w[2] for w in word_data] + [w[4] for w in word_data] # y1 and y2 + min_y, max_y = min(y_coords), max(y_coords) + total_text_height = max_y - min_y + + if total_text_height <= 0: + return 0.0 + + # Create a boolean array representing the Y-axis (1 pixel per unit) + gap_open_mask = np.ones(int(total_text_height) + 1, dtype=bool) + + zone_left = sep_x - (gutter_width / 2) + zone_right = sep_x + (gutter_width / 2) + offset_y = int(min_y) + + for _, x1, y1, x2, y2 in word_data: + # Check if this word horizontally interferes with the separator + if x2 > zone_left and x1 < zone_right: + y_start_idx = max(0, int(y1) - offset_y) + y_end_idx = min(len(gap_open_mask), int(y2) - offset_y) + if y_end_idx > y_start_idx: + gap_open_mask[y_start_idx:y_end_idx] = False + + open_pixels = np.sum(gap_open_mask) + coverage_ratio = open_pixels / len(gap_open_mask) + + return coverage_ratio + + +def calculate_x_gutters(word_data: list, params: Dict, page_height: float) -> List[int]: + """ + Calculates X-axis histogram and validates using BRIDGING DENSITY and Vertical Coverage. + """ + if not word_data: return [] + + x_points = [] + # Use only word_data elements 1 (x1) and 3 (x2) + for item in word_data: + x_points.extend([item[1], item[3]]) + + if not x_points: return [] + max_x = max(x_points) + + # 1. Determine total text height for ratio calculation + y_coords = [item[2] for item in word_data] + [item[4] for item in word_data] + min_y, max_y = min(y_coords), max(y_coords) + total_text_height = max_y - min_y + if total_text_height <= 0: return [] + + # Histogram Setup + bin_size = params.get('cluster_bin_size', 5) + smoothing = params.get('cluster_smoothing', 1) + min_width = params.get('cluster_min_width', 20) + threshold_percentile = params.get('cluster_threshold_percentile', 85) + + num_bins = int(np.ceil(max_x / bin_size)) + hist, bin_edges = np.histogram(x_points, bins=num_bins, range=(0, max_x)) + smoothed_hist = gaussian_filter1d(hist.astype(float), sigma=smoothing) + inverted_signal = np.max(smoothed_hist) - smoothed_hist + + peaks, properties = find_peaks( + inverted_signal, + height=np.max(inverted_signal) - np.percentile(smoothed_hist, threshold_percentile), + distance=min_width / bin_size + ) + + if not peaks.size: return [] + separator_x_coords = [int(bin_edges[p]) for p in peaks] + final_separators = [] + + for x_coord in separator_x_coords: + # --- CHECK 1: BRIDGING DENSITY (The "Cut Through" Check) --- + # Calculate the total vertical height of words that physically cross this line. + bridging_height = 0 + bridging_count = 0 + + for item in word_data: + wx1, wy1, wx2, wy2 = item[1], item[2], item[3], item[4] + + # Check if this word physically sits on top of the separator line + if wx1 < x_coord and wx2 > x_coord: + word_h = wy2 - wy1 + bridging_height += word_h + bridging_count += 1 + + # Calculate Ratio: How much of the page's text height is blocked by these crossing words? + bridging_ratio = bridging_height / total_text_height + + # THRESHOLD: If bridging blocks > 8% of page height, REJECT. + # This allows for page numbers or headers (usually < 5%) to cross, but NOT paragraphs. + if bridging_ratio > 0.08: + print( + f" ❌ Separator X={x_coord} REJECTED: Bridging Ratio {bridging_ratio:.1%} (>15%) cuts through text.") + continue + + # --- CHECK 2: VERTICAL GAP COVERAGE (The "Clean Split" Check) --- + # The gap must exist cleanly for > 65% of the text height. + coverage = calculate_vertical_gap_coverage(word_data, x_coord, page_height, gutter_width=min_width) + + if coverage >= 0.80: + final_separators.append(x_coord) + print(f" -> Separator X={x_coord} ACCEPTED (Coverage: {coverage:.1%}, Bridging: {bridging_ratio:.1%})") + else: + print(f" ❌ Separator X={x_coord} REJECTED (Coverage: {coverage:.1%}, Bridging: {bridging_ratio:.1%})") + + return sorted(final_separators) + + +def get_word_data_for_detection(page: fitz.Page, pdf_path: str, page_num: int, + top_margin_percent=0.10, bottom_margin_percent=0.10) -> list: + """Extract word data with OCR caching to avoid redundant Tesseract runs.""" + word_data = page.get_text("words") + + if len(word_data) > 0: + word_data = [(w[4], w[0], w[1], w[2], w[3]) for w in word_data] + else: + if _ocr_cache.has_ocr(pdf_path, page_num): + word_data = _ocr_cache.get_ocr(pdf_path, page_num) + else: + try: + # --- OPTIMIZATION START --- + # 1. Render at Higher Resolution (Zoom 4.0 = ~300 DPI) + zoom_level = 4.0 + pix = page.get_pixmap(matrix=fitz.Matrix(zoom_level, zoom_level)) + + # 2. Convert directly to OpenCV format (Faster than PIL) + img_np = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, pix.n) + if pix.n == 3: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif pix.n == 4: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGBA2BGR) + + # 3. Apply Preprocessing (Thresholding) + processed_img = preprocess_image_for_ocr(img_np) + + # 4. Optimized Tesseract Config + # --psm 6: Assume a single uniform block of text (Great for columns/questions) + # --oem 3: Default engine (LSTM) + custom_config = r'--oem 3 --psm 6' + + data = pytesseract.image_to_data(processed_img, output_type=pytesseract.Output.DICT, + config=custom_config) + + full_word_data = [] + for i in range(len(data['level'])): + text = data['text'][i].strip() + if text: + # Scale coordinates back to PDF points + x1 = data['left'][i] / zoom_level + y1 = data['top'][i] / zoom_level + x2 = (data['left'][i] + data['width'][i]) / zoom_level + y2 = (data['top'][i] + data['height'][i]) / zoom_level + full_word_data.append((text, x1, y1, x2, y2)) + + word_data = full_word_data + _ocr_cache.set_ocr(pdf_path, page_num, word_data) + # --- OPTIMIZATION END --- + except Exception as e: + print(f" ❌ OCR Error in detection phase: {e}") + return [] + + # Apply margin filtering + page_height = page.rect.height + y_min = page_height * top_margin_percent + y_max = page_height * (1 - bottom_margin_percent) + return [d for d in word_data if d[2] >= y_min and d[4] <= y_max] + + +def pixmap_to_numpy(pix: fitz.Pixmap) -> np.ndarray: + img_data = pix.samples + img = np.frombuffer(img_data, dtype=np.uint8).reshape(pix.height, pix.width, pix.n) + if pix.n == 4: + img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGR) + elif pix.n == 3: + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + return img + + +# def extract_native_words_and_convert(fitz_page, scale_factor: float = 2.0) -> list: +# raw_word_data = fitz_page.get_text("words") +# converted_ocr_output = [] +# DEFAULT_CONFIDENCE = 99.0 + +# for x1, y1, x2, y2, word, *rest in raw_word_data: +# # --- FIX: SANITIZE TEXT HERE --- +# # cleaned_word = sanitize_text(word) +# # if not cleaned_word.strip(): continue + +# x1_pix = int(x1 * scale_factor) +# y1_pix = int(y1 * scale_factor) +# x2_pix = int(x2 * scale_factor) +# y2_pix = int(y2 * scale_factor) +# converted_ocr_output.append({ +# 'type': 'text', +# 'word': cleaned_word, # Use the sanitized word +# 'confidence': DEFAULT_CONFIDENCE, +# 'bbox': [x1_pix, y1_pix, x2_pix, y2_pix], +# 'y0': y1_pix, 'x0': x1_pix +# }) +# return converted_ocr_output + + +# def extract_native_words_and_convert(fitz_page, scale_factor: float = 2.0) -> list: +# raw_word_data = fitz_page.get_text("words") + +# # ============================================================================== +# # --- DEBUGGING BLOCK: CHECK FIRST 50 NATIVE WORDS --- +# # ============================================================================== +# print(f"\n[DEBUG] Native Extraction (Page {fitz_page.number + 1}): Checking first 50 words...") +# debug_count = 0 +# for item in raw_word_data: +# if debug_count >= 50: break +# # item format: (x0, y0, x1, y1, word, block_no, line_no, word_no) +# word_text = item[4] + +# # Generate unicode hex codes for every character in the word +# unicode_points = [f"\\u{ord(c):04x}" for c in word_text] +# print(f" Word {debug_count}: '{word_text}' -> Codes: {unicode_points}") +# debug_count += 1 +# print("----------------------------------------------------------------------\n") +# # ============================================================================== + +# converted_ocr_output = [] +# DEFAULT_CONFIDENCE = 99.0 + +# for x1, y1, x2, y2, word, *rest in raw_word_data: +# # --- FIX: SANITIZE TEXT HERE --- +# cleaned_word = sanitize_text(word) +# if not cleaned_word.strip(): continue + +# x1_pix = int(x1 * scale_factor) +# y1_pix = int(y1 * scale_factor) +# x2_pix = int(x2 * scale_factor) +# y2_pix = int(y2 * scale_factor) +# converted_ocr_output.append({ +# 'type': 'text', +# 'word': cleaned_word, # Use the sanitized word +# 'confidence': DEFAULT_CONFIDENCE, +# 'bbox': [x1_pix, y1_pix, x2_pix, y2_pix], +# 'y0': y1_pix, 'x0': x1_pix +# }) +# return converted_ocr_output + + +def extract_native_words_and_convert(fitz_page, scale_factor: float = 2.0) -> list: + # 1. Get raw data + try: + raw_word_data = fitz_page.get_text("words") + except Exception as e: + print(f" ❌ PyMuPDF extraction failed completely: {e}") + return [] + + # ============================================================================== + # --- DEBUGGING BLOCK: CHECK FIRST 50 NATIVE WORDS (SAFE PRINT) --- + # ============================================================================== + print(f"\n[DEBUG] Native Extraction (Page {fitz_page.number + 1}): Checking first 50 words...") + + debug_count = 0 + for item in raw_word_data: + if debug_count >= 50: break + + word_text = item[4] + + # --- SAFE PRINTING LOGIC --- + # We encode/decode to ignore surrogates just for the print statement + # This prevents the "UnicodeEncodeError" that was crashing your script + safe_text = word_text.encode('utf-8', 'ignore').decode('utf-8') + + # Get hex codes (handling potential errors in 'ord') + try: + unicode_points = [f"\\u{ord(c):04x}" for c in word_text] + except: + unicode_points = ["ERROR"] + + print(f" Word {debug_count}: '{safe_text}' -> Codes: {unicode_points}") + debug_count += 1 + print("----------------------------------------------------------------------\n") + # ============================================================================== + + converted_ocr_output = [] + DEFAULT_CONFIDENCE = 99.0 + + for x1, y1, x2, y2, word, *rest in raw_word_data: + # --- FIX: ROBUST SANITIZATION --- + # 1. Encode to UTF-8 ignoring errors (strips surrogates) + # 2. Decode back to string + cleaned_word_bytes = word.encode('utf-8', 'ignore') + cleaned_word = cleaned_word_bytes.decode('utf-8') + cleaned_word = word.encode('utf-8', 'ignore').decode('utf-8').strip() + + # cleaned_word = cleaned_word.strip() + if not cleaned_word: continue + + x1_pix = int(x1 * scale_factor) + y1_pix = int(y1 * scale_factor) + x2_pix = int(x2 * scale_factor) + y2_pix = int(y2 * scale_factor) + + converted_ocr_output.append({ + 'type': 'text', + 'word': cleaned_word, + 'confidence': DEFAULT_CONFIDENCE, + 'bbox': [x1_pix, y1_pix, x2_pix, y2_pix], + 'y0': y1_pix, 'x0': x1_pix + }) + + return converted_ocr_output + + +# =================================================================================================== +# =================================================================================================== +# =================================================================================================== + + +import pandas as pd +import pickle +import os +import time +import json +from sklearn.feature_extraction.text import TfidfVectorizer +import numpy as np +from collections import defaultdict + +# --- Model File Paths (Required for the Classifier to load) --- +VECTORIZER_FILE = 'tfidf_vectorizer_conditional.pkl' +SUBJECT_MODEL_FILE = 'subject_classifier_model_conditional.pkl' +CONDITIONAL_CONCEPT_MODELS_FILE = 'conditional_concept_models.pkl' + + +# --- Hierarchical Classifier Class (Dependency for the helper function) --- + +class HierarchicalClassifier: + """ + A two-stage classification system based on conditional training. + Loads the vectorizer, subject classifier, and conditional concept models. + """ + + def __init__(self): + self.vectorizer = None + self.subject_model = None + self.conditional_concept_models = {} + self.is_ready = False + + def load_models(self): + """Loads the vectorizer, subject model, and conditional concept models.""" + try: + start_time = time.time() + # 1. Load the TF-IDF Vectorizer + with open(VECTORIZER_FILE, 'rb') as f: + self.vectorizer = pickle.load(f) + + # 2. Load the Level 1 (Subject) Classifier + with open(SUBJECT_MODEL_FILE, 'rb') as f: + self.subject_model = pickle.load(f) + + # 3. Load the dictionary of conditional Level 2 (Concept) Models + with open(CONDITIONAL_CONCEPT_MODELS_FILE, 'rb') as f: + conditional_data = pickle.load(f) + + # Extract just the models for easy access + for subject, data in conditional_data.items(): + self.conditional_concept_models[subject] = data['model'] + + print(f"Loaded models successfully in {time.time() - start_time:.2f} seconds.") + self.is_ready = True + + except FileNotFoundError as e: + print(f"Error: Required model file not found: {e.filename}.") + self.is_ready = False + except Exception as e: + print(f"An error occurred while loading models: {e}") + self.is_ready = False + + return self.is_ready + + def predict_subject(self, text_chunk): + """Predicts the top Subject (Level 1).""" + if not self.is_ready: + return "Unknown", 0.0 + + # Vectorize the input + text_vector = self.vectorizer.transform([text_chunk]).astype(np.float64) + + if hasattr(self.subject_model, 'predict_proba'): + probabilities = self.subject_model.predict_proba(text_vector)[0] + classes = self.subject_model.classes_ + + top_index = np.argmax(probabilities) + return classes[top_index], probabilities[top_index] + else: + return self.subject_model.predict(text_vector)[0], 1.0 + + def predict_concept_hierarchical(self, text_chunk, predicted_subject): + """ + Predicts the top Concept (Level 2) using the specialized conditional model. + """ + if not self.is_ready: + return "Unknown", 0.0 + + concept_model = self.conditional_concept_models.get(predicted_subject) + + if concept_model is None or len(getattr(concept_model, 'classes_', [])) <= 1: + return "No_Conditional_Model_Found", 0.0 + + # Vectorize the input + text_vector = self.vectorizer.transform([text_chunk]).astype(np.float64) + + if hasattr(concept_model, 'predict_proba'): + probabilities = concept_model.predict_proba(text_vector)[0] + classes = concept_model.classes_ + + top_index = np.argmax(probabilities) + return classes[top_index], probabilities[top_index] + else: + return concept_model.predict(text_vector)[0], 1.0 + + +# -------------------------------------------------------------------------------------- +# --- The Requested Helper Function --- + +def post_process_json_with_inference(json_data, classifier): + """ + Takes JSON data, runs hierarchical inference on all question/option text, + and adds 'predicted_subject' and 'predicted_concept' tags to each entry. + + Args: + json_data (list): The list of dictionaries containing question entries. + classifier (HierarchicalClassifier): An initialized and loaded classifier object. + + Returns: + list: The modified list of dictionaries with classification tags added. + """ + if not classifier.is_ready: + print("Classifier not ready. Skipping inference.") + return json_data + + # This print statement can be removed for silent pipeline integration + print("\n--- Starting Subject/Concept Detection ---") + + for entry in json_data: + # Only process entries that have a 'question' field + if 'question' not in entry: + continue + + # 1. Combine Question Text and Option Text for robust feature extraction + full_text = entry.get('question', '') + + options = entry.get('options', {}) + for option_key, option_value in options.items(): + # Use the text component of the option if available + option_text = option_value if isinstance(option_value, str) else option_key + full_text += " " + option_text.replace('\n', ' ') + + # Clean up text (remove multiple spaces and surrounding whitespace) + full_text = ' '.join(full_text.split()).strip() + + # Handle empty text + if not full_text: + entry['predicted_subject'] = {'label': 'Empty_Text', 'confidence': 0.0} + entry['predicted_concept'] = {'label': 'Empty_Text', 'confidence': 0.0} + continue + + # 2. STAGE 1: Predict Subject + subj_label, subj_conf = classifier.predict_subject(full_text) + + # 3. STAGE 2: Predict Concept (Conditional on predicted subject) + conc_label, conc_conf = classifier.predict_concept_hierarchical(full_text, subj_label) + + # 4. Add results to the JSON entry + entry['predicted_subject'] = { + 'label': subj_label, + 'confidence': round(subj_conf, 4) + } + entry['predicted_concept'] = { + 'label': conc_label, + 'confidence': round(conc_conf, 4) + } + + # This print statement can be removed for silent pipeline integration + # print("--- JSON Post-Processing Complete ---") + + return json_data + + +# =================================================================================================== +# =================================================================================================== +# =================================================================================================== + + +def preprocess_and_ocr_page(original_img: np.ndarray, model, pdf_path: str, + page_num: int, fitz_page: fitz.Page, + pdf_name: str) -> Tuple[List[Dict[str, Any]], Optional[int]]: + """ + OPTIMIZED FLOW: + 1. Run YOLO to find Equations/Tables. + 2. Mask raw text with YOLO boxes. + 3. Run Column Detection on the MASKED data. + 4. Proceed with OCR (Native or High-Res Tesseract Fallback) and Output. + """ + global GLOBAL_FIGURE_COUNT, GLOBAL_EQUATION_COUNT + + start_time_total = time.time() + + if original_img is None: + print(f" ❌ Invalid image for page {page_num}.") + return None, None + + # ==================================================================== + # --- STEP 1: YOLO DETECTION --- + # ==================================================================== + start_time_yolo = time.time() + results = model.predict(source=original_img, conf=CONF_THRESHOLD, imgsz=640, verbose=False) + + relevant_detections = [] + if results and results[0].boxes: + for box in results[0].boxes: + class_id = int(box.cls[0]) + class_name = model.names[class_id] + if class_name in TARGET_CLASSES: + x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int) + relevant_detections.append( + {'coords': (x1, y1, x2, y2), 'y1': y1, 'class': class_name, 'conf': float(box.conf[0])} + ) + + merged_detections = merge_overlapping_boxes(relevant_detections, IOU_MERGE_THRESHOLD) + print(f" [LOG] YOLO found {len(merged_detections)} objects in {time.time() - start_time_yolo:.3f}s.") + + # ==================================================================== + # --- STEP 2: PREPARE DATA FOR COLUMN DETECTION (MASKING) --- + # ==================================================================== + # Note: This uses the updated 'get_word_data_for_detection' which has its own optimizations + raw_words_for_layout = get_word_data_for_detection( + fitz_page, pdf_path, page_num, + top_margin_percent=0.10, bottom_margin_percent=0.10 + ) + + masked_word_data = merge_yolo_into_word_data(raw_words_for_layout, merged_detections, scale_factor=2.0) + + # ==================================================================== + # --- STEP 3: COLUMN DETECTION --- + # ==================================================================== + page_width_pdf = fitz_page.rect.width + page_height_pdf = fitz_page.rect.height + + column_detection_params = { + 'cluster_bin_size': 2, 'cluster_smoothing': 2, + 'cluster_min_width': 10, 'cluster_threshold_percentile': 85, + } + + separators = calculate_x_gutters(masked_word_data, column_detection_params, page_height_pdf) + + page_separator_x = None + if separators: + central_min = page_width_pdf * 0.35 + central_max = page_width_pdf * 0.65 + central_separators = [s for s in separators if central_min <= s <= central_max] + + if central_separators: + center_x = page_width_pdf / 2 + page_separator_x = min(central_separators, key=lambda x: abs(x - center_x)) + print(f" ✅ Column Split Confirmed at X={page_separator_x:.1f}") + else: + print(" ⚠️ Gutter found off-center. Ignoring.") + else: + print(" -> Single Column Layout Confirmed.") + + # ==================================================================== + # --- STEP 4: COMPONENT EXTRACTION (Save Images) --- + # ==================================================================== + start_time_components = time.time() + component_metadata = [] + fig_count_page = 0 + eq_count_page = 0 + + for detection in merged_detections: + x1, y1, x2, y2 = detection['coords'] + class_name = detection['class'] + + if class_name == 'figure': + GLOBAL_FIGURE_COUNT += 1 + counter = GLOBAL_FIGURE_COUNT + component_word = f"FIGURE{counter}" + fig_count_page += 1 + elif class_name == 'equation': + GLOBAL_EQUATION_COUNT += 1 + counter = GLOBAL_EQUATION_COUNT + component_word = f"EQUATION{counter}" + eq_count_page += 1 + else: + continue + + component_crop = original_img[y1:y2, x1:x2] + component_filename = f"{pdf_name}_page{page_num}_{class_name}{counter}.png" + cv2.imwrite(os.path.join(FIGURE_EXTRACTION_DIR, component_filename), component_crop) + + y_midpoint = (y1 + y2) // 2 + component_metadata.append({ + 'type': class_name, 'word': component_word, + 'bbox': [int(x1), int(y1), int(x2), int(y2)], + 'y0': int(y_midpoint), 'x0': int(x1) + }) + + # ==================================================================== + # --- STEP 5: HYBRID OCR (Native Text + Cached Tesseract Fallback) --- + # ==================================================================== + raw_ocr_output = [] + scale_factor = 2.0 # Pipeline standard scale + + try: + # Try getting native text first + # NOTE: extract_native_words_and_convert MUST ALSO BE UPDATED TO USE sanitize_text + raw_ocr_output = extract_native_words_and_convert(fitz_page, scale_factor=scale_factor) + except Exception as e: + print(f" ❌ Native text extraction failed: {e}") + + # If native text is missing, fall back to OCR + if not raw_ocr_output: + if _ocr_cache.has_ocr(pdf_path, page_num): + print(f" ⚡ Using cached Tesseract OCR for page {page_num}") + cached_word_data = _ocr_cache.get_ocr(pdf_path, page_num) + for word_tuple in cached_word_data: + word_text, x1, y1, x2, y2 = word_tuple + + # Scale from PDF points to Pipeline Pixels (2.0) + x1_pix = int(x1 * scale_factor) + y1_pix = int(y1 * scale_factor) + x2_pix = int(x2 * scale_factor) + y2_pix = int(y2 * scale_factor) + + raw_ocr_output.append({ + 'type': 'text', 'word': word_text, 'confidence': 95.0, + 'bbox': [x1_pix, y1_pix, x2_pix, y2_pix], + 'y0': y1_pix, 'x0': x1_pix + }) + else: + # === START OF OPTIMIZED OCR BLOCK === + try: + # 1. Re-render Page at High Resolution (Zoom 4.0 = ~300 DPI) + ocr_zoom = 4.0 + pix_ocr = fitz_page.get_pixmap(matrix=fitz.Matrix(ocr_zoom, ocr_zoom)) + + # Convert PyMuPDF Pixmap to OpenCV format + img_ocr_np = np.frombuffer(pix_ocr.samples, dtype=np.uint8).reshape(pix_ocr.height, pix_ocr.width, + pix_ocr.n) + if pix_ocr.n == 3: + img_ocr_np = cv2.cvtColor(img_ocr_np, cv2.COLOR_RGB2BGR) + elif pix_ocr.n == 4: + img_ocr_np = cv2.cvtColor(img_ocr_np, cv2.COLOR_RGBA2BGR) + + # 2. Preprocess (Binarization) + processed_img = preprocess_image_for_ocr(img_ocr_np) + + # 3. Run Tesseract with Optimized Configuration + custom_config = r'--oem 3 --psm 6' + + hocr_data = pytesseract.image_to_data( + processed_img, + output_type=pytesseract.Output.DICT, + config=custom_config + ) + + for i in range(len(hocr_data['level'])): + text = hocr_data['text'][i] # Retrieve raw Tesseract text + + # --- FIX: SANITIZE TEXT AND THEN STRIP --- + cleaned_text = sanitize_text(text).strip() + + if cleaned_text and hocr_data['conf'][i] > -1: + # 4. Coordinate Mapping + scale_adjustment = scale_factor / ocr_zoom + + x1 = int(hocr_data['left'][i] * scale_adjustment) + y1 = int(hocr_data['top'][i] * scale_adjustment) + w = int(hocr_data['width'][i] * scale_adjustment) + h = int(hocr_data['height'][i] * scale_adjustment) + x2 = x1 + w + y2 = y1 + h + + raw_ocr_output.append({ + 'type': 'text', + 'word': cleaned_text, # Use the sanitized word + 'confidence': float(hocr_data['conf'][i]), + 'bbox': [x1, y1, x2, y2], + 'y0': y1, + 'x0': x1 + }) + except Exception as e: + print(f" ❌ Tesseract OCR Error: {e}") + # === END OF OPTIMIZED OCR BLOCK === + + # ==================================================================== + # --- STEP 6: OCR CLEANING AND MERGING --- + # ==================================================================== + items_to_sort = [] + + for ocr_word in raw_ocr_output: + is_suppressed = False + for component in component_metadata: + # Do not include words that are inside figure/equation boxes + ioa = calculate_ioa(ocr_word['bbox'], component['bbox']) + if ioa > IOA_SUPPRESSION_THRESHOLD: + is_suppressed = True + break + if not is_suppressed: + items_to_sort.append(ocr_word) + + # Add figures/equations back into the flow as "words" + items_to_sort.extend(component_metadata) + + # ==================================================================== + # --- STEP 7: LINE-BASED SORTING --- + # ==================================================================== + items_to_sort.sort(key=lambda x: (x['y0'], x['x0'])) + lines = [] + + for item in items_to_sort: + placed = False + for line in lines: + y_ref = min(it['y0'] for it in line) + if abs(y_ref - item['y0']) < LINE_TOLERANCE: + line.append(item) + placed = True + break + if not placed and item['type'] in ['equation', 'figure']: + for line in lines: + y_ref = min(it['y0'] for it in line) + if abs(y_ref - item['y0']) < 20: + line.append(item) + placed = True + break + if not placed: + lines.append([item]) + + for line in lines: + line.sort(key=lambda x: x['x0']) + + final_output = [] + for line in lines: + for item in line: + data_item = {"word": item["word"], "bbox": item["bbox"], "type": item["type"]} + if 'tag' in item: data_item['tag'] = item['tag'] + final_output.append(data_item) + + return final_output, page_separator_x + + +# def preprocess_and_ocr_page(original_img: np.ndarray, model, pdf_path: str, +# page_num: int, fitz_page: fitz.Page, +# pdf_name: str) -> Tuple[List[Dict[str, Any]], Optional[int]]: +# """ +# OPTIMIZED FLOW: +# 1. Run YOLO to find Equations/Tables. +# 2. Mask raw text with YOLO boxes. +# 3. Run Column Detection on the MASKED data. +# 4. Proceed with OCR (Native or High-Res Tesseract Fallback) and Output. +# """ +# global GLOBAL_FIGURE_COUNT, GLOBAL_EQUATION_COUNT + +# start_time_total = time.time() + +# if original_img is None: +# print(f" ❌ Invalid image for page {page_num}.") +# return None, None + +# # ==================================================================== +# # --- STEP 1: YOLO DETECTION --- +# # ==================================================================== +# start_time_yolo = time.time() +# results = model.predict(source=original_img, conf=CONF_THRESHOLD, imgsz=640, verbose=False) + +# relevant_detections = [] +# if results and results[0].boxes: +# for box in results[0].boxes: +# class_id = int(box.cls[0]) +# class_name = model.names[class_id] +# if class_name in TARGET_CLASSES: +# x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int) +# relevant_detections.append( +# {'coords': (x1, y1, x2, y2), 'y1': y1, 'class': class_name, 'conf': float(box.conf[0])} +# ) + +# merged_detections = merge_overlapping_boxes(relevant_detections, IOU_MERGE_THRESHOLD) +# print(f" [LOG] YOLO found {len(merged_detections)} objects in {time.time() - start_time_yolo:.3f}s.") + +# # ==================================================================== +# # --- STEP 2: PREPARE DATA FOR COLUMN DETECTION (MASKING) --- +# # ==================================================================== +# # Note: This uses the updated 'get_word_data_for_detection' which has its own optimizations +# raw_words_for_layout = get_word_data_for_detection( +# fitz_page, pdf_path, page_num, +# top_margin_percent=0.10, bottom_margin_percent=0.10 +# ) + +# masked_word_data = merge_yolo_into_word_data(raw_words_for_layout, merged_detections, scale_factor=2.0) + +# # ==================================================================== +# # --- STEP 3: COLUMN DETECTION --- +# # ==================================================================== +# page_width_pdf = fitz_page.rect.width +# page_height_pdf = fitz_page.rect.height + +# column_detection_params = { +# 'cluster_bin_size': 2, 'cluster_smoothing': 2, +# 'cluster_min_width': 10, 'cluster_threshold_percentile': 85, +# } + +# separators = calculate_x_gutters(masked_word_data, column_detection_params, page_height_pdf) + +# page_separator_x = None +# if separators: +# central_min = page_width_pdf * 0.35 +# central_max = page_width_pdf * 0.65 +# central_separators = [s for s in separators if central_min <= s <= central_max] + +# if central_separators: +# center_x = page_width_pdf / 2 +# page_separator_x = min(central_separators, key=lambda x: abs(x - center_x)) +# print(f" ✅ Column Split Confirmed at X={page_separator_x:.1f}") +# else: +# print(" ⚠️ Gutter found off-center. Ignoring.") +# else: +# print(" -> Single Column Layout Confirmed.") + +# # ==================================================================== +# # --- STEP 4: COMPONENT EXTRACTION (Save Images) --- +# # ==================================================================== +# start_time_components = time.time() +# component_metadata = [] +# fig_count_page = 0 +# eq_count_page = 0 + +# for detection in merged_detections: +# x1, y1, x2, y2 = detection['coords'] +# class_name = detection['class'] + +# if class_name == 'figure': +# GLOBAL_FIGURE_COUNT += 1 +# counter = GLOBAL_FIGURE_COUNT +# component_word = f"FIGURE{counter}" +# fig_count_page += 1 +# elif class_name == 'equation': +# GLOBAL_EQUATION_COUNT += 1 +# counter = GLOBAL_EQUATION_COUNT +# component_word = f"EQUATION{counter}" +# eq_count_page += 1 +# else: +# continue + +# component_crop = original_img[y1:y2, x1:x2] +# component_filename = f"{pdf_name}_page{page_num}_{class_name}{counter}.png" +# cv2.imwrite(os.path.join(FIGURE_EXTRACTION_DIR, component_filename), component_crop) + +# y_midpoint = (y1 + y2) // 2 +# component_metadata.append({ +# 'type': class_name, 'word': component_word, +# 'bbox': [int(x1), int(y1), int(x2), int(y2)], +# 'y0': int(y_midpoint), 'x0': int(x1) +# }) + +# # ==================================================================== +# # --- STEP 5: HYBRID OCR (Native Text + Cached Tesseract Fallback) --- +# # ==================================================================== +# raw_ocr_output = [] +# scale_factor = 2.0 # Pipeline standard scale + +# try: +# # Try getting native text first +# # NOTE: extract_native_words_and_convert MUST ALSO BE UPDATED TO USE sanitize_text +# raw_ocr_output = extract_native_words_and_convert(fitz_page, scale_factor=scale_factor) +# except Exception as e: +# print(f" ❌ Native text extraction failed: {e}") + +# # If native text is missing, fall back to OCR +# if not raw_ocr_output: +# if _ocr_cache.has_ocr(pdf_path, page_num): +# print(f" ⚡ Using cached Tesseract OCR for page {page_num}") +# cached_word_data = _ocr_cache.get_ocr(pdf_path, page_num) +# for word_tuple in cached_word_data: +# word_text, x1, y1, x2, y2 = word_tuple + +# # Scale from PDF points to Pipeline Pixels (2.0) +# x1_pix = int(x1 * scale_factor) +# y1_pix = int(y1 * scale_factor) +# x2_pix = int(x2 * scale_factor) +# y2_pix = int(y2 * scale_factor) + +# raw_ocr_output.append({ +# 'type': 'text', 'word': word_text, 'confidence': 95.0, +# 'bbox': [x1_pix, y1_pix, x2_pix, y2_pix], +# 'y0': y1_pix, 'x0': x1_pix +# }) +# else: +# # === START OF OPTIMIZED OCR BLOCK === +# try: +# # 1. Re-render Page at High Resolution (Zoom 4.0 = ~300 DPI) +# ocr_zoom = 4.0 +# pix_ocr = fitz_page.get_pixmap(matrix=fitz.Matrix(ocr_zoom, ocr_zoom)) + +# # Convert PyMuPDF Pixmap to OpenCV format +# img_ocr_np = np.frombuffer(pix_ocr.samples, dtype=np.uint8).reshape(pix_ocr.height, pix_ocr.width, +# pix_ocr.n) +# if pix_ocr.n == 3: +# img_ocr_np = cv2.cvtColor(img_ocr_np, cv2.COLOR_RGB2BGR) +# elif pix_ocr.n == 4: +# img_ocr_np = cv2.cvtColor(img_ocr_np, cv2.COLOR_RGBA2BGR) + +# # 2. Preprocess (Binarization) +# processed_img = preprocess_image_for_ocr(img_ocr_np) + +# # 3. Run Tesseract with Optimized Configuration +# custom_config = r'--oem 3 --psm 6' + +# hocr_data = pytesseract.image_to_data( +# processed_img, +# output_type=pytesseract.Output.DICT, +# config=custom_config +# ) + +# # ============================================================================== +# # --- DEBUGGING BLOCK: CHECK FIRST 50 OCR WORDS --- +# # ============================================================================== +# print(f"\n[DEBUG] Tesseract OCR Fallback (Page {page_num}): Checking first 50 words...") +# debug_count = 0 +# for i in range(len(hocr_data['level'])): +# text = hocr_data['text'][i].strip() +# if text: +# unicode_points = [f"\\u{ord(c):04x}" for c in text] +# print(f" OCR Word {debug_count}: '{text}' -> Codes: {unicode_points}") +# debug_count += 1 +# if debug_count >= 50: break +# print("----------------------------------------------------------------------\n") +# # ============================================================================== + +# for i in range(len(hocr_data['level'])): +# text = hocr_data['text'][i] # Retrieve raw Tesseract text + +# # --- FIX: SANITIZE TEXT AND THEN STRIP --- +# cleaned_text = sanitize_text(text).strip() + +# if cleaned_text and hocr_data['conf'][i] > -1: +# # 4. Coordinate Mapping +# scale_adjustment = scale_factor / ocr_zoom + +# x1 = int(hocr_data['left'][i] * scale_adjustment) +# y1 = int(hocr_data['top'][i] * scale_adjustment) +# w = int(hocr_data['width'][i] * scale_adjustment) +# h = int(hocr_data['height'][i] * scale_adjustment) +# x2 = x1 + w +# y2 = y1 + h + +# raw_ocr_output.append({ +# 'type': 'text', +# 'word': cleaned_text, # Use the sanitized word +# 'confidence': float(hocr_data['conf'][i]), +# 'bbox': [x1, y1, x2, y2], +# 'y0': y1, +# 'x0': x1 +# }) +# except Exception as e: +# print(f" ❌ Tesseract OCR Error: {e}") +# # === END OF OPTIMIZED OCR BLOCK === + +# # ==================================================================== +# # --- STEP 6: OCR CLEANING AND MERGING --- +# # ==================================================================== +# items_to_sort = [] + +# for ocr_word in raw_ocr_output: +# is_suppressed = False +# for component in component_metadata: +# # Do not include words that are inside figure/equation boxes +# ioa = calculate_ioa(ocr_word['bbox'], component['bbox']) +# if ioa > IOA_SUPPRESSION_THRESHOLD: +# is_suppressed = True +# break +# if not is_suppressed: +# items_to_sort.append(ocr_word) + +# # Add figures/equations back into the flow as "words" +# items_to_sort.extend(component_metadata) + +# # ==================================================================== +# # --- STEP 7: LINE-BASED SORTING --- +# # ==================================================================== +# items_to_sort.sort(key=lambda x: (x['y0'], x['x0'])) +# lines = [] + +# for item in items_to_sort: +# placed = False +# for line in lines: +# y_ref = min(it['y0'] for it in line) +# if abs(y_ref - item['y0']) < LINE_TOLERANCE: +# line.append(item) +# placed = True +# break +# if not placed and item['type'] in ['equation', 'figure']: +# for line in lines: +# y_ref = min(it['y0'] for it in line) +# if abs(y_ref - item['y0']) < 20: +# line.append(item) +# placed = True +# break +# if not placed: +# lines.append([item]) + +# for line in lines: +# line.sort(key=lambda x: x['x0']) + +# final_output = [] +# for line in lines: +# for item in line: +# data_item = {"word": item["word"], "bbox": item["bbox"], "type": item["type"]} +# if 'tag' in item: data_item['tag'] = item['tag'] +# final_output.append(data_item) + +# return final_output, page_separator_x + + +def run_single_pdf_preprocessing(pdf_path: str, preprocessed_json_path: str) -> Optional[str]: + global GLOBAL_FIGURE_COUNT, GLOBAL_EQUATION_COUNT + + GLOBAL_FIGURE_COUNT = 0 + GLOBAL_EQUATION_COUNT = 0 + _ocr_cache.clear() + + print("\n" + "=" * 80) + print("--- 1. STARTING OPTIMIZED YOLO/OCR PREPROCESSING PIPELINE ---") + print("=" * 80) + + if not os.path.exists(pdf_path): + print(f"❌ FATAL ERROR: Input PDF not found at {pdf_path}.") + return None + + os.makedirs(os.path.dirname(preprocessed_json_path), exist_ok=True) + os.makedirs(FIGURE_EXTRACTION_DIR, exist_ok=True) + + model = YOLO(WEIGHTS_PATH) + pdf_name = os.path.splitext(os.path.basename(pdf_path))[0] + + try: + doc = fitz.open(pdf_path) + print(f"✅ Opened PDF: {pdf_name} ({doc.page_count} pages)") + except Exception as e: + print(f"❌ ERROR loading PDF file: {e}") + return None + + all_pages_data = [] + total_pages_processed = 0 + mat = fitz.Matrix(2.0, 2.0) + + print("\n[STEP 1.2: ITERATING PAGES - IN-MEMORY PROCESSING]") + + for page_num_0_based in range(doc.page_count): + page_num = page_num_0_based + 1 + print(f" -> Processing Page {page_num}/{doc.page_count}...") + + fitz_page = doc.load_page(page_num_0_based) + + try: + pix = fitz_page.get_pixmap(matrix=mat) + original_img = pixmap_to_numpy(pix) + except Exception as e: + print(f" ❌ Error converting page {page_num} to image: {e}") + continue + + final_output, page_separator_x = preprocess_and_ocr_page( + original_img, + model, + pdf_path, + page_num, + fitz_page, + pdf_name + ) + + if final_output is not None: + page_data = { + "page_number": page_num, + "data": final_output, + "column_separator_x": page_separator_x + } + all_pages_data.append(page_data) + total_pages_processed += 1 + else: + print(f" ❌ Skipped page {page_num} due to processing error.") + + doc.close() + + if all_pages_data: + try: + with open(preprocessed_json_path, 'w') as f: + json.dump(all_pages_data, f, indent=4) + print(f"\n ✅ Combined structured OCR JSON saved to: {os.path.basename(preprocessed_json_path)}") + except Exception as e: + print(f"❌ ERROR saving combined JSON output: {e}") + return None + else: + print("❌ WARNING: No page data generated. Halting pipeline.") + return None + + print("\n" + "=" * 80) + print(f"--- YOLO/OCR PREPROCESSING COMPLETE ({total_pages_processed} pages processed) ---") + print("=" * 80) + + return preprocessed_json_path + + +# ============================================================================ +# --- PHASE 2: LAYOUTLMV3 INFERENCE FUNCTIONS --- +# ============================================================================ + +class LayoutLMv3ForTokenClassification(nn.Module): + def __init__(self, num_labels: int = NUM_LABELS): + super().__init__() + self.num_labels = num_labels + config = LayoutLMv3Config.from_pretrained("microsoft/layoutlmv3-base", num_labels=num_labels) + self.layoutlmv3 = LayoutLMv3Model.from_pretrained("microsoft/layoutlmv3-base", config=config) + self.classifier = nn.Linear(config.hidden_size, num_labels) + self.crf = CRF(num_labels) + self.init_weights() + + def init_weights(self): + nn.init.xavier_uniform_(self.classifier.weight) + if self.classifier.bias is not None: nn.init.zeros_(self.classifier.bias) + + def forward(self, input_ids: torch.Tensor, bbox: torch.Tensor, attention_mask: torch.Tensor, + labels: Optional[torch.Tensor] = None): + outputs = self.layoutlmv3(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, return_dict=True) + sequence_output = outputs.last_hidden_state + emissions = self.classifier(sequence_output) + mask = attention_mask.bool() + if labels is not None: + loss = -self.crf(emissions, labels, mask=mask).mean() + return loss + else: + return self.crf.viterbi_decode(emissions, mask=mask) + + +def _merge_integrity(all_token_data: List[Dict[str, Any]], + column_separator_x: Optional[int]) -> List[List[Dict[str, Any]]]: + """Splits the token data objects into column chunks based on a separator.""" + if column_separator_x is None: + print(" -> No column separator. Treating as one chunk.") + return [all_token_data] + + left_column_tokens, right_column_tokens = [], [] + for token_data in all_token_data: + bbox_raw = token_data['bbox_raw_pdf_space'] + center_x = (bbox_raw[0] + bbox_raw[2]) / 2 + if center_x < column_separator_x: + left_column_tokens.append(token_data) + else: + right_column_tokens.append(token_data) + + chunks = [c for c in [left_column_tokens, right_column_tokens] if c] + print(f" -> Data split into {len(chunks)} column chunk(s) using separator X={column_separator_x}.") + return chunks + + +def run_inference_and_get_raw_words(pdf_path: str, model_path: str, + preprocessed_json_path: str, + column_detection_params: Optional[Dict] = None) -> List[Dict[str, Any]]: + print("\n" + "=" * 80) + print("--- 2. STARTING LAYOUTLMV3 INFERENCE PIPELINE (Raw Word Output) ---") + print("=" * 80) + + tokenizer = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f" -> Using device: {device}") + + try: + model = LayoutLMv3ForTokenClassification(num_labels=NUM_LABELS) + checkpoint = torch.load(model_path, map_location=device) + model_state = checkpoint.get('model_state_dict', checkpoint) + # Apply patch for layoutlmv3 compatibility with saved state_dict + fixed_state_dict = {key.replace('layoutlm.', 'layoutlmv3.'): value for key, value in model_state.items()} + model.load_state_dict(fixed_state_dict) + model.to(device) + model.eval() + print(f"✅ LayoutLMv3 Model loaded successfully from {os.path.basename(model_path)}.") + except Exception as e: + print(f"❌ FATAL ERROR during LayoutLMv3 model loading: {e}") + return [] + + try: + with open(preprocessed_json_path, 'r', encoding='utf-8') as f: + preprocessed_data = json.load(f) + print(f"✅ Loaded preprocessed data with {len(preprocessed_data)} pages.") + except Exception: + print("❌ Error loading preprocessed JSON.") + return [] + + try: + doc = fitz.open(pdf_path) + except Exception: + print("❌ Error loading PDF.") + return [] + + final_page_predictions = [] + CHUNK_SIZE = 500 + + for page_data in preprocessed_data: + page_num_1_based = page_data['page_number'] + page_num_0_based = page_num_1_based - 1 + page_raw_predictions = [] + print(f"\n *** Processing Page {page_num_1_based} ({len(page_data['data'])} raw tokens) ***") + + fitz_page = doc.load_page(page_num_0_based) + page_width, page_height = fitz_page.rect.width, fitz_page.rect.height + print(f" -> Page dimensions: {page_width:.0f}x{page_height:.0f} (PDF points).") + + all_token_data = [] + scale_factor = 2.0 + + for item in page_data['data']: + raw_yolo_bbox = item['bbox'] + bbox_pdf = [ + int(raw_yolo_bbox[0] / scale_factor), int(raw_yolo_bbox[1] / scale_factor), + int(raw_yolo_bbox[2] / scale_factor), int(raw_yolo_bbox[3] / scale_factor) + ] + normalized_bbox = [ + max(0, min(1000, int(1000 * bbox_pdf[0] / page_width))), + max(0, min(1000, int(1000 * bbox_pdf[1] / page_height))), + max(0, min(1000, int(1000 * bbox_pdf[2] / page_width))), + max(0, min(1000, int(1000 * bbox_pdf[3] / page_height))) + ] + all_token_data.append({ + "word": item['word'], + "bbox_raw_pdf_space": bbox_pdf, + "bbox_normalized": normalized_bbox, + "item_original_data": item + }) + + if not all_token_data: + continue + + column_separator_x = page_data.get('column_separator_x', None) + if column_separator_x is not None: + print(f" -> Using SAVED column separator: X={column_separator_x}") + else: + print(" -> No column separator found. Assuming single chunk.") + + token_chunks = _merge_integrity(all_token_data, column_separator_x) + total_chunks = len(token_chunks) + + for chunk_idx, chunk_tokens in enumerate(token_chunks): + if not chunk_tokens: continue + + # 1. Sanitize: Convert everything to strings and aggressively clean Unicode errors. + chunk_words = [ + str(t['word']).encode('utf-8', errors='ignore').decode('utf-8') + for t in chunk_tokens + ] + chunk_normalized_bboxes = [t['bbox_normalized'] for t in chunk_tokens] + + total_sub_chunks = (len(chunk_words) + CHUNK_SIZE - 1) // CHUNK_SIZE + for i in range(0, len(chunk_words), CHUNK_SIZE): + sub_chunk_idx = i // CHUNK_SIZE + 1 + sub_words = chunk_words[i:i + CHUNK_SIZE] + sub_bboxes = chunk_normalized_bboxes[i:i + CHUNK_SIZE] + sub_tokens_data = chunk_tokens[i:i + CHUNK_SIZE] + + print( + f" -> Chunk {chunk_idx + 1}/{total_chunks}, Sub-chunk {sub_chunk_idx}/{total_sub_chunks}: {len(sub_words)} words. Running Inference...") + + # 2. Manual generation of word_ids + manual_word_ids = [] + for current_word_idx, word in enumerate(sub_words): + sub_tokens = tokenizer.tokenize(word) + for _ in sub_tokens: + manual_word_ids.append(current_word_idx) + + encoded_input = tokenizer( + sub_words, + boxes=sub_bboxes, + truncation=True, + padding="max_length", + max_length=512, + is_split_into_words=True, + return_tensors="pt" + ) + + # Check for empty sequence + if encoded_input['input_ids'].shape[0] == 0: + print(f" -> Warning: Sub-chunk {sub_chunk_idx} encoded to an empty sequence. Skipping.") + continue + + # 3. Finalize word_ids based on encoded output length + sequence_length = int(torch.sum(encoded_input['attention_mask']).item()) + content_token_length = max(0, sequence_length - 2) + + manual_word_ids = manual_word_ids[:content_token_length] + + final_word_ids = [None] # CLS token (index 0) + final_word_ids.extend(manual_word_ids) + + if sequence_length > 1: + final_word_ids.append(None) # SEP token + + final_word_ids.extend([None] * (512 - len(final_word_ids))) + word_ids = final_word_ids[:512] # Final array for mapping + + # Inputs are already batched by the tokenizer as [1, 512] + input_ids = encoded_input['input_ids'].to(device) + bbox = encoded_input['bbox'].to(device) + attention_mask = encoded_input['attention_mask'].to(device) + + with torch.no_grad(): + model_outputs = model(input_ids, bbox, attention_mask) + + # --- Robust extraction: support several forward return types --- + # We'll try (in order): + # 1) model_outputs is (emissions_tensor, viterbi_list) -> use emissions for logits, keep decoded + # 2) model_outputs has .logits attribute (HF ModelOutput) + # 3) model_outputs is tuple/list containing a logits tensor + # 4) model_outputs is a tensor (assume logits) + # 5) model_outputs is a list-of-lists of ints (viterbi decoded) -> use that directly (no logits) + logits_tensor = None + decoded_labels_list = None + + # case 1: tuple/list with (emissions, viterbi) + if isinstance(model_outputs, (tuple, list)) and len(model_outputs) == 2: + a, b = model_outputs + # a might be tensor (emissions), b might be viterbi list + if isinstance(a, torch.Tensor): + logits_tensor = a + if isinstance(b, list): + decoded_labels_list = b + + # case 2: HF ModelOutput with .logits + if logits_tensor is None and hasattr(model_outputs, 'logits') and isinstance(model_outputs.logits, + torch.Tensor): + logits_tensor = model_outputs.logits + + # case 3: tuple/list - search for a 3D tensor (B, L, C) + if logits_tensor is None and isinstance(model_outputs, (tuple, list)): + found_tensor = None + for item in model_outputs: + if isinstance(item, torch.Tensor): + # prefer 3D (batch, seq, labels) + if item.dim() == 3: + logits_tensor = item + break + if found_tensor is None: + found_tensor = item + if logits_tensor is None and found_tensor is not None: + # found_tensor may be (batch, seq, hidden) or (seq, hidden); we avoid guessing. + # Keep found_tensor only if it matches num_labels dimension + if found_tensor.dim() == 3 and found_tensor.shape[-1] == NUM_LABELS: + logits_tensor = found_tensor + elif found_tensor.dim() == 2 and found_tensor.shape[-1] == NUM_LABELS: + logits_tensor = found_tensor.unsqueeze(0) + + # case 4: model_outputs directly a tensor + if logits_tensor is None and isinstance(model_outputs, torch.Tensor): + logits_tensor = model_outputs + + # case 5: model_outputs is a decoded viterbi list (common for CRF-only forward) + if decoded_labels_list is None and isinstance(model_outputs, list) and model_outputs and isinstance( + model_outputs[0], list): + # assume model_outputs is already viterbi decoded: List[List[int]] with batch dim first + decoded_labels_list = model_outputs + + # If neither logits nor decoded exist, that's fatal + if logits_tensor is None and decoded_labels_list is None: + # helpful debug info + try: + elem_shapes = [(type(x), getattr(x, 'shape', None)) for x in model_outputs] if isinstance( + model_outputs, (list, tuple)) else [ + (type(model_outputs), getattr(model_outputs, 'shape', None))] + except Exception: + elem_shapes = str(type(model_outputs)) + raise RuntimeError( + f"Model output of type {type(model_outputs)} did not contain a valid logits tensor or decoded viterbi. Contents: {elem_shapes}") + + # If we have logits_tensor, normalize shape to [seq_len, num_labels] + if logits_tensor is not None: + # If shape is [B, L, C] with B==1, squeeze batch + if logits_tensor.dim() == 3 and logits_tensor.shape[0] == 1: + preds_tensor = logits_tensor.squeeze(0) # [L, C] + else: + preds_tensor = logits_tensor # possibly [L, C] already + + # Safety: ensure we have at least seq_len x channels + if preds_tensor.dim() != 2: + # try to reshape or error + raise RuntimeError(f"Unexpected logits tensor shape: {tuple(preds_tensor.shape)}") + # We'll use preds_tensor[token_idx] to argmax + else: + preds_tensor = None # no logits available + + # If decoded labels provided, make a token-level list-of-ints aligned to tokenizer tokens + decoded_token_labels = None + if decoded_labels_list is not None: + # decoded_labels_list is batch-first; we used batch size 1 + # if multiple sequences returned, take first + decoded_token_labels = decoded_labels_list[0] if isinstance(decoded_labels_list[0], + list) else decoded_labels_list + + # Now map token-level predictions -> word-level predictions using word_ids + word_idx_to_pred_id = {} + + if preds_tensor is not None: + # We have logits. Use argmax of logits for each token id up to sequence_length + for token_idx, word_idx in enumerate(word_ids): + if token_idx >= sequence_length: + break + if word_idx is not None and word_idx < len(sub_words): + if word_idx not in word_idx_to_pred_id: + pred_id = torch.argmax(preds_tensor[token_idx]).item() + word_idx_to_pred_id[word_idx] = pred_id + else: + # No logits, but we have decoded_token_labels from CRF (one label per token) + # We'll align decoded_token_labels to token positions. + if decoded_token_labels is None: + # should not happen due to earlier checks + raise RuntimeError("No logits and no decoded labels available for mapping.") + # decoded_token_labels length may be equal to content_token_length (no special tokens) + # or equal to sequence_length; try to align intelligently: + # Prefer using decoded_token_labels aligned to the tokenizer tokens (starting at token 1 for CLS) + # If decoded length == content_token_length, then manual_word_ids maps sub-token -> word idx for content tokens only. + # We'll iterate tokens and pick label accordingly. + # Build token_idx -> decoded_label mapping: + # We'll assume decoded_token_labels correspond to content tokens (no CLS/SEP). If decoded length == sequence_length, then shift by 0. + decoded_len = len(decoded_token_labels) + # Heuristic: if decoded_len == content_token_length -> alignment starts at token_idx 1 (skip CLS) + if decoded_len == content_token_length: + decoded_start = 1 + elif decoded_len == sequence_length: + decoded_start = 0 + else: + # fallback: prefer decoded_start=1 (most common) + decoded_start = 1 + + for tok_idx_in_decoded, label_id in enumerate(decoded_token_labels): + tok_idx = decoded_start + tok_idx_in_decoded + if tok_idx >= 512: + break + if tok_idx >= sequence_length: + break + # map this token to a word index if present + word_idx = word_ids[tok_idx] if tok_idx < len(word_ids) else None + if word_idx is not None and word_idx < len(sub_words): + if word_idx not in word_idx_to_pred_id: + # label_id may already be an int + word_idx_to_pred_id[word_idx] = int(label_id) + + # Finally convert mapped word preds -> page_raw_predictions entries + for current_word_idx in range(len(sub_words)): + pred_id = word_idx_to_pred_id.get(current_word_idx, 0) # default to 0 + predicted_label = ID_TO_LABEL[pred_id] + original_token = sub_tokens_data[current_word_idx] + page_raw_predictions.append({ + "word": original_token['word'], + "bbox": original_token['bbox_raw_pdf_space'], + "predicted_label": predicted_label, + "page_number": page_num_1_based + }) + + if page_raw_predictions: + final_page_predictions.append({ + "page_number": page_num_1_based, + "data": page_raw_predictions + }) + print(f" *** Page {page_num_1_based} Finalized: {len(page_raw_predictions)} labeled words. ***") + + doc.close() + print("\n" + "=" * 80) + print("--- LAYOUTLMV3 INFERENCE COMPLETE ---") + print("=" * 80) + return final_page_predictions + + +# def run_inference_and_get_raw_words(pdf_path: str, model_path: str, +# preprocessed_json_path: str, +# column_detection_params: Optional[Dict] = None) -> List[Dict[str, Any]]: +# print("\n" + "=" * 80) +# print("--- 2. STARTING LAYOUTLMV3 INFERENCE PIPELINE (Raw Word Output) ---") +# print("=" * 80) + +# tokenizer = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base") +# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +# print(f" -> Using device: {device}") + +# try: +# model = LayoutLMv3ForTokenClassification(num_labels=NUM_LABELS) +# checkpoint = torch.load(model_path, map_location=device) +# model_state = checkpoint.get('model_state_dict', checkpoint) +# # Apply patch for layoutlmv3 compatibility with saved state_dict +# fixed_state_dict = {key.replace('layoutlm.', 'layoutlmv3.'): value for key, value in model_state.items()} +# model.load_state_dict(fixed_state_dict) +# model.to(device) +# model.eval() +# print(f"✅ LayoutLMv3 Model loaded successfully from {os.path.basename(model_path)}.") +# except Exception as e: +# print(f"❌ FATAL ERROR during LayoutLMv3 model loading: {e}") +# return [] + +# try: +# with open(preprocessed_json_path, 'r', encoding='utf-8') as f: +# preprocessed_data = json.load(f) +# print(f"✅ Loaded preprocessed data with {len(preprocessed_data)} pages.") +# except Exception: +# print("❌ Error loading preprocessed JSON.") +# return [] + +# try: +# doc = fitz.open(pdf_path) +# except Exception: +# print("❌ Error loading PDF.") +# return [] + +# final_page_predictions = [] +# CHUNK_SIZE = 500 + +# for page_data in preprocessed_data: +# page_num_1_based = page_data['page_number'] +# page_num_0_based = page_num_1_based - 1 +# page_raw_predictions = [] +# print(f"\n *** Processing Page {page_num_1_based} ({len(page_data['data'])} raw tokens) ***") + +# fitz_page = doc.load_page(page_num_0_based) +# page_width, page_height = fitz_page.rect.width, fitz_page.rect.height +# print(f" -> Page dimensions: {page_width:.0f}x{page_height:.0f} (PDF points).") + +# all_token_data = [] +# scale_factor = 2.0 + +# for item in page_data['data']: +# raw_yolo_bbox = item['bbox'] +# bbox_pdf = [ +# int(raw_yolo_bbox[0] / scale_factor), int(raw_yolo_bbox[1] / scale_factor), +# int(raw_yolo_bbox[2] / scale_factor), int(raw_yolo_bbox[3] / scale_factor) +# ] +# normalized_bbox = [ +# max(0, min(1000, int(1000 * bbox_pdf[0] / page_width))), +# max(0, min(1000, int(1000 * bbox_pdf[1] / page_height))), +# max(0, min(1000, int(1000 * bbox_pdf[2] / page_width))), +# max(0, min(1000, int(1000 * bbox_pdf[3] / page_height))) +# ] +# all_token_data.append({ +# "word": item['word'], +# "bbox_raw_pdf_space": bbox_pdf, +# "bbox_normalized": normalized_bbox, +# "item_original_data": item +# }) + +# # ============================================================================== +# # --- DEBUGGING BLOCK: CHECK FIRST 50 TOKENS BEFORE INFERENCE --- +# # ============================================================================== +# print(f"\n[DEBUG] LayoutLMv3 Input (Page {page_num_1_based}): Checking first 50 tokens...") +# debug_count = 0 +# for t in all_token_data: +# if debug_count >= 50: break +# w = t['word'] +# unicode_points = [f"\\u{ord(c):04x}" for c in w] +# print(f" Token {debug_count}: '{w}' -> Codes: {unicode_points}") +# debug_count += 1 +# print("----------------------------------------------------------------------\n") +# # ============================================================================== + +# if not all_token_data: +# continue + +# column_separator_x = page_data.get('column_separator_x', None) +# if column_separator_x is not None: +# print(f" -> Using SAVED column separator: X={column_separator_x}") +# else: +# print(" -> No column separator found. Assuming single chunk.") + +# token_chunks = _merge_integrity(all_token_data, column_separator_x) +# total_chunks = len(token_chunks) + +# for chunk_idx, chunk_tokens in enumerate(token_chunks): +# if not chunk_tokens: continue + +# # 1. Sanitize: Convert everything to strings and aggressively clean Unicode errors. +# chunk_words = [ +# str(t['word']).encode('utf-8', errors='ignore').decode('utf-8') +# for t in chunk_tokens +# ] +# chunk_normalized_bboxes = [t['bbox_normalized'] for t in chunk_tokens] + +# total_sub_chunks = (len(chunk_words) + CHUNK_SIZE - 1) // CHUNK_SIZE +# for i in range(0, len(chunk_words), CHUNK_SIZE): +# sub_chunk_idx = i // CHUNK_SIZE + 1 +# sub_words = chunk_words[i:i + CHUNK_SIZE] +# sub_bboxes = chunk_normalized_bboxes[i:i + CHUNK_SIZE] +# sub_tokens_data = chunk_tokens[i:i + CHUNK_SIZE] + +# print(f" -> Chunk {chunk_idx + 1}/{total_chunks}, Sub-chunk {sub_chunk_idx}/{total_sub_chunks}: {len(sub_words)} words. Running Inference...") + +# # 2. Manual generation of word_ids +# manual_word_ids = [] +# for current_word_idx, word in enumerate(sub_words): +# sub_tokens = tokenizer.tokenize(word) +# for _ in sub_tokens: +# manual_word_ids.append(current_word_idx) + +# encoded_input = tokenizer( +# sub_words, +# boxes=sub_bboxes, +# truncation=True, +# padding="max_length", +# max_length=512, +# is_split_into_words=True, +# return_tensors="pt" +# ) + +# # Check for empty sequence +# if encoded_input['input_ids'].shape[0] == 0: +# print(f" -> Warning: Sub-chunk {sub_chunk_idx} encoded to an empty sequence. Skipping.") +# continue + +# # 3. Finalize word_ids based on encoded output length +# sequence_length = int(torch.sum(encoded_input['attention_mask']).item()) +# content_token_length = max(0, sequence_length - 2) + +# manual_word_ids = manual_word_ids[:content_token_length] + +# final_word_ids = [None] # CLS token (index 0) +# final_word_ids.extend(manual_word_ids) + +# if sequence_length > 1: +# final_word_ids.append(None) # SEP token + +# final_word_ids.extend([None] * (512 - len(final_word_ids))) +# word_ids = final_word_ids[:512] # Final array for mapping + +# # Inputs are already batched by the tokenizer as [1, 512] +# input_ids = encoded_input['input_ids'].to(device) +# bbox = encoded_input['bbox'].to(device) +# attention_mask = encoded_input['attention_mask'].to(device) + +# with torch.no_grad(): +# model_outputs = model(input_ids, bbox, attention_mask) + +# # --- Robust extraction: support several forward return types --- +# logits_tensor = None +# decoded_labels_list = None + +# # case 1: tuple/list with (emissions, viterbi) +# if isinstance(model_outputs, (tuple, list)) and len(model_outputs) == 2: +# a, b = model_outputs +# if isinstance(a, torch.Tensor): +# logits_tensor = a +# if isinstance(b, list): +# decoded_labels_list = b + +# # case 2: HF ModelOutput with .logits +# if logits_tensor is None and hasattr(model_outputs, 'logits') and isinstance(model_outputs.logits, torch.Tensor): +# logits_tensor = model_outputs.logits + +# # case 3: tuple/list - search for a 3D tensor (B, L, C) +# if logits_tensor is None and isinstance(model_outputs, (tuple, list)): +# found_tensor = None +# for item in model_outputs: +# if isinstance(item, torch.Tensor): +# if item.dim() == 3: +# logits_tensor = item +# break +# if found_tensor is None: +# found_tensor = item +# if logits_tensor is None and found_tensor is not None: +# if found_tensor.dim() == 3 and found_tensor.shape[-1] == NUM_LABELS: +# logits_tensor = found_tensor +# elif found_tensor.dim() == 2 and found_tensor.shape[-1] == NUM_LABELS: +# logits_tensor = found_tensor.unsqueeze(0) + +# # case 4: model_outputs directly a tensor +# if logits_tensor is None and isinstance(model_outputs, torch.Tensor): +# logits_tensor = model_outputs + +# # case 5: model_outputs is a decoded viterbi list (common for CRF-only forward) +# if decoded_labels_list is None and isinstance(model_outputs, list) and model_outputs and isinstance(model_outputs[0], list): +# decoded_labels_list = model_outputs + +# # If neither logits nor decoded exist, that's fatal +# if logits_tensor is None and decoded_labels_list is None: +# try: +# elem_shapes = [ (type(x), getattr(x, 'shape', None)) for x in model_outputs ] if isinstance(model_outputs, (list, tuple)) else [(type(model_outputs), getattr(model_outputs, 'shape', None))] +# except Exception: +# elem_shapes = str(type(model_outputs)) +# raise RuntimeError(f"Model output of type {type(model_outputs)} did not contain a valid logits tensor or decoded viterbi. Contents: {elem_shapes}") + +# # If we have logits_tensor, normalize shape to [seq_len, num_labels] +# if logits_tensor is not None: +# if logits_tensor.dim() == 3 and logits_tensor.shape[0] == 1: +# preds_tensor = logits_tensor.squeeze(0) # [L, C] +# else: +# preds_tensor = logits_tensor # possibly [L, C] already + +# if preds_tensor.dim() != 2: +# raise RuntimeError(f"Unexpected logits tensor shape: {tuple(preds_tensor.shape)}") +# else: +# preds_tensor = None # no logits available + +# # If decoded labels provided, make a token-level list-of-ints aligned to tokenizer tokens +# decoded_token_labels = None +# if decoded_labels_list is not None: +# decoded_token_labels = decoded_labels_list[0] if isinstance(decoded_labels_list[0], list) else decoded_labels_list + +# # Now map token-level predictions -> word-level predictions using word_ids +# word_idx_to_pred_id = {} + +# if preds_tensor is not None: +# for token_idx, word_idx in enumerate(word_ids): +# if token_idx >= sequence_length: +# break +# if word_idx is not None and word_idx < len(sub_words): +# if word_idx not in word_idx_to_pred_id: +# pred_id = torch.argmax(preds_tensor[token_idx]).item() +# word_idx_to_pred_id[word_idx] = pred_id +# else: +# if decoded_token_labels is None: +# raise RuntimeError("No logits and no decoded labels available for mapping.") +# decoded_len = len(decoded_token_labels) +# if decoded_len == content_token_length: +# decoded_start = 1 +# elif decoded_len == sequence_length: +# decoded_start = 0 +# else: +# decoded_start = 1 + +# for tok_idx_in_decoded, label_id in enumerate(decoded_token_labels): +# tok_idx = decoded_start + tok_idx_in_decoded +# if tok_idx >= 512: +# break +# if tok_idx >= sequence_length: +# break +# word_idx = word_ids[tok_idx] if tok_idx < len(word_ids) else None +# if word_idx is not None and word_idx < len(sub_words): +# if word_idx not in word_idx_to_pred_id: +# word_idx_to_pred_id[word_idx] = int(label_id) + +# # Finally convert mapped word preds -> page_raw_predictions entries +# for current_word_idx in range(len(sub_words)): +# pred_id = word_idx_to_pred_id.get(current_word_idx, 0) # default to 0 +# predicted_label = ID_TO_LABEL[pred_id] +# original_token = sub_tokens_data[current_word_idx] +# page_raw_predictions.append({ +# "word": original_token['word'], +# "bbox": original_token['bbox_raw_pdf_space'], +# "predicted_label": predicted_label, +# "page_number": page_num_1_based +# }) + +# if page_raw_predictions: +# final_page_predictions.append({ +# "page_number": page_num_1_based, +# "data": page_raw_predictions +# }) +# print(f" *** Page {page_num_1_based} Finalized: {len(page_raw_predictions)} labeled words. ***") + +# doc.close() +# print("\n" + "=" * 80) +# print("--- LAYOUTLMV3 INFERENCE COMPLETE ---") +# print("=" * 80) +# return final_page_predictions + + +# ============================================================================ +# --- PHASE 3: BIO TO STRUCTURED JSON DECODER --- +# ============================================================================ + + +def convert_bio_to_structured_json_relaxed(input_path: str, output_path: str) -> Optional[List[Dict[str, Any]]]: + print("\n" + "=" * 80) + print("--- 3. STARTING BIO TO STRUCTURED JSON DECODING ---") + print("=" * 80) + try: + with open(input_path, 'r', encoding='utf-8') as f: + predictions_by_page = json.load(f) + except Exception as e: + print(f"❌ Error loading raw prediction file: {e}") + return None + + predictions = [] + for page_item in predictions_by_page: + if isinstance(page_item, dict) and 'data' in page_item: + predictions.extend(page_item['data']) + + structured_data = [] + current_item = None + current_option_key = None + current_passage_buffer = [] + current_text_buffer = [] + first_question_started = False + last_entity_type = None + just_finished_i_option = False + is_in_new_passage = False + + def finalize_passage_to_item(item, passage_buffer): + if passage_buffer: + passage_text = re.sub(r'\s{2,}', ' ', ' '.join(passage_buffer)).strip() + if item.get('passage'): + item['passage'] += ' ' + passage_text + else: + item['passage'] = passage_text + passage_buffer.clear() + + for item in predictions: + word = item['word'] + label = item['predicted_label'] + entity_type = label[2:].strip() if label.startswith(('B-', 'I-')) else None + current_text_buffer.append(word) + previous_entity_type = last_entity_type + is_passage_label = (entity_type == 'PASSAGE') + + if not first_question_started: + if label != 'B-QUESTION' and not is_passage_label: + just_finished_i_option = False + is_in_new_passage = False + continue + if is_passage_label: + current_passage_buffer.append(word) + last_entity_type = 'PASSAGE' + just_finished_i_option = False + is_in_new_passage = False + continue + + if label == 'B-QUESTION': + if not first_question_started: + header_text = ' '.join(current_text_buffer[:-1]).strip() + if header_text or current_passage_buffer: + metadata_item = {'type': 'METADATA', 'passage': ''} + finalize_passage_to_item(metadata_item, current_passage_buffer) + if header_text: metadata_item['text'] = header_text + structured_data.append(metadata_item) + first_question_started = True + current_text_buffer = [word] + + if current_item is not None: + finalize_passage_to_item(current_item, current_passage_buffer) + current_item['text'] = ' '.join(current_text_buffer[:-1]).strip() + structured_data.append(current_item) + current_text_buffer = [word] + + current_item = { + 'question': word, 'options': {}, 'answer': '', 'passage': '', 'text': '' + } + current_option_key = None + last_entity_type = 'QUESTION' + just_finished_i_option = False + is_in_new_passage = False + continue + + if current_item is not None: + if is_in_new_passage: + # 🔑 Robust Initialization and Appending for 'new_passage' + if 'new_passage' not in current_item: + current_item['new_passage'] = word + else: + current_item['new_passage'] += f' {word}' + + if label.startswith('B-') or (label.startswith('I-') and entity_type != 'PASSAGE'): + is_in_new_passage = False + if label.startswith(('B-', 'I-')): last_entity_type = entity_type + continue + is_in_new_passage = False + + if label.startswith('B-'): + if entity_type in ['QUESTION', 'OPTION', 'ANSWER', 'SECTION_HEADING']: + finalize_passage_to_item(current_item, current_passage_buffer) + current_passage_buffer = [] + last_entity_type = entity_type + if entity_type == 'PASSAGE': + if previous_entity_type == 'OPTION' and just_finished_i_option: + current_item['new_passage'] = word # Initialize the new passage start + is_in_new_passage = True + else: + current_passage_buffer.append(word) + elif entity_type == 'OPTION': + current_option_key = word + current_item['options'][current_option_key] = word + just_finished_i_option = False + elif entity_type == 'ANSWER': + current_item['answer'] = word + current_option_key = None + just_finished_i_option = False + elif entity_type == 'QUESTION': + current_item['question'] += f' {word}' + just_finished_i_option = False + + elif label.startswith('I-'): + if entity_type == 'QUESTION': + current_item['question'] += f' {word}' + elif entity_type == 'PASSAGE': + if previous_entity_type == 'OPTION' and just_finished_i_option: + current_item['new_passage'] = word # Initialize the new passage start + is_in_new_passage = True + else: + if not current_passage_buffer: last_entity_type = 'PASSAGE' + current_passage_buffer.append(word) + elif entity_type == 'OPTION' and current_option_key is not None: + current_item['options'][current_option_key] += f' {word}' + just_finished_i_option = True + elif entity_type == 'ANSWER': + current_item['answer'] += f' {word}' + just_finished_i_option = (entity_type == 'OPTION') + + elif label == 'O': + if last_entity_type == 'QUESTION': + current_item['question'] += f' {word}' + just_finished_i_option = False + + if current_item is not None: + finalize_passage_to_item(current_item, current_passage_buffer) + current_item['text'] = ' '.join(current_text_buffer).strip() + structured_data.append(current_item) + + for item in structured_data: + item['text'] = re.sub(r'\s{2,}', ' ', item['text']).strip() + if 'new_passage' in item: + item['new_passage'] = re.sub(r'\s{2,}', ' ', item['new_passage']).strip() + + try: + with open(output_path, 'w', encoding='utf-8') as f: + json.dump(structured_data, f, indent=2, ensure_ascii=False) + except Exception: + pass + + return structured_data + + +def create_query_text(entry: Dict[str, Any]) -> str: + """Combines question and options into a single string for similarity matching.""" + query_parts = [] + if entry.get("question"): + query_parts.append(entry["question"]) + + for key in ["options", "options_text"]: + options = entry.get(key) + if options and isinstance(options, dict): + for value in options.values(): + if value and isinstance(value, str): + query_parts.append(value) + return " ".join(query_parts) + + +def calculate_similarity(doc1: str, doc2: str) -> float: + """Calculates Cosine Similarity between two text strings.""" + if not doc1 or not doc2: + return 0.0 + + def clean_text(text): + return re.sub(r'^\s*[\(\d\w]+\.?\s*', '', text, flags=re.MULTILINE) + + clean_doc1 = clean_text(doc1) + clean_doc2 = clean_text(doc2) + corpus = [clean_doc1, clean_doc2] + + try: + vectorizer = CountVectorizer(stop_words='english', lowercase=True, token_pattern=r'(?u)\b\w\w+\b') + tfidf_matrix = vectorizer.fit_transform(corpus) + if tfidf_matrix.shape[1] == 0: + return 0.0 + vectors = tfidf_matrix.toarray() + # Handle cases where vectors might be empty or too short + if len(vectors) < 2: + return 0.0 + score = cosine_similarity(vectors[0:1], vectors[1:2])[0][0] + return score + except Exception: + return 0.0 + + +def process_context_linking(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Links questions to passages based on 'passage' flow vs 'new_passage' priority. + Includes 'Decay Logic': If 2 consecutive questions fail to match the active passage, + the passage context is dropped to prevent false positives downstream. + """ + print("\n" + "=" * 80) + print("--- STARTING CONTEXT LINKING (WITH DECAY LOGIC) ---") + print("=" * 80) + + if not data: return [] + + # --- PHASE 1: IDENTIFY PASSAGE DEFINERS --- + passage_definer_indices = [] + for i, entry in enumerate(data): + if entry.get("passage") and entry["passage"].strip(): + passage_definer_indices.append(i) + if entry.get("new_passage") and entry["new_passage"].strip(): + if i not in passage_definer_indices: + passage_definer_indices.append(i) + + # --- PHASE 2: CONTEXT TRANSFER & LINKING --- + current_passage_text = None + current_new_passage_text = None + + # NEW: Counter to track consecutive linking failures + consecutive_failures = 0 + MAX_CONSECUTIVE_FAILURES = 2 + + for i, entry in enumerate(data): + item_type = entry.get("type", "Question") + + # A. UNCONDITIONALLY UPDATE CONTEXTS (And Reset Decay Counter) + if entry.get("passage") and entry["passage"].strip(): + current_passage_text = entry["passage"] + consecutive_failures = 0 # Reset because we have fresh explicit context + # print(f" [Flow] Updated Standard Context from Item {i}") + + if entry.get("new_passage") and entry["new_passage"].strip(): + current_new_passage_text = entry["new_passage"] + # We don't necessarily reset standard failures here as this is a local override + + # B. QUESTION LINKING + if entry.get("question") and item_type != "METADATA": + combined_query = create_query_text(entry) + + # Skip if query is too short (noise) + if len(combined_query.strip()) < 5: + continue + + # Calculate scores + score_old = calculate_similarity(current_passage_text, combined_query) if current_passage_text else 0.0 + score_new = calculate_similarity(current_new_passage_text, + combined_query) if current_new_passage_text else 0.0 + + # ------------------------------------------------------------------ + # 🛑 CRITICAL FIX APPLIED HERE 🛑 + # The original line: q_preview = entry['question'][:30] + '...' + + # 1. Capture the raw preview string (which might contain the bad surrogate) + q_preview_raw = entry['question'][:30] + '...' + + # 2. Safely clean the string by encoding to UTF-8 and ignoring errors, + # then decoding back. This removes the invalid surrogate character. + q_preview = q_preview_raw.encode('utf-8', errors='ignore').decode('utf-8') + # ------------------------------------------------------------------ + + # RESOLUTION LOGIC + linked = False + + # 1. Prefer New Passage if significantly better + if current_new_passage_text and (score_new > score_old + RESOLUTION_MARGIN) and ( + score_new >= SIMILARITY_THRESHOLD): + entry["passage"] = current_new_passage_text + print(f" [Linker] 🚀 Q{i} ('{q_preview}') -> NEW PASSAGE (Score: {score_new:.3f})") + linked = True + # Note: We do not reset 'consecutive_failures' for the standard passage here, + # because we matched the *new* passage, not the standard one. + + # 2. Otherwise use Standard Passage if it meets threshold + elif current_passage_text and (score_old >= SIMILARITY_THRESHOLD): + entry["passage"] = current_passage_text + print(f" [Linker] ✅ Q{i} ('{q_preview}') -> STANDARD PASSAGE (Score: {score_old:.3f})") + linked = True + consecutive_failures = 0 # Success! Reset the kill switch. + + if not linked: + # 3. DECAY LOGIC + if current_passage_text: + consecutive_failures += 1 + # This is the line that was failing (or similar logging lines) + print( + f" [Linker] ⚠️ Q{i} NOT LINKED. (Failures: {consecutive_failures}/{MAX_CONSECUTIVE_FAILURES})") + + if consecutive_failures >= MAX_CONSECUTIVE_FAILURES: + print(f" [Linker] 🗑️ Context dropped due to {consecutive_failures} consecutive misses.") + current_passage_text = None + consecutive_failures = 0 + else: + print(f" [Linker] ⚠️ Q{i} NOT LINKED (No active context).") + + # --- PHASE 3: CLEANUP AND INTERPOLATION --- + print(" [Linker] Running Cleanup & Interpolation...") + + # 3A. Self-Correction (Remove weak links) + for i in passage_definer_indices: + entry = data[i] + if entry.get("question") and entry.get("type") != "METADATA": + passage_to_check = entry.get("passage") or entry.get("new_passage") + if passage_to_check: + self_sim = calculate_similarity(passage_to_check, create_query_text(entry)) + if self_sim < SIMILARITY_THRESHOLD: + entry["passage"] = "" + if "new_passage" in entry: entry["new_passage"] = "" + print(f" [Cleanup] Removed weak link for Q{i}") + + # 3B. Interpolation (Fill gaps) + # We only interpolate if the gap is strictly 1 question wide to avoid undoing the decay logic + for i in range(1, len(data) - 1): + current_entry = data[i] + is_gap = current_entry.get("question") and not current_entry.get("passage") + if is_gap: + prev_p = data[i - 1].get("passage") + next_p = data[i + 1].get("passage") + if prev_p and next_p and (prev_p == next_p) and prev_p.strip(): + current_entry["passage"] = prev_p + print(f" [Linker] 🥪 Q{i} Interpolated from neighbors.") + + return data + + +def correct_misaligned_options(structured_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + print("\n" + "=" * 80) + print("--- 5. STARTING POST-PROCESSING: OPTION ALIGNMENT CORRECTION ---") + print("=" * 80) + tag_pattern = re.compile(r'(EQUATION\d+|FIGURE\d+)') + corrected_count = 0 + for item in structured_data: + if item.get('type') in ['METADATA']: continue + options = item.get('options') + if not options or len(options) < 2: continue + option_keys = list(options.keys()) + for i in range(len(option_keys) - 1): + current_key = option_keys[i] + next_key = option_keys[i + 1] + current_value = options[current_key].strip() + next_value = options[next_key].strip() + is_current_empty = current_value == current_key + content_in_next = next_value.replace(next_key, '', 1).strip() + tags_in_next = tag_pattern.findall(content_in_next) + has_two_tags = len(tags_in_next) == 2 + if is_current_empty and has_two_tags: + tag_to_move = tags_in_next[0] + options[current_key] = f"{current_key} {tag_to_move}".strip() + options[next_key] = f"{next_key} {tags_in_next[1]}".strip() + corrected_count += 1 + print(f"✅ Option alignment correction finished. Total corrections: {corrected_count}.") + return structured_data + + +def get_base64_for_file(filepath: str) -> Optional[str]: + """Reads a file and returns its Base64 encoded string without the data URI prefix.""" + try: + with open(filepath, "rb") as image_file: + # Return raw base64 string + return base64.b64encode(image_file.read()).decode('utf-8') + except Exception as e: + print(f"Error reading and encoding file {filepath}: {e}") + return None + + +def embed_images_as_base64_in_memory(structured_data: List[Dict[str, Any]], figure_extraction_dir: str) -> List[ + Dict[str, Any]]: + print("\n" + "=" * 80) + print("--- 4. STARTING IMAGE EMBEDDING (Base64) / EQUATION TO LATEX CONVERSION ---") + print("=" * 80) + if not structured_data: + return [] + + image_files = glob.glob(os.path.join(figure_extraction_dir, "*.png")) + image_lookup = {} + tag_regex = re.compile(r'(figure|equation)(\d+)', re.IGNORECASE) + + for filepath in image_files: + filename = os.path.basename(filepath) + match = re.search(r'_(figure|equation)(\d+)\.png$', filename, re.IGNORECASE) + if match: + key = f"{match.group(1).upper()}{match.group(2)}" + image_lookup[key] = filepath + + print(f" -> Found {len(image_lookup)} image components.") + + final_structured_data = [] + + for item in structured_data: + text_fields = [item.get('question', ''), item.get('passage', '')] + if 'options' in item: + for opt_val in item['options'].values(): + text_fields.append(opt_val) + if 'new_passage' in item: + text_fields.append(item['new_passage']) + + unique_tags_to_embed = set() + for text in text_fields: + if not text: continue + for match in tag_regex.finditer(text): + tag = match.group(0).upper() + if tag in image_lookup: + unique_tags_to_embed.add(tag) + + # List of tags that were successfully converted to LaTeX + tags_converted_to_latex = set() + + for tag in sorted(list(unique_tags_to_embed)): + filepath = image_lookup[tag] + base_key = tag.replace(' ', '').lower() # e.g., figure1 or equation1 + + if 'EQUATION' in tag: + # Equation to LaTeX conversion + base64_code = get_base64_for_file(filepath) # This reads the file for conversion + if base64_code: + latex_output = get_latex_from_base64(base64_code) + if not latex_output.startswith('[P2T_ERROR') and not latex_output.startswith('[P2T_WARNING'): + # *** CORE CHANGE: Store the clean LaTeX output directly *** + item[base_key] = latex_output + tags_converted_to_latex.add(tag) + print(f" ✅ Embedded Clean LaTeX for {tag}") + else: + # On failure, embed the error message + item[base_key] = latex_output + print(f" ⚠️ Failed to convert {tag} to LaTeX. Embedding error message.") + else: + item[base_key] = "[FILE_ERROR: Could not read image file]" + print(f" ❌ File read error for {tag}.") + + elif 'FIGURE' in tag: + # Figure to Base64 conversion + base64_code = get_base64_for_file(filepath) + item[base_key] = base64_code + print(f" ✅ Embedded Base64 for {tag}") + + final_structured_data.append(item) + + print(f"✅ Image embedding complete.") + return final_structured_data + + +# ============================================================================ +# --- MAIN FUNCTION --- +# ============================================================================ + + +# def run_document_pipeline(input_pdf_path: str, layoutlmv3_model_path: str) -> Optional[ +# List[Dict[str, Any]]]: +def run_document_pipeline(input_pdf_path: str, layoutlmv3_model_path: str, + structured_intermediate_output_path: Optional[str] = None) -> Optional[List[Dict[str, Any]]]: + if not os.path.exists(input_pdf_path): return None + + print("\n" + "#" * 80) + print("### STARTING OPTIMIZED FULL DOCUMENT ANALYSIS PIPELINE ###") + print("#" * 80) + + pdf_name = os.path.splitext(os.path.basename(input_pdf_path))[0] + temp_pipeline_dir = os.path.join(tempfile.gettempdir(), f"pipeline_run_{pdf_name}_{os.getpid()}") + os.makedirs(temp_pipeline_dir, exist_ok=True) + + preprocessed_json_path = os.path.join(temp_pipeline_dir, f"{pdf_name}_preprocessed.json") + raw_output_path = os.path.join(temp_pipeline_dir, f"{pdf_name}_raw_predictions.json") + structured_intermediate_output_path = os.path.join(temp_pipeline_dir, f"{pdf_name}_structured_intermediate.json") + + final_result = None + try: + # Phase 1: Preprocessing with YOLO First + Masking + preprocessed_json_path_out = run_single_pdf_preprocessing(input_pdf_path, preprocessed_json_path) + if not preprocessed_json_path_out: return None + + # Phase 2: Inference + page_raw_predictions_list = run_inference_and_get_raw_words( + input_pdf_path, layoutlmv3_model_path, preprocessed_json_path_out + ) + if not page_raw_predictions_list: return None + + # --- DEBUG STEP: SAVE RAW PREDICTIONS --- + # Save raw predictions to the temporary file + with open(raw_output_path, 'w', encoding='utf-8') as f: + json.dump(page_raw_predictions_list, f, indent=4) + + # Explicitly copy/save the raw predictions to the user-specified debug path + # if raw_predictions_output_path: + # shutil.copy(raw_output_path, raw_predictions_output_path) + # print(f"\n✅ DEBUG: Raw predictions saved to: {raw_predictions_output_path}") + # ---------------------------------------- + + # Phase 3: Decoding + structured_data_list = convert_bio_to_structured_json_relaxed( + raw_output_path, structured_intermediate_output_path + ) + if not structured_data_list: return None + structured_data_list = correct_misaligned_options(structured_data_list) + structured_data_list = process_context_linking(structured_data_list) + + # Phase 4: Embedding / Equation to LaTeX Conversion + final_result = embed_images_as_base64_in_memory(structured_data_list, FIGURE_EXTRACTION_DIR) + + # ================================================================================ + # --- NEW FINAL STEP: HIERARCHICAL CLASSIFICATION TAGGING --- + # ================================================================================ + + print("\n" + "=" * 80) + print("--- FINAL STEP: HIERARCHICAL SUBJECT/CONCEPT TAGGING ---") + print("=" * 80) + + # 1. Initialize and Load the Classifier + classifier = HierarchicalClassifier() + if classifier.load_models(): + # 2. Run Classification on the *Final* Result + # The function modifies the list in place and returns it + final_result = post_process_json_with_inference( + final_result, classifier + ) + print("✅ Classification complete. Tags added to final output.") + else: + print("❌ Classification model loading failed. Outputting un-tagged data.") + + # ==================================================================== + + + except Exception as e: + print(f"❌ FATAL ERROR: {e}") + import traceback + traceback.print_exc() + return None + + finally: + try: + for f in glob.glob(os.path.join(temp_pipeline_dir, '*')): + os.remove(f) + os.rmdir(temp_pipeline_dir) + except Exception: + pass + + print("\n" + "#" * 80) + print("### OPTIMIZED PIPELINE EXECUTION COMPLETE ###") + print("#" * 80) + return final_result + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Complete Pipeline") + parser.add_argument("--input_pdf", type=str, required=True, help="Input PDF") + parser.add_argument("--layoutlmv3_model_path", type=str, default=DEFAULT_LAYOUTLMV3_MODEL_PATH, help="Model Path") + + # --- ADDED ARGUMENT FOR DEBUGGING --- + parser.add_argument("--raw_preds_path", type=str, default='BIO_debug.json', + help="Debug path for raw BIO tag predictions (JSON).") + # ------------------------------------ + args = parser.parse_args() + + pdf_name = os.path.splitext(os.path.basename(args.input_pdf))[0] + final_output_path = os.path.abspath(f"{pdf_name}_final_output_embedded.json") + + # --- CALCULATE RAW PREDICTIONS OUTPUT PATH (Kept commented as per original script) --- + # raw_predictions_output_path = os.path.abspath( + # args.raw_preds_path if args.raw_preds_path else f"{pdf_name}_raw_predictions_debug.json") + # --------------------------------------------- + + # --- UPDATED FUNCTION CALL --- + final_json_data = run_document_pipeline( + args.input_pdf, + args.layoutlmv3_model_path) + # ----------------------------- + + # 🛑 CRITICAL FINAL FIX: AGGRESSIVE CUSTOM JSON SAVING 🛑 + if final_json_data: + # 1. Dump the Python object to a standard JSON string. + # This converts the in-memory double backslash ('\\') into a quadruple backslash ('\\\\') + # in the raw json_str string content. + json_str = json.dumps(final_json_data, indent=2, ensure_ascii=False) + + # 2. **AGGRESSIVE UNDO ESCAPING:** We assume we have quadruple backslashes and + # replace them with the double backslashes needed for the LaTeX command to work. + # This operation essentially replaces four literal backslashes with two literal backslashes. + # final_output_content = json_str.replace('\\\\\\\\', '\\\\') + + # 3. Write the corrected string content to the file. + with open(final_output_path, 'w', encoding='utf-8') as f: + f.write(json_str) + + print(f"\n✅ Final Data Saved: {final_output_path}") + else: + print("\n❌ Pipeline Failed.") + sys.exit(1) \ No newline at end of file