import fitz # PyMuPDF import numpy as np import cv2 import torch import torch import torch.serialization _original_torch_load = torch.load def patched_torch_load(*args, **kwargs): # FORCE classic behavior kwargs["weights_only"] = False return _original_torch_load(*args, **kwargs) torch.load = patched_torch_load #================================================================================== #RAPID OCR #================================================================================== from rapidocr import RapidOCR, OCRVersion # Initialize RapidOCR (v5 is generally the most accurate current version) # We use return_word_box=True to get word-level precision similar to Tesseract's image_to_data ocr_engine = RapidOCR(params={ "Det.ocr_version": OCRVersion.PPOCRV5, "Rec.ocr_version": OCRVersion.PPOCRV5, "Cls.ocr_version": OCRVersion.PPOCRV4, }) #================================================================================== #RAPID OCR #================================================================================== import json import argparse import os import re import torch.nn as nn from TorchCRF import CRF # from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Model, LayoutLMv3Config from transformers import LayoutLMv3Tokenizer, LayoutLMv3Model, LayoutLMv3Config from typing import List, Dict, Any, Optional, Union, Tuple from ultralytics import YOLO import glob import pytesseract from PIL import Image from scipy.signal import find_peaks from scipy.ndimage import gaussian_filter1d import sys import io import base64 import tempfile import time import shutil from sklearn.feature_extraction.text import CountVectorizer from sklearn.metrics.pairwise import cosine_similarity import logging from transformers import TrOCRProcessor from optimum.onnxruntime import ORTModelForVision2Seq # ============================================================================ # --- TR-OCR/ORT MODEL INITIALIZATION --- # ============================================================================ logging.basicConfig(level=logging.WARNING) processor = None ort_model = None try: MODEL_NAME = 'breezedeus/pix2text-mfr-1.5' processor = TrOCRProcessor.from_pretrained(MODEL_NAME) # Initialize the model for ONNX Runtime # NOTE: Set use_cache=False to avoid caching warnings/issues if reloading ort_model = ORTModelForVision2Seq.from_pretrained(MODEL_NAME, use_cache=False) print("✅ ORTModelForVision2Seq and TrOCRProcessor initialized successfully for equation conversion.") except Exception as e: print(f"❌ Error initializing TrOCR/ORT model. Equations will not be converted: {e}") processor = None ort_model = None from typing import Optional def sanitize_text(text: Optional[str]) -> str: """Removes surrogate characters and other invalid code points that cause UTF-8 encoding errors.""" if not isinstance(text, str) or text is None: return "" # Matches all surrogates (\ud800-\udfff) and common non-characters (\ufffe, \uffff). # This specifically removes '\udefd' which is causing your error. surrogates_and_nonchars = re.compile(r'[\ud800-\udfff\ufffe\uffff]') # Replace the invalid characters with a standard space. # We strip afterward in the calling function. return surrogates_and_nonchars.sub(' ', text) def get_latex_from_base64(base64_string: str) -> str: """ Decodes a Base64 image string and uses the pre-initialized TrOCR/ORT model to recognize the formula. It cleans the output by removing spaces and crucially, replacing double backslashes with single backslashes for correct LaTeX. """ if ort_model is None or processor is None: return "[MODEL_ERROR: Model not initialized]" try: # 1. Decode Base64 to Image image_data = base64.b64decode(base64_string) # We must ensure the image is RGB format for the model input image = Image.open(io.BytesIO(image_data)).convert('RGB') # 2. Preprocess the image pixel_values = processor(images=image, return_tensors="pt").pixel_values # 3. Text Generation (OCR) generated_ids = ort_model.generate(pixel_values) raw_generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True) if not raw_generated_text: return "[OCR_WARNING: No formula found]" latex_string = raw_generated_text[0] # --- 4. Post-processing and Cleanup --- # # A. Remove all spaces/line breaks # cleaned_latex = re.sub(r'\s+', '', latex_string) cleaned_latex = re.sub(r'[\r\n]+', '', latex_string) # B. CRITICAL FIX: Replace double backslashes (\\) with single backslashes (\). # This corrects model output that already over-escaped the LaTeX commands. # Python literal: '\\\\' is replaced with '\\'. #cleaned_latex = cleaned_latex.replace('\\\\', '\\') return cleaned_latex except Exception as e: # Catch any unexpected errors print(f" ❌ TR-OCR Recognition failed: {e}") return f"[TR_OCR_ERROR: Recognition failed: {e}]" # ============================================================================ # --- CONFIGURATION AND CONSTANTS --- # ============================================================================ # NOTE: Update these paths to match your environment before running! WEIGHTS_PATH = 'best.pt' DEFAULT_LAYOUTLMV3_MODEL_PATH = "98.pth" # DIRECTORY CONFIGURATION OCR_JSON_OUTPUT_DIR = './ocr_json_output_final' FIGURE_EXTRACTION_DIR = './figure_extraction' TEMP_IMAGE_DIR = './temp_pdf_images' # Detection parameters # CONF_THRESHOLD = 0.2 TARGET_CLASSES = ['figure', 'equation'] IOU_MERGE_THRESHOLD = 0.4 IOA_SUPPRESSION_THRESHOLD = 0.7 LINE_TOLERANCE = 15 # Similarity SIMILARITY_THRESHOLD = 0.10 RESOLUTION_MARGIN = 0.05 # Global counters for sequential numbering across the entire PDF GLOBAL_FIGURE_COUNT = 0 GLOBAL_EQUATION_COUNT = 0 # LayoutLMv3 Labels ID_TO_LABEL = { 0: "O", 1: "B-QUESTION", 2: "I-QUESTION", 3: "B-OPTION", 4: "I-OPTION", 5: "B-ANSWER", 6: "I-ANSWER", 7: "B-SECTION_HEADING", 8: "I-SECTION_HEADING", 9: "B-PASSAGE", 10: "I-PASSAGE" } NUM_LABELS = len(ID_TO_LABEL) # ============================================================================ # --- PERFORMANCE OPTIMIZATION: OCR CACHE --- # ============================================================================ class OCRCache: """Caches OCR results per page to avoid redundant Tesseract runs.""" def __init__(self): self.cache = {} def get_key(self, pdf_path: str, page_num: int) -> str: return f"{pdf_path}:{page_num}" def has_ocr(self, pdf_path: str, page_num: int) -> bool: return self.get_key(pdf_path, page_num) in self.cache def get_ocr(self, pdf_path: str, page_num: int) -> Optional[list]: return self.cache.get(self.get_key(pdf_path, page_num)) def set_ocr(self, pdf_path: str, page_num: int, ocr_data: list): self.cache[self.get_key(pdf_path, page_num)] = ocr_data def clear(self): self.cache.clear() # Global OCR cache instance _ocr_cache = OCRCache() # ============================================================================ # --- PHASE 1: YOLO/OCR PREPROCESSING FUNCTIONS --- # ============================================================================ def calculate_iou(box1, box2): x1_a, y1_a, x2_a, y2_a = box1 x1_b, y1_b, x2_b, y2_b = box2 x_left = max(x1_a, x1_b) y_top = max(y1_a, y1_b) x_right = min(x2_a, x2_b) y_bottom = min(y2_a, y2_b) intersection_area = max(0, x_right - x_left) * max(0, y_bottom - y_top) box_a_area = (x2_a - x1_a) * (y2_a - y1_a) box_b_area = (x2_b - x1_b) * (y2_b - y1_b) union_area = float(box_a_area + box_b_area - intersection_area) return intersection_area / union_area if union_area > 0 else 0 def calculate_ioa(box1, box2): x1_a, y1_a, x2_a, y2_a = box1 x1_b, y1_b, x2_b, y2_b = box2 x_left = max(x1_a, x1_b) y_top = max(y1_a, y1_b) x_right = min(x2_a, x2_b) y_bottom = min(y2_a, y2_b) intersection_area = max(0, x_right - x_left) * max(0, y_bottom - y_top) box_a_area = (x2_a - x1_a) * (y2_a - y1_a) return intersection_area / box_a_area if box_a_area > 0 else 0 def filter_nested_boxes(detections, ioa_threshold=0.80): """ Removes boxes that are inside larger boxes (Containment Check). Prioritizes keeping the LARGEST box (the 'parent' container). """ if not detections: return [] # 1. Calculate Area for all detections for d in detections: x1, y1, x2, y2 = d['coords'] d['area'] = (x2 - x1) * (y2 - y1) # 2. Sort by Area Descending (Largest to Smallest) # This ensures we process the 'container' first detections.sort(key=lambda x: x['area'], reverse=True) keep_indices = [] is_suppressed = [False] * len(detections) for i in range(len(detections)): if is_suppressed[i]: continue keep_indices.append(i) box_a = detections[i]['coords'] # Compare with all smaller boxes for j in range(i + 1, len(detections)): if is_suppressed[j]: continue box_b = detections[j]['coords'] # Calculate Intersection x_left = max(box_a[0], box_b[0]) y_top = max(box_a[1], box_b[1]) x_right = min(box_a[2], box_b[2]) y_bottom = min(box_a[3], box_b[3]) if x_right < x_left or y_bottom < y_top: intersection = 0 else: intersection = (x_right - x_left) * (y_bottom - y_top) # Calculate IoA (Intersection over Area of the SMALLER box) area_b = detections[j]['area'] if area_b > 0: ioa_small = intersection / area_b # If the small box is > 90% inside the big box, suppress the small one. if ioa_small > ioa_threshold: is_suppressed[j] = True # print(f" [Suppress] Removed nested object inside larger '{detections[i]['class']}'") return [detections[i] for i in keep_indices] def merge_overlapping_boxes(detections, iou_threshold): if not detections: return [] detections.sort(key=lambda d: d['conf'], reverse=True) merged_detections = [] is_merged = [False] * len(detections) for i in range(len(detections)): if is_merged[i]: continue current_box = detections[i]['coords'] current_class = detections[i]['class'] merged_x1, merged_y1, merged_x2, merged_y2 = current_box for j in range(i + 1, len(detections)): if is_merged[j] or detections[j]['class'] != current_class: continue other_box = detections[j]['coords'] iou = calculate_iou(current_box, other_box) if iou > iou_threshold: merged_x1 = min(merged_x1, other_box[0]) merged_y1 = min(merged_y1, other_box[1]) merged_x2 = max(merged_x2, other_box[2]) merged_y2 = max(merged_y2, other_box[3]) is_merged[j] = True merged_detections.append({ 'coords': (merged_x1, merged_y1, merged_x2, merged_y2), 'y1': merged_y1, 'class': current_class, 'conf': detections[i]['conf'] }) return merged_detections def merge_yolo_into_word_data(raw_word_data: list, yolo_detections: list, scale_factor: float) -> list: """ Filters out raw words that are inside YOLO boxes and replaces them with a single solid 'placeholder' block for the column detector. """ if not yolo_detections: return raw_word_data # 1. Convert YOLO boxes (Pixels) to PDF Coordinates (Points) pdf_space_boxes = [] for det in yolo_detections: x1, y1, x2, y2 = det['coords'] pdf_box = ( x1 / scale_factor, y1 / scale_factor, x2 / scale_factor, y2 / scale_factor ) pdf_space_boxes.append(pdf_box) # 2. Filter out raw words that are inside YOLO boxes cleaned_word_data = [] for word_tuple in raw_word_data: wx1, wy1, wx2, wy2 = word_tuple[1], word_tuple[2], word_tuple[3], word_tuple[4] w_center_x = (wx1 + wx2) / 2 w_center_y = (wy1 + wy2) / 2 is_inside_yolo = False for px1, py1, px2, py2 in pdf_space_boxes: if px1 <= w_center_x <= px2 and py1 <= w_center_y <= py2: is_inside_yolo = True break if not is_inside_yolo: cleaned_word_data.append(word_tuple) # 3. Add the YOLO boxes themselves as "Solid Words" for i, (px1, py1, px2, py2) in enumerate(pdf_space_boxes): dummy_entry = (f"BLOCK_{i}", px1, py1, px2, py2) cleaned_word_data.append(dummy_entry) return cleaned_word_data # ============================================================================ # --- MISSING HELPER FUNCTION --- # ============================================================================ def preprocess_image_for_ocr(img_np): """ Converts image to grayscale and applies Otsu's Binarization to separate text from background clearly. """ # 1. Convert to Grayscale if needed if len(img_np.shape) == 3: gray = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY) else: gray = img_np # 2. Apply Otsu's Thresholding (Automatic binary threshold) # This makes text solid black and background solid white _, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) return thresh def calculate_vertical_gap_coverage(word_data: list, sep_x: int, page_height: float, gutter_width: int = 10) -> float: """ Calculates what percentage of the page's vertical text span is 'cleanly split' by the separator. A valid column split should split > 65% of the page verticality. """ if not word_data: return 0.0 # Determine the vertical span of the actual text content y_coords = [w[2] for w in word_data] + [w[4] for w in word_data] # y1 and y2 min_y, max_y = min(y_coords), max(y_coords) total_text_height = max_y - min_y if total_text_height <= 0: return 0.0 # Create a boolean array representing the Y-axis (1 pixel per unit) gap_open_mask = np.ones(int(total_text_height) + 1, dtype=bool) zone_left = sep_x - (gutter_width / 2) zone_right = sep_x + (gutter_width / 2) offset_y = int(min_y) for _, x1, y1, x2, y2 in word_data: # Check if this word horizontally interferes with the separator if x2 > zone_left and x1 < zone_right: y_start_idx = max(0, int(y1) - offset_y) y_end_idx = min(len(gap_open_mask), int(y2) - offset_y) if y_end_idx > y_start_idx: gap_open_mask[y_start_idx:y_end_idx] = False open_pixels = np.sum(gap_open_mask) coverage_ratio = open_pixels / len(gap_open_mask) return coverage_ratio def calculate_x_gutters(word_data: list, params: Dict, page_height: float) -> List[int]: """ Calculates X-axis histogram and validates using BRIDGING DENSITY and Vertical Coverage. """ if not word_data: return [] x_points = [] # Use only word_data elements 1 (x1) and 3 (x2) for item in word_data: x_points.extend([item[1], item[3]]) if not x_points: return [] max_x = max(x_points) # 1. Determine total text height for ratio calculation y_coords = [item[2] for item in word_data] + [item[4] for item in word_data] min_y, max_y = min(y_coords), max(y_coords) total_text_height = max_y - min_y if total_text_height <= 0: return [] # Histogram Setup bin_size = params.get('cluster_bin_size', 5) smoothing = params.get('cluster_smoothing', 1) min_width = params.get('cluster_min_width', 20) threshold_percentile = params.get('cluster_threshold_percentile', 85) num_bins = int(np.ceil(max_x / bin_size)) hist, bin_edges = np.histogram(x_points, bins=num_bins, range=(0, max_x)) smoothed_hist = gaussian_filter1d(hist.astype(float), sigma=smoothing) inverted_signal = np.max(smoothed_hist) - smoothed_hist peaks, properties = find_peaks( inverted_signal, height=np.max(inverted_signal) - np.percentile(smoothed_hist, threshold_percentile), distance=min_width / bin_size ) if not peaks.size: return [] separator_x_coords = [int(bin_edges[p]) for p in peaks] final_separators = [] for x_coord in separator_x_coords: # --- CHECK 1: BRIDGING DENSITY (The "Cut Through" Check) --- # Calculate the total vertical height of words that physically cross this line. bridging_height = 0 bridging_count = 0 for item in word_data: wx1, wy1, wx2, wy2 = item[1], item[2], item[3], item[4] # Check if this word physically sits on top of the separator line if wx1 < x_coord and wx2 > x_coord: word_h = wy2 - wy1 bridging_height += word_h bridging_count += 1 # Calculate Ratio: How much of the page's text height is blocked by these crossing words? bridging_ratio = bridging_height / total_text_height # THRESHOLD: If bridging blocks > 8% of page height, REJECT. # This allows for page numbers or headers (usually < 5%) to cross, but NOT paragraphs. if bridging_ratio > 0.08: print( f" ❌ Separator X={x_coord} REJECTED: Bridging Ratio {bridging_ratio:.1%} (>15%) cuts through text.") continue # --- CHECK 2: VERTICAL GAP COVERAGE (The "Clean Split" Check) --- # The gap must exist cleanly for > 65% of the text height. coverage = calculate_vertical_gap_coverage(word_data, x_coord, page_height, gutter_width=min_width) if coverage >= 0.80: final_separators.append(x_coord) print(f" -> Separator X={x_coord} ACCEPTED (Coverage: {coverage:.1%}, Bridging: {bridging_ratio:.1%})") else: print(f" ❌ Separator X={x_coord} REJECTED (Coverage: {coverage:.1%}, Bridging: {bridging_ratio:.1%})") return sorted(final_separators) #====================================================================================================================================== # def get_word_data_for_detection(page: fitz.Page, pdf_path: str, page_num: int, # top_margin_percent=0.10, bottom_margin_percent=0.10) -> list: # """Extract word data with OCR caching to avoid redundant Tesseract runs.""" # word_data = page.get_text("words") # if len(word_data) > 0: # word_data = [(w[4], w[0], w[1], w[2], w[3]) for w in word_data] # else: # if _ocr_cache.has_ocr(pdf_path, page_num): # word_data = _ocr_cache.get_ocr(pdf_path, page_num) # else: # try: # # --- OPTIMIZATION START --- # # 1. Render at Higher Resolution (Zoom 4.0 = ~300 DPI) # zoom_level = 4.0 # pix = page.get_pixmap(matrix=fitz.Matrix(zoom_level, zoom_level)) # # 2. Convert directly to OpenCV format (Faster than PIL) # img_np = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, pix.n) # if pix.n == 3: # img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) # elif pix.n == 4: # img_np = cv2.cvtColor(img_np, cv2.COLOR_RGBA2BGR) # # 3. Apply Preprocessing (Thresholding) # processed_img = preprocess_image_for_ocr(img_np) # # 4. Optimized Tesseract Config # # --psm 6: Assume a single uniform block of text (Great for columns/questions) # # --oem 3: Default engine (LSTM) # custom_config = r'--oem 3 --psm 6' # data = pytesseract.image_to_data(processed_img, output_type=pytesseract.Output.DICT, # config=custom_config) # full_word_data = [] # for i in range(len(data['level'])): # text = data['text'][i].strip() # if text: # # Scale coordinates back to PDF points # x1 = data['left'][i] / zoom_level # y1 = data['top'][i] / zoom_level # x2 = (data['left'][i] + data['width'][i]) / zoom_level # y2 = (data['top'][i] + data['height'][i]) / zoom_level # full_word_data.append((text, x1, y1, x2, y2)) # word_data = full_word_data # _ocr_cache.set_ocr(pdf_path, page_num, word_data) # # --- OPTIMIZATION END --- # except Exception as e: # print(f" ❌ OCR Error in detection phase: {e}") # return [] # # Apply margin filtering # page_height = page.rect.height # y_min = page_height * top_margin_percent # y_max = page_height * (1 - bottom_margin_percent) # return [d for d in word_data if d[2] >= y_min and d[4] <= y_max] #============================================================================================================ # def get_word_data_for_detection(page: fitz.Page, pdf_path: str, page_num: int, # top_margin_percent=0.10, bottom_margin_percent=0.10) -> list: # word_data = page.get_text("words") # if len(word_data) > 0: # word_data = [(w[4], w[0], w[1], w[2], w[3]) for w in word_data] # else: # if _ocr_cache.has_ocr(pdf_path, page_num): # word_data = _ocr_cache.get_ocr(pdf_path, page_num) # else: # try: # # 1. Render at Higher Resolution # zoom_level = 4.0 # pix = page.get_pixmap(matrix=fitz.Matrix(zoom_level, zoom_level)) # img_np = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, pix.n) # # Convert to BGR for RapidOCR # if pix.n == 3: img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) # elif pix.n == 4: img_np = cv2.cvtColor(img_np, cv2.COLOR_RGBA2BGR) # # 2. Run RapidOCR # # RapidOCR returns: [[box, text, score], ...] # # where box is [[x1,y1], [x2,y2], [x3,y3], [x4,y4]] # results, _ = ocr_engine(img_np) # full_word_data = [] # if results: # for box, text, score in results: # text = text.strip() # if text: # # 3. Convert Polygon to BBox and Scale back to PDF points # xs = [p[0] for p in box] # ys = [p[1] for p in box] # x1 = min(xs) / zoom_level # y1 = min(ys) / zoom_level # x2 = max(xs) / zoom_level # y2 = max(ys) / zoom_level # full_word_data.append((text, x1, y1, x2, y2)) # word_data = full_word_data # _ocr_cache.set_ocr(pdf_path, page_num, word_data) # except Exception as e: # print(f" ❌ RapidOCR Error in detection phase: {e}") # return [] # # Apply margin filtering # page_height = page.rect.height # y_min = page_height * top_margin_percent # y_max = page_height * (1 - bottom_margin_percent) # return [d for d in word_data if d[2] >= y_min and d[4] <= y_max] # def get_word_data_for_detection(page: fitz.Page, pdf_path: str, page_num: int, # top_margin_percent=0.10, bottom_margin_percent=0.10) -> list: # word_data = page.get_text("words") # if len(word_data) > 0: # # Reformat standard PyMuPDF output to (text, x1, y1, x2, y2) # word_data = [(w[4], w[0], w[1], w[2], w[3]) for w in word_data] # else: # if _ocr_cache.has_ocr(pdf_path, page_num): # word_data = _ocr_cache.get_ocr(pdf_path, page_num) # else: # try: # # 1. Render at Higher Resolution # zoom_level = 4.0 # pix = page.get_pixmap(matrix=fitz.Matrix(zoom_level, zoom_level)) # img_np = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, pix.n) # # Convert to BGR for RapidOCR # if pix.n == 3: # img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) # elif pix.n == 4: # img_np = cv2.cvtColor(img_np, cv2.COLOR_RGBA2BGR) # # 2. Run RapidOCR # ocr_out = ocr_engine(img_np) # full_word_data = [] # # CRITICAL FIX: Use 'is not None' to avoid NumPy truthiness ambiguity # if ocr_out is not None and ocr_out.boxes is not None: # # Use zip to iterate through boxes, text, and scores simultaneously # for box, text, score in zip(ocr_out.boxes, ocr_out.txts, ocr_out.scores): # text = str(text).strip() # if text: # # 3. Convert Polygon to BBox and Scale back to PDF points # xs = [p[0] for p in box] # ys = [p[1] for p in box] # x1 = min(xs) / zoom_level # y1 = min(ys) / zoom_level # x2 = max(xs) / zoom_level # y2 = max(ys) / zoom_level # full_word_data.append((text, x1, y1, x2, y2)) # word_data = full_word_data # _ocr_cache.set_ocr(pdf_path, page_num, word_data) # except Exception as e: # print(f" ❌ RapidOCR Error in detection phase: {e}") # return [] # # Apply margin filtering # page_height = page.rect.height # y_min = page_height * top_margin_percent # y_max = page_height * (1 - bottom_margin_percent) # # Return filtered data where y-coordinates fall within the margins # return [d for d in word_data if d[2] >= y_min and d[4] <= y_max] def get_word_data_for_detection(page: fitz.Page, pdf_path: str, page_num: int, top_margin_percent=0.10, bottom_margin_percent=0.10) -> list: """ Retrieves word data using PyMuPDF native extraction or RapidOCR fallback. """ # 1. Attempt Native Extraction word_data = page.get_text("words") if len(word_data) > 5: word_data = [(w[4], w[0], w[1], w[2], w[3]) for w in word_data] else: # 2. Check Cache if _ocr_cache.has_ocr(pdf_path, page_num): cached_data = _ocr_cache.get_ocr(pdf_path, page_num) if cached_data and len(cached_data) > 0: return cached_data # 3. OCR Fallback (RapidOCR) try: zoom_level = 2.0 pix = page.get_pixmap(matrix=fitz.Matrix(zoom_level, zoom_level)) img_np = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, pix.n) if pix.n == 3: img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) elif pix.n == 4: img_np = cv2.cvtColor(img_np, cv2.COLOR_RGBA2BGR) # CRITICAL FIX: Use return_word_box=True and access word_results ocr_result = ocr_engine(img_np, return_word_box=True) full_word_data = [] # Check if we got valid results if ocr_result and ocr_result.word_results: scale_adjustment = 1.0 / zoom_level # Flatten the per-line word results flat_results = sum(ocr_result.word_results, ()) for text, score, bbox in flat_results: text = str(text).strip() if text: # Convert Polygon to BBox xs = [p[0] for p in bbox] ys = [p[1] for p in bbox] x1 = min(xs) * scale_adjustment y1 = min(ys) * scale_adjustment x2 = max(xs) * scale_adjustment y2 = max(ys) * scale_adjustment full_word_data.append((text, x1, y1, x2, y2)) word_data = full_word_data if len(word_data) > 0: _ocr_cache.set_ocr(pdf_path, page_num, word_data) except Exception as e: print(f" ❌ RapidOCR Error in detection phase: {e}") import traceback traceback.print_exc() return [] # 4. Apply Margin Filtering page_height = page.rect.height y_min = page_height * top_margin_percent y_max = page_height * (1 - bottom_margin_percent) return [d for d in word_data if d[2] >= y_min and d[4] <= y_max] #========================================================================================================================================= #============================================================================================================================================= def pixmap_to_numpy(pix: fitz.Pixmap) -> np.ndarray: img_data = pix.samples img = np.frombuffer(img_data, dtype=np.uint8).reshape(pix.height, pix.width, pix.n) if pix.n == 4: img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGR) elif pix.n == 3: img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) return img def extract_native_words_and_convert(fitz_page, scale_factor: float = 2.0) -> list: # 1. Get raw data try: raw_word_data = fitz_page.get_text("words") except Exception as e: print(f" ❌ PyMuPDF extraction failed completely: {e}") return [] # ============================================================================== # --- DEBUGGING BLOCK: CHECK FIRST 50 NATIVE WORDS (SAFE PRINT) --- # ============================================================================== print(f"\n[DEBUG] Native Extraction (Page {fitz_page.number + 1}): Checking first 50 words...") debug_count = 0 for item in raw_word_data: if debug_count >= 50: break word_text = item[4] # --- SAFE PRINTING LOGIC --- # We encode/decode to ignore surrogates just for the print statement # This prevents the "UnicodeEncodeError" that was crashing your script safe_text = word_text.encode('utf-8', 'ignore').decode('utf-8') # Get hex codes (handling potential errors in 'ord') try: unicode_points = [f"\\u{ord(c):04x}" for c in word_text] except: unicode_points = ["ERROR"] print(f" Word {debug_count}: '{safe_text}' -> Codes: {unicode_points}") debug_count += 1 print("----------------------------------------------------------------------\n") # ============================================================================== converted_ocr_output = [] DEFAULT_CONFIDENCE = 99.0 for x1, y1, x2, y2, word, *rest in raw_word_data: # --- FIX: ROBUST SANITIZATION --- # 1. Encode to UTF-8 ignoring errors (strips surrogates) # 2. Decode back to string cleaned_word_bytes = word.encode('utf-8', 'ignore') cleaned_word = cleaned_word_bytes.decode('utf-8') cleaned_word = word.encode('utf-8', 'ignore').decode('utf-8').strip() # cleaned_word = cleaned_word.strip() if not cleaned_word: continue x1_pix = int(x1 * scale_factor) y1_pix = int(y1 * scale_factor) x2_pix = int(x2 * scale_factor) y2_pix = int(y2 * scale_factor) converted_ocr_output.append({ 'type': 'text', 'word': cleaned_word, 'confidence': DEFAULT_CONFIDENCE, 'bbox': [x1_pix, y1_pix, x2_pix, y2_pix], 'y0': y1_pix, 'x0': x1_pix }) return converted_ocr_output #=================================================================================================== #=================================================================================================== #=================================================================================================== import pandas as pd import pickle import os import time import json from sklearn.feature_extraction.text import TfidfVectorizer import numpy as np from collections import defaultdict # --- Model File Paths (Required for the Classifier to load) --- VECTORIZER_FILE = 'tfidf_vectorizer_conditional.pkl' SUBJECT_MODEL_FILE = 'subject_classifier_model_conditional.pkl' CONDITIONAL_CONCEPT_MODELS_FILE = 'conditional_concept_models.pkl' # --- Hierarchical Classifier Class (Dependency for the helper function) --- class HierarchicalClassifier: """ A two-stage classification system based on conditional training. Loads the vectorizer, subject classifier, and conditional concept models. """ def __init__(self): self.vectorizer = None self.subject_model = None self.conditional_concept_models = {} self.is_ready = False def load_models(self): """Loads the vectorizer, subject model, and conditional concept models.""" try: start_time = time.time() # 1. Load the TF-IDF Vectorizer with open(VECTORIZER_FILE, 'rb') as f: self.vectorizer = pickle.load(f) # 2. Load the Level 1 (Subject) Classifier with open(SUBJECT_MODEL_FILE, 'rb') as f: self.subject_model = pickle.load(f) # 3. Load the dictionary of conditional Level 2 (Concept) Models with open(CONDITIONAL_CONCEPT_MODELS_FILE, 'rb') as f: conditional_data = pickle.load(f) # Extract just the models for easy access for subject, data in conditional_data.items(): self.conditional_concept_models[subject] = data['model'] print(f"Loaded models successfully in {time.time() - start_time:.2f} seconds.") self.is_ready = True except FileNotFoundError as e: print(f"Error: Required model file not found: {e.filename}.") self.is_ready = False except Exception as e: print(f"An error occurred while loading models: {e}") self.is_ready = False return self.is_ready def predict_subject(self, text_chunk): """Predicts the top Subject (Level 1).""" if not self.is_ready: return "Unknown", 0.0 # Vectorize the input text_vector = self.vectorizer.transform([text_chunk]).astype(np.float64) if hasattr(self.subject_model, 'predict_proba'): probabilities = self.subject_model.predict_proba(text_vector)[0] classes = self.subject_model.classes_ top_index = np.argmax(probabilities) return classes[top_index], probabilities[top_index] else: return self.subject_model.predict(text_vector)[0], 1.0 def predict_concept_hierarchical(self, text_chunk, predicted_subject): """ Predicts the top Concept (Level 2) using the specialized conditional model. """ if not self.is_ready: return "Unknown", 0.0 concept_model = self.conditional_concept_models.get(predicted_subject) if concept_model is None or len(getattr(concept_model, 'classes_', [])) <= 1: return "No_Conditional_Model_Found", 0.0 # Vectorize the input text_vector = self.vectorizer.transform([text_chunk]).astype(np.float64) if hasattr(concept_model, 'predict_proba'): probabilities = concept_model.predict_proba(text_vector)[0] classes = concept_model.classes_ top_index = np.argmax(probabilities) return classes[top_index], probabilities[top_index] else: return concept_model.predict(text_vector)[0], 1.0 # -------------------------------------------------------------------------------------- # --- The Requested Helper Function --- def post_process_json_with_inference(json_data, classifier): """ Takes JSON data, runs hierarchical inference on all question/option text, and adds 'predicted_subject' and 'predicted_concept' tags to each entry. Args: json_data (list): The list of dictionaries containing question entries. classifier (HierarchicalClassifier): An initialized and loaded classifier object. Returns: list: The modified list of dictionaries with classification tags added. """ if not classifier.is_ready: print("Classifier not ready. Skipping inference.") return json_data # This print statement can be removed for silent pipeline integration print("\n--- Starting Subject/Concept Detection ---") for entry in json_data: # Only process entries that have a 'question' field if 'question' not in entry: continue # 1. Combine Question Text and Option Text for robust feature extraction full_text = entry.get('question', '') options = entry.get('options', {}) for option_key, option_value in options.items(): # Use the text component of the option if available option_text = option_value if isinstance(option_value, str) else option_key full_text += " " + option_text.replace('\n', ' ') # Clean up text (remove multiple spaces and surrounding whitespace) full_text = ' '.join(full_text.split()).strip() # Handle empty text if not full_text: entry['predicted_subject'] = {'label': 'Empty_Text', 'confidence': 0.0} entry['predicted_concept'] = {'label': 'Empty_Text', 'confidence': 0.0} continue # 2. STAGE 1: Predict Subject subj_label, subj_conf = classifier.predict_subject(full_text) # 3. STAGE 2: Predict Concept (Conditional on predicted subject) conc_label, conc_conf = classifier.predict_concept_hierarchical(full_text, subj_label) # 4. Add results to the JSON entry entry['predicted_subject'] = { 'label': subj_label, 'confidence': round(subj_conf, 4) } entry['predicted_concept'] = { 'label': conc_label, 'confidence': round(conc_conf, 4) } # This print statement can be removed for silent pipeline integration # print("--- JSON Post-Processing Complete ---") return json_data #=================================================================================================== #=================================================================================================== #=================================================================================================== def preprocess_and_ocr_page(original_img: np.ndarray, model, pdf_path: str, page_num: int, fitz_page: fitz.Page, pdf_name: str) -> Tuple[List[Dict[str, Any]], Optional[int]]: """ OPTIMIZED FLOW: 1. Run YOLO to find Equations/Tables. 2. Mask raw text with YOLO boxes. 3. Run Column Detection on the MASKED data. 4. Proceed with OCR (Native or High-Res Tesseract Fallback) and Output. """ global GLOBAL_FIGURE_COUNT, GLOBAL_EQUATION_COUNT start_time_total = time.time() if original_img is None: print(f" ❌ Invalid image for page {page_num}.") return None, None # ==================================================================== # --- STEP 1: YOLO DETECTION --- # ==================================================================== start_time_yolo = time.time() results = model.predict(source=original_img, conf=0.2, imgsz=640, verbose=False) relevant_detections = [] THRESHOLDS = { 'figure': 0.75, 'equation': 0.20 } if results and results[0].boxes: for box in results[0].boxes: class_id = int(box.cls[0]) class_name = model.names[class_id] conf = float(box.conf[0]) # Logic: Check if class is in our list AND meets its specific threshold if class_name in THRESHOLDS: if conf >= THRESHOLDS[class_name]: x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int) relevant_detections.append({ 'coords': (x1, y1, x2, y2), 'y1': y1, 'class': class_name, 'conf': conf }) merged_detections = merge_overlapping_boxes(relevant_detections, IOU_MERGE_THRESHOLD) print(f" [LOG] YOLO found {len(merged_detections)} objects in {time.time() - start_time_yolo:.3f}s.") # if results and results[0].boxes: # for box in results[0].boxes: # class_id = int(box.cls[0]) # class_name = model.names[class_id] # if class_name in TARGET_CLASSES: # x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int) # relevant_detections.append( # {'coords': (x1, y1, x2, y2), 'y1': y1, 'class': class_name, 'conf': float(box.conf[0])} # ) # merged_detections = merge_overlapping_boxes(relevant_detections, IOU_MERGE_THRESHOLD) # print(f" [LOG] YOLO found {len(merged_detections)} objects in {time.time() - start_time_yolo:.3f}s.") # ==================================================================== # --- STEP 2: PREPARE DATA FOR COLUMN DETECTION (MASKING) --- # ==================================================================== # Note: This uses the updated 'get_word_data_for_detection' which has its own optimizations raw_words_for_layout = get_word_data_for_detection( fitz_page, pdf_path, page_num, top_margin_percent=0.10, bottom_margin_percent=0.10 ) masked_word_data = merge_yolo_into_word_data(raw_words_for_layout, merged_detections, scale_factor=2.0) # ==================================================================== # --- STEP 3: COLUMN DETECTION --- # ==================================================================== page_width_pdf = fitz_page.rect.width page_height_pdf = fitz_page.rect.height column_detection_params = { 'cluster_bin_size': 2, 'cluster_smoothing': 2, 'cluster_min_width': 10, 'cluster_threshold_percentile': 85, } separators = calculate_x_gutters(masked_word_data, column_detection_params, page_height_pdf) page_separator_x = None if separators: central_min = page_width_pdf * 0.35 central_max = page_width_pdf * 0.65 central_separators = [s for s in separators if central_min <= s <= central_max] if central_separators: center_x = page_width_pdf / 2 page_separator_x = min(central_separators, key=lambda x: abs(x - center_x)) print(f" ✅ Column Split Confirmed at X={page_separator_x:.1f}") else: print(" ⚠️ Gutter found off-center. Ignoring.") else: print(" -> Single Column Layout Confirmed.") # ==================================================================== # --- STEP 4: COMPONENT EXTRACTION (Save Images) --- # ==================================================================== start_time_components = time.time() component_metadata = [] fig_count_page = 0 eq_count_page = 0 for detection in merged_detections: x1, y1, x2, y2 = detection['coords'] class_name = detection['class'] if class_name == 'figure': GLOBAL_FIGURE_COUNT += 1 counter = GLOBAL_FIGURE_COUNT component_word = f"FIGURE{counter}" fig_count_page += 1 elif class_name == 'equation': GLOBAL_EQUATION_COUNT += 1 counter = GLOBAL_EQUATION_COUNT component_word = f"EQUATION{counter}" eq_count_page += 1 else: continue component_crop = original_img[y1:y2, x1:x2] component_filename = f"{pdf_name}_page{page_num}_{class_name}{counter}.png" cv2.imwrite(os.path.join(FIGURE_EXTRACTION_DIR, component_filename), component_crop) y_midpoint = (y1 + y2) // 2 component_metadata.append({ 'type': class_name, 'word': component_word, 'bbox': [int(x1), int(y1), int(x2), int(y2)], 'y0': int(y_midpoint), 'x0': int(x1) }) # ==================================================================== # --- STEP 5: HYBRID OCR (Native Text + Cached Tesseract Fallback) --- # ==================================================================== raw_ocr_output = [] scale_factor = 2.0 # Pipeline standard scale try: # Try getting native text first # NOTE: extract_native_words_and_convert MUST ALSO BE UPDATED TO USE sanitize_text raw_ocr_output = extract_native_words_and_convert(fitz_page, scale_factor=scale_factor) except Exception as e: print(f" ❌ Native text extraction failed: {e}") # If native text is missing, fall back to OCR if not raw_ocr_output: if _ocr_cache.has_ocr(pdf_path, page_num): print(f" ⚡ Using cached Tesseract OCR for page {page_num}") cached_word_data = _ocr_cache.get_ocr(pdf_path, page_num) for word_tuple in cached_word_data: word_text, x1, y1, x2, y2 = word_tuple # Scale from PDF points to Pipeline Pixels (2.0) x1_pix = int(x1 * scale_factor) y1_pix = int(y1 * scale_factor) x2_pix = int(x2 * scale_factor) y2_pix = int(y2 * scale_factor) raw_ocr_output.append({ 'type': 'text', 'word': word_text, 'confidence': 95.0, 'bbox': [x1_pix, y1_pix, x2_pix, y2_pix], 'y0': y1_pix, 'x0': x1_pix }) else: # === START OF OPTIMIZED OCR BLOCK === # try: # # 1. Re-render Page at High Resolution (Zoom 4.0 = ~300 DPI) # ocr_zoom = 4.0 # pix_ocr = fitz_page.get_pixmap(matrix=fitz.Matrix(ocr_zoom, ocr_zoom)) # # Convert PyMuPDF Pixmap to OpenCV format # img_ocr_np = np.frombuffer(pix_ocr.samples, dtype=np.uint8).reshape(pix_ocr.height, pix_ocr.width, # pix_ocr.n) # if pix_ocr.n == 3: # img_ocr_np = cv2.cvtColor(img_ocr_np, cv2.COLOR_RGB2BGR) # elif pix_ocr.n == 4: # img_ocr_np = cv2.cvtColor(img_ocr_np, cv2.COLOR_RGBA2BGR) # # 2. Preprocess (Binarization) # processed_img = preprocess_image_for_ocr(img_ocr_np) # # 3. Run Tesseract with Optimized Configuration # custom_config = r'--oem 3 --psm 6' # hocr_data = pytesseract.image_to_data( # processed_img, # output_type=pytesseract.Output.DICT, # config=custom_config # ) # for i in range(len(hocr_data['level'])): # text = hocr_data['text'][i] # Retrieve raw Tesseract text # # --- FIX: SANITIZE TEXT AND THEN STRIP --- # cleaned_text = sanitize_text(text).strip() # if cleaned_text and hocr_data['conf'][i] > -1: # # 4. Coordinate Mapping # scale_adjustment = scale_factor / ocr_zoom # x1 = int(hocr_data['left'][i] * scale_adjustment) # y1 = int(hocr_data['top'][i] * scale_adjustment) # w = int(hocr_data['width'][i] * scale_adjustment) # h = int(hocr_data['height'][i] * scale_adjustment) # x2 = x1 + w # y2 = y1 + h # raw_ocr_output.append({ # 'type': 'text', # 'word': cleaned_text, # Use the sanitized word # 'confidence': float(hocr_data['conf'][i]), # 'bbox': [x1, y1, x2, y2], # 'y0': y1, # 'x0': x1 # }) # except Exception as e: # print(f" ❌ Tesseract OCR Error: {e}") #============================================================================================================================================================= #============================================================================================================================================================= # else: # # === START OF RAPIDOCR BLOCK === # try: # # 1. Re-render Page at High Resolution (Standardizing to Zoom 4.0) # ocr_zoom = 4.0 # pix_ocr = fitz_page.get_pixmap(matrix=fitz.Matrix(ocr_zoom, ocr_zoom)) # # Convert PyMuPDF Pixmap to OpenCV format (BGR) # img_ocr_np = np.frombuffer(pix_ocr.samples, dtype=np.uint8).reshape( # pix_ocr.height, pix_ocr.width, pix_ocr.n # ) # if pix_ocr.n == 3: # img_ocr_np = cv2.cvtColor(img_ocr_np, cv2.COLOR_RGB2BGR) # elif pix_ocr.n == 4: # img_ocr_np = cv2.cvtColor(img_ocr_np, cv2.COLOR_RGBA2BGR) # # 2. Run RapidOCR (Models handle preprocessing internally) # results, _ = ocr_engine(img_ocr_np) # if results: # # Calculate scaling from OCR image (4.0) to your pipeline standard (scale_factor=2.0) # scale_adjustment = scale_factor / ocr_zoom # for box, text, score in results: # # Sanitize and clean text # cleaned_text = sanitize_text(text).strip() # if cleaned_text: # # 3. Coordinate Mapping (Convert 4-point polygon to x1, y1, x2, y2) # xs = [p[0] for p in box] # ys = [p[1] for p in box] # x1 = int(min(xs) * scale_adjustment) # y1 = int(min(ys) * scale_adjustment) # x2 = int(max(xs) * scale_adjustment) # y2 = int(max(ys) * scale_adjustment) # raw_ocr_output.append({ # 'type': 'text', # 'word': cleaned_text, # 'confidence': float(score) * 100, # Converting 0-1.0 to 0-100 scale # 'bbox': [x1, y1, x2, y2], # 'y0': y1, # 'x0': x1 # }) # except Exception as e: # print(f" ❌ RapidOCR Fallback Error: {e}") try: # 1. Re-render Page at High Resolution (Standardizing to Zoom 4.0) ocr_zoom = 4.0 pix_ocr = fitz_page.get_pixmap(matrix=fitz.Matrix(ocr_zoom, ocr_zoom)) # Convert PyMuPDF Pixmap to OpenCV format (BGR) img_ocr_np = np.frombuffer(pix_ocr.samples, dtype=np.uint8).reshape( pix_ocr.height, pix_ocr.width, pix_ocr.n ) if pix_ocr.n == 3: img_ocr_np = cv2.cvtColor(img_ocr_np, cv2.COLOR_RGB2BGR) elif pix_ocr.n == 4: img_ocr_np = cv2.cvtColor(img_ocr_np, cv2.COLOR_RGBA2BGR) # 2. Run RapidOCR # FIX 1: Capture the object (Removes the "cannot unpack" error) ocr_out = ocr_engine(img_ocr_np) # FIX 2: Use 'is not None' (Removes the "ambiguous truth value" error) if ocr_out is not None and ocr_out.boxes is not None: # Calculate scaling from OCR image (4.0) to your pipeline standard (scale_factor=2.0) scale_adjustment = scale_factor / ocr_zoom # FIX 3: Zip the attributes to restore your expected (box, text, score) format for box, text, score in zip(ocr_out.boxes, ocr_out.txts, ocr_out.scores): # Sanitize and clean text cleaned_text = sanitize_text(str(text)).strip() if cleaned_text: # 3. Coordinate Mapping (Convert 4-point polygon to x1, y1, x2, y2) xs = [p[0] for p in box] ys = [p[1] for p in box] x1 = int(min(xs) * scale_adjustment) y1 = int(min(ys) * scale_adjustment) x2 = int(max(xs) * scale_adjustment) y2 = int(max(ys) * scale_adjustment) raw_ocr_output.append({ 'type': 'text', 'word': cleaned_text, 'confidence': float(score) * 100, # Converting 0-1.0 to 0-100 scale 'bbox': [x1, y1, x2, y2], 'y0': y1, 'x0': x1 }) except Exception as e: print(f" ❌ RapidOCR Fallback Error: {e}") # === END OF RAPIDOCR BLOCK ========================== # === END OF RAPIDOCR BLOCK ==================================================================================================================================== #=========================================================================================================================================================================== # === END OF OPTIMIZED OCR BLOCK === # ==================================================================== # --- STEP 6: OCR CLEANING AND MERGING --- # ==================================================================== items_to_sort = [] for ocr_word in raw_ocr_output: is_suppressed = False for component in component_metadata: # Do not include words that are inside figure/equation boxes ioa = calculate_ioa(ocr_word['bbox'], component['bbox']) if ioa > IOA_SUPPRESSION_THRESHOLD: is_suppressed = True break if not is_suppressed: items_to_sort.append(ocr_word) # Add figures/equations back into the flow as "words" items_to_sort.extend(component_metadata) # ==================================================================== # --- STEP 7: LINE-BASED SORTING --- # ==================================================================== items_to_sort.sort(key=lambda x: (x['y0'], x['x0'])) lines = [] for item in items_to_sort: placed = False for line in lines: y_ref = min(it['y0'] for it in line) if abs(y_ref - item['y0']) < LINE_TOLERANCE: line.append(item) placed = True break if not placed and item['type'] in ['equation', 'figure']: for line in lines: y_ref = min(it['y0'] for it in line) if abs(y_ref - item['y0']) < 20: line.append(item) placed = True break if not placed: lines.append([item]) for line in lines: line.sort(key=lambda x: x['x0']) final_output = [] for line in lines: for item in line: data_item = {"word": item["word"], "bbox": item["bbox"], "type": item["type"]} if 'tag' in item: data_item['tag'] = item['tag'] final_output.append(data_item) return final_output, page_separator_x def run_single_pdf_preprocessing(pdf_path: str, preprocessed_json_path: str) -> Optional[str]: global GLOBAL_FIGURE_COUNT, GLOBAL_EQUATION_COUNT GLOBAL_FIGURE_COUNT = 0 GLOBAL_EQUATION_COUNT = 0 _ocr_cache.clear() print("\n" + "=" * 80) print("--- 1. STARTING OPTIMIZED YOLO/OCR PREPROCESSING PIPELINE ---") print("=" * 80) if not os.path.exists(pdf_path): print(f"❌ FATAL ERROR: Input PDF not found at {pdf_path}.") return None os.makedirs(os.path.dirname(preprocessed_json_path), exist_ok=True) os.makedirs(FIGURE_EXTRACTION_DIR, exist_ok=True) model = YOLO(WEIGHTS_PATH) pdf_name = os.path.splitext(os.path.basename(pdf_path))[0] try: doc = fitz.open(pdf_path) print(f"✅ Opened PDF: {pdf_name} ({doc.page_count} pages)") except Exception as e: print(f"❌ ERROR loading PDF file: {e}") return None all_pages_data = [] total_pages_processed = 0 mat = fitz.Matrix(2.0, 2.0) print("\n[STEP 1.2: ITERATING PAGES - IN-MEMORY PROCESSING]") for page_num_0_based in range(doc.page_count): page_num = page_num_0_based + 1 print(f" -> Processing Page {page_num}/{doc.page_count}...") fitz_page = doc.load_page(page_num_0_based) try: pix = fitz_page.get_pixmap(matrix=mat) original_img = pixmap_to_numpy(pix) except Exception as e: print(f" ❌ Error converting page {page_num} to image: {e}") continue final_output, page_separator_x = preprocess_and_ocr_page( original_img, model, pdf_path, page_num, fitz_page, pdf_name ) if final_output is not None: page_data = { "page_number": page_num, "data": final_output, "column_separator_x": page_separator_x } all_pages_data.append(page_data) total_pages_processed += 1 else: print(f" ❌ Skipped page {page_num} due to processing error.") doc.close() if all_pages_data: try: with open(preprocessed_json_path, 'w') as f: json.dump(all_pages_data, f, indent=4) print(f"\n ✅ Combined structured OCR JSON saved to: {os.path.basename(preprocessed_json_path)}") except Exception as e: print(f"❌ ERROR saving combined JSON output: {e}") return None else: print("❌ WARNING: No page data generated. Halting pipeline.") return None print("\n" + "=" * 80) print(f"--- YOLO/OCR PREPROCESSING COMPLETE ({total_pages_processed} pages processed) ---") print("=" * 80) return preprocessed_json_path # ============================================================================ # --- PHASE 2: LAYOUTLMV3 INFERENCE FUNCTIONS --- # ============================================================================ class LayoutLMv3ForTokenClassification(nn.Module): def __init__(self, num_labels: int = NUM_LABELS): super().__init__() self.num_labels = num_labels config = LayoutLMv3Config.from_pretrained("microsoft/layoutlmv3-base", num_labels=num_labels) self.layoutlmv3 = LayoutLMv3Model.from_pretrained("microsoft/layoutlmv3-base", config=config) self.classifier = nn.Linear(config.hidden_size, num_labels) self.crf = CRF(num_labels) self.init_weights() def init_weights(self): nn.init.xavier_uniform_(self.classifier.weight) if self.classifier.bias is not None: nn.init.zeros_(self.classifier.bias) def forward(self, input_ids: torch.Tensor, bbox: torch.Tensor, attention_mask: torch.Tensor, labels: Optional[torch.Tensor] = None): outputs = self.layoutlmv3(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, return_dict=True) sequence_output = outputs.last_hidden_state emissions = self.classifier(sequence_output) mask = attention_mask.bool() if labels is not None: loss = -self.crf(emissions, labels, mask=mask).mean() return loss else: return self.crf.viterbi_decode(emissions, mask=mask) def _merge_integrity(all_token_data: List[Dict[str, Any]], column_separator_x: Optional[int]) -> List[List[Dict[str, Any]]]: """Splits the token data objects into column chunks based on a separator.""" if column_separator_x is None: print(" -> No column separator. Treating as one chunk.") return [all_token_data] left_column_tokens, right_column_tokens = [], [] for token_data in all_token_data: bbox_raw = token_data['bbox_raw_pdf_space'] center_x = (bbox_raw[0] + bbox_raw[2]) / 2 if center_x < column_separator_x: left_column_tokens.append(token_data) else: right_column_tokens.append(token_data) chunks = [c for c in [left_column_tokens, right_column_tokens] if c] print(f" -> Data split into {len(chunks)} column chunk(s) using separator X={column_separator_x}.") return chunks def run_inference_and_get_raw_words(pdf_path: str, model_path: str, preprocessed_json_path: str, column_detection_params: Optional[Dict] = None) -> List[Dict[str, Any]]: print("\n" + "=" * 80) print("--- 2. STARTING LAYOUTLMV3 INFERENCE PIPELINE (Raw Word Output) ---") print("=" * 80) tokenizer = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f" -> Using device: {device}") try: model = LayoutLMv3ForTokenClassification(num_labels=NUM_LABELS) checkpoint = torch.load(model_path, map_location=device) model_state = checkpoint.get('model_state_dict', checkpoint) # Apply patch for layoutlmv3 compatibility with saved state_dict fixed_state_dict = {key.replace('layoutlm.', 'layoutlmv3.'): value for key, value in model_state.items()} model.load_state_dict(fixed_state_dict) model.to(device) model.eval() print(f"✅ LayoutLMv3 Model loaded successfully from {os.path.basename(model_path)}.") except Exception as e: print(f"❌ FATAL ERROR during LayoutLMv3 model loading: {e}") return [] try: with open(preprocessed_json_path, 'r', encoding='utf-8') as f: preprocessed_data = json.load(f) print(f"✅ Loaded preprocessed data with {len(preprocessed_data)} pages.") except Exception: print("❌ Error loading preprocessed JSON.") return [] try: doc = fitz.open(pdf_path) except Exception: print("❌ Error loading PDF.") return [] final_page_predictions = [] CHUNK_SIZE = 500 for page_data in preprocessed_data: page_num_1_based = page_data['page_number'] page_num_0_based = page_num_1_based - 1 page_raw_predictions = [] print(f"\n *** Processing Page {page_num_1_based} ({len(page_data['data'])} raw tokens) ***") fitz_page = doc.load_page(page_num_0_based) page_width, page_height = fitz_page.rect.width, fitz_page.rect.height print(f" -> Page dimensions: {page_width:.0f}x{page_height:.0f} (PDF points).") all_token_data = [] scale_factor = 2.0 for item in page_data['data']: raw_yolo_bbox = item['bbox'] bbox_pdf = [ int(raw_yolo_bbox[0] / scale_factor), int(raw_yolo_bbox[1] / scale_factor), int(raw_yolo_bbox[2] / scale_factor), int(raw_yolo_bbox[3] / scale_factor) ] normalized_bbox = [ max(0, min(1000, int(1000 * bbox_pdf[0] / page_width))), max(0, min(1000, int(1000 * bbox_pdf[1] / page_height))), max(0, min(1000, int(1000 * bbox_pdf[2] / page_width))), max(0, min(1000, int(1000 * bbox_pdf[3] / page_height))) ] all_token_data.append({ "word": item['word'], "bbox_raw_pdf_space": bbox_pdf, "bbox_normalized": normalized_bbox, "item_original_data": item }) if not all_token_data: continue column_separator_x = page_data.get('column_separator_x', None) if column_separator_x is not None: print(f" -> Using SAVED column separator: X={column_separator_x}") else: print(" -> No column separator found. Assuming single chunk.") token_chunks = _merge_integrity(all_token_data, column_separator_x) total_chunks = len(token_chunks) for chunk_idx, chunk_tokens in enumerate(token_chunks): if not chunk_tokens: continue # 1. Sanitize: Convert everything to strings and aggressively clean Unicode errors. chunk_words = [ str(t['word']).encode('utf-8', errors='ignore').decode('utf-8') for t in chunk_tokens ] chunk_normalized_bboxes = [t['bbox_normalized'] for t in chunk_tokens] total_sub_chunks = (len(chunk_words) + CHUNK_SIZE - 1) // CHUNK_SIZE for i in range(0, len(chunk_words), CHUNK_SIZE): sub_chunk_idx = i // CHUNK_SIZE + 1 sub_words = chunk_words[i:i + CHUNK_SIZE] sub_bboxes = chunk_normalized_bboxes[i:i + CHUNK_SIZE] sub_tokens_data = chunk_tokens[i:i + CHUNK_SIZE] print(f" -> Chunk {chunk_idx + 1}/{total_chunks}, Sub-chunk {sub_chunk_idx}/{total_sub_chunks}: {len(sub_words)} words. Running Inference...") # 2. Manual generation of word_ids manual_word_ids = [] for current_word_idx, word in enumerate(sub_words): sub_tokens = tokenizer.tokenize(word) for _ in sub_tokens: manual_word_ids.append(current_word_idx) encoded_input = tokenizer( sub_words, boxes=sub_bboxes, truncation=True, padding="max_length", max_length=512, is_split_into_words=True, return_tensors="pt" ) # Check for empty sequence if encoded_input['input_ids'].shape[0] == 0: print(f" -> Warning: Sub-chunk {sub_chunk_idx} encoded to an empty sequence. Skipping.") continue # 3. Finalize word_ids based on encoded output length sequence_length = int(torch.sum(encoded_input['attention_mask']).item()) content_token_length = max(0, sequence_length - 2) manual_word_ids = manual_word_ids[:content_token_length] final_word_ids = [None] # CLS token (index 0) final_word_ids.extend(manual_word_ids) if sequence_length > 1: final_word_ids.append(None) # SEP token final_word_ids.extend([None] * (512 - len(final_word_ids))) word_ids = final_word_ids[:512] # Final array for mapping # Inputs are already batched by the tokenizer as [1, 512] input_ids = encoded_input['input_ids'].to(device) bbox = encoded_input['bbox'].to(device) attention_mask = encoded_input['attention_mask'].to(device) with torch.no_grad(): model_outputs = model(input_ids, bbox, attention_mask) # --- Robust extraction: support several forward return types --- # We'll try (in order): # 1) model_outputs is (emissions_tensor, viterbi_list) -> use emissions for logits, keep decoded # 2) model_outputs has .logits attribute (HF ModelOutput) # 3) model_outputs is tuple/list containing a logits tensor # 4) model_outputs is a tensor (assume logits) # 5) model_outputs is a list-of-lists of ints (viterbi decoded) -> use that directly (no logits) logits_tensor = None decoded_labels_list = None # case 1: tuple/list with (emissions, viterbi) if isinstance(model_outputs, (tuple, list)) and len(model_outputs) == 2: a, b = model_outputs # a might be tensor (emissions), b might be viterbi list if isinstance(a, torch.Tensor): logits_tensor = a if isinstance(b, list): decoded_labels_list = b # case 2: HF ModelOutput with .logits if logits_tensor is None and hasattr(model_outputs, 'logits') and isinstance(model_outputs.logits, torch.Tensor): logits_tensor = model_outputs.logits # case 3: tuple/list - search for a 3D tensor (B, L, C) if logits_tensor is None and isinstance(model_outputs, (tuple, list)): found_tensor = None for item in model_outputs: if isinstance(item, torch.Tensor): # prefer 3D (batch, seq, labels) if item.dim() == 3: logits_tensor = item break if found_tensor is None: found_tensor = item if logits_tensor is None and found_tensor is not None: # found_tensor may be (batch, seq, hidden) or (seq, hidden); we avoid guessing. # Keep found_tensor only if it matches num_labels dimension if found_tensor.dim() == 3 and found_tensor.shape[-1] == NUM_LABELS: logits_tensor = found_tensor elif found_tensor.dim() == 2 and found_tensor.shape[-1] == NUM_LABELS: logits_tensor = found_tensor.unsqueeze(0) # case 4: model_outputs directly a tensor if logits_tensor is None and isinstance(model_outputs, torch.Tensor): logits_tensor = model_outputs # case 5: model_outputs is a decoded viterbi list (common for CRF-only forward) if decoded_labels_list is None and isinstance(model_outputs, list) and model_outputs and isinstance(model_outputs[0], list): # assume model_outputs is already viterbi decoded: List[List[int]] with batch dim first decoded_labels_list = model_outputs # If neither logits nor decoded exist, that's fatal if logits_tensor is None and decoded_labels_list is None: # helpful debug info try: elem_shapes = [ (type(x), getattr(x, 'shape', None)) for x in model_outputs ] if isinstance(model_outputs, (list, tuple)) else [(type(model_outputs), getattr(model_outputs, 'shape', None))] except Exception: elem_shapes = str(type(model_outputs)) raise RuntimeError(f"Model output of type {type(model_outputs)} did not contain a valid logits tensor or decoded viterbi. Contents: {elem_shapes}") # If we have logits_tensor, normalize shape to [seq_len, num_labels] if logits_tensor is not None: # If shape is [B, L, C] with B==1, squeeze batch if logits_tensor.dim() == 3 and logits_tensor.shape[0] == 1: preds_tensor = logits_tensor.squeeze(0) # [L, C] else: preds_tensor = logits_tensor # possibly [L, C] already # Safety: ensure we have at least seq_len x channels if preds_tensor.dim() != 2: # try to reshape or error raise RuntimeError(f"Unexpected logits tensor shape: {tuple(preds_tensor.shape)}") # We'll use preds_tensor[token_idx] to argmax else: preds_tensor = None # no logits available # If decoded labels provided, make a token-level list-of-ints aligned to tokenizer tokens decoded_token_labels = None if decoded_labels_list is not None: # decoded_labels_list is batch-first; we used batch size 1 # if multiple sequences returned, take first decoded_token_labels = decoded_labels_list[0] if isinstance(decoded_labels_list[0], list) else decoded_labels_list # Now map token-level predictions -> word-level predictions using word_ids word_idx_to_pred_id = {} if preds_tensor is not None: # We have logits. Use argmax of logits for each token id up to sequence_length for token_idx, word_idx in enumerate(word_ids): if token_idx >= sequence_length: break if word_idx is not None and word_idx < len(sub_words): if word_idx not in word_idx_to_pred_id: pred_id = torch.argmax(preds_tensor[token_idx]).item() word_idx_to_pred_id[word_idx] = pred_id else: # No logits, but we have decoded_token_labels from CRF (one label per token) # We'll align decoded_token_labels to token positions. if decoded_token_labels is None: # should not happen due to earlier checks raise RuntimeError("No logits and no decoded labels available for mapping.") # decoded_token_labels length may be equal to content_token_length (no special tokens) # or equal to sequence_length; try to align intelligently: # Prefer using decoded_token_labels aligned to the tokenizer tokens (starting at token 1 for CLS) # If decoded length == content_token_length, then manual_word_ids maps sub-token -> word idx for content tokens only. # We'll iterate tokens and pick label accordingly. # Build token_idx -> decoded_label mapping: # We'll assume decoded_token_labels correspond to content tokens (no CLS/SEP). If decoded length == sequence_length, then shift by 0. decoded_len = len(decoded_token_labels) # Heuristic: if decoded_len == content_token_length -> alignment starts at token_idx 1 (skip CLS) if decoded_len == content_token_length: decoded_start = 1 elif decoded_len == sequence_length: decoded_start = 0 else: # fallback: prefer decoded_start=1 (most common) decoded_start = 1 for tok_idx_in_decoded, label_id in enumerate(decoded_token_labels): tok_idx = decoded_start + tok_idx_in_decoded if tok_idx >= 512: break if tok_idx >= sequence_length: break # map this token to a word index if present word_idx = word_ids[tok_idx] if tok_idx < len(word_ids) else None if word_idx is not None and word_idx < len(sub_words): if word_idx not in word_idx_to_pred_id: # label_id may already be an int word_idx_to_pred_id[word_idx] = int(label_id) # Finally convert mapped word preds -> page_raw_predictions entries for current_word_idx in range(len(sub_words)): pred_id = word_idx_to_pred_id.get(current_word_idx, 0) # default to 0 predicted_label = ID_TO_LABEL[pred_id] original_token = sub_tokens_data[current_word_idx] page_raw_predictions.append({ "word": original_token['word'], "bbox": original_token['bbox_raw_pdf_space'], "predicted_label": predicted_label, "page_number": page_num_1_based }) if page_raw_predictions: final_page_predictions.append({ "page_number": page_num_1_based, "data": page_raw_predictions }) print(f" *** Page {page_num_1_based} Finalized: {len(page_raw_predictions)} labeled words. ***") doc.close() print("\n" + "=" * 80) print("--- LAYOUTLMV3 INFERENCE COMPLETE ---") print("=" * 80) return final_page_predictions # ============================================================================ # --- PHASE 3: BIO TO STRUCTURED JSON DECODER --- # ============================================================================ # def convert_bio_to_structured_json_relaxed(input_path: str, output_path: str) -> Optional[List[Dict[str, Any]]]: # print("\n" + "=" * 80) # print("--- 3. STARTING BIO TO STRUCTURED JSON DECODING ---") # print("=" * 80) # try: # with open(input_path, 'r', encoding='utf-8') as f: # predictions_by_page = json.load(f) # except Exception as e: # print(f"❌ Error loading raw prediction file: {e}") # return None # predictions = [] # for page_item in predictions_by_page: # if isinstance(page_item, dict) and 'data' in page_item: # predictions.extend(page_item['data']) # structured_data = [] # current_item = None # current_option_key = None # current_passage_buffer = [] # current_text_buffer = [] # first_question_started = False # last_entity_type = None # just_finished_i_option = False # is_in_new_passage = False # def finalize_passage_to_item(item, passage_buffer): # if passage_buffer: # passage_text = re.sub(r'\s{2,}', ' ', ' '.join(passage_buffer)).strip() # if item.get('passage'): # item['passage'] += ' ' + passage_text # else: # item['passage'] = passage_text # passage_buffer.clear() # for item in predictions: # word = item['word'] # label = item['predicted_label'] # entity_type = label[2:].strip() if label.startswith(('B-', 'I-')) else None # current_text_buffer.append(word) # previous_entity_type = last_entity_type # is_passage_label = (entity_type == 'PASSAGE') # if not first_question_started: # if label != 'B-QUESTION' and not is_passage_label: # just_finished_i_option = False # is_in_new_passage = False # continue # if is_passage_label: # current_passage_buffer.append(word) # last_entity_type = 'PASSAGE' # just_finished_i_option = False # is_in_new_passage = False # continue # if label == 'B-QUESTION': # if not first_question_started: # header_text = ' '.join(current_text_buffer[:-1]).strip() # if header_text or current_passage_buffer: # metadata_item = {'type': 'METADATA', 'passage': ''} # finalize_passage_to_item(metadata_item, current_passage_buffer) # if header_text: metadata_item['text'] = header_text # structured_data.append(metadata_item) # first_question_started = True # current_text_buffer = [word] # if current_item is not None: # finalize_passage_to_item(current_item, current_passage_buffer) # current_item['text'] = ' '.join(current_text_buffer[:-1]).strip() # structured_data.append(current_item) # current_text_buffer = [word] # current_item = { # 'question': word, 'options': {}, 'answer': '', 'passage': '', 'text': '' # } # current_option_key = None # last_entity_type = 'QUESTION' # just_finished_i_option = False # is_in_new_passage = False # continue # if current_item is not None: # if is_in_new_passage: # # 🔑 Robust Initialization and Appending for 'new_passage' # if 'new_passage' not in current_item: # current_item['new_passage'] = word # else: # current_item['new_passage'] += f' {word}' # if label.startswith('B-') or (label.startswith('I-') and entity_type != 'PASSAGE'): # is_in_new_passage = False # if label.startswith(('B-', 'I-')): last_entity_type = entity_type # continue # is_in_new_passage = False # if label.startswith('B-'): # if entity_type in ['QUESTION', 'OPTION', 'ANSWER', 'SECTION_HEADING']: # finalize_passage_to_item(current_item, current_passage_buffer) # current_passage_buffer = [] # last_entity_type = entity_type # if entity_type == 'PASSAGE': # if previous_entity_type == 'OPTION' and just_finished_i_option: # current_item['new_passage'] = word # Initialize the new passage start # is_in_new_passage = True # else: # current_passage_buffer.append(word) # elif entity_type == 'OPTION': # current_option_key = word # current_item['options'][current_option_key] = word # just_finished_i_option = False # elif entity_type == 'ANSWER': # current_item['answer'] = word # current_option_key = None # just_finished_i_option = False # elif entity_type == 'QUESTION': # current_item['question'] += f' {word}' # just_finished_i_option = False # elif label.startswith('I-'): # if entity_type == 'QUESTION': # current_item['question'] += f' {word}' # elif entity_type == 'PASSAGE': # if previous_entity_type == 'OPTION' and just_finished_i_option: # current_item['new_passage'] = word # Initialize the new passage start # is_in_new_passage = True # else: # if not current_passage_buffer: last_entity_type = 'PASSAGE' # current_passage_buffer.append(word) # elif entity_type == 'OPTION' and current_option_key is not None: # current_item['options'][current_option_key] += f' {word}' # just_finished_i_option = True # elif entity_type == 'ANSWER': # current_item['answer'] += f' {word}' # just_finished_i_option = (entity_type == 'OPTION') # elif label == 'O': # if last_entity_type == 'QUESTION': # current_item['question'] += f' {word}' # just_finished_i_option = False # if current_item is not None: # finalize_passage_to_item(current_item, current_passage_buffer) # current_item['text'] = ' '.join(current_text_buffer).strip() # structured_data.append(current_item) # for item in structured_data: # item['text'] = re.sub(r'\s{2,}', ' ', item['text']).strip() # if 'new_passage' in item: # item['new_passage'] = re.sub(r'\s{2,}', ' ', item['new_passage']).strip() # try: # with open(output_path, 'w', encoding='utf-8') as f: # json.dump(structured_data, f, indent=2, ensure_ascii=False) # except Exception: # pass # return structured_data def convert_bio_to_structured_json_relaxed(input_path: str, output_path: str) -> Optional[List[Dict[str, Any]]]: print("\n" + "=" * 80) print("--- 3. STARTING BIO TO STRUCTURED JSON DECODING ---") print(f"Source: {input_path}") print("=" * 80) start_time = time.time() try: with open(input_path, 'r', encoding='utf-8') as f: predictions_by_page = json.load(f) print(f"✅ Successfully loaded raw predictions ({len(predictions_by_page)} pages found)") except Exception as e: print(f"❌ Error loading raw prediction file: {e}") return None predictions = [] for page_item in predictions_by_page: if isinstance(page_item, dict) and 'data' in page_item: predictions.extend(page_item['data']) total_words = len(predictions) print(f"📋 Total words to process: {total_words}") structured_data = [] current_item = None current_option_key = None current_passage_buffer = [] current_text_buffer = [] first_question_started = False last_entity_type = None just_finished_i_option = False is_in_new_passage = False def finalize_passage_to_item(item, passage_buffer): if passage_buffer: passage_text = re.sub(r'\s{2,}', ' ', ' '.join(passage_buffer)).strip() print(f" ↳ [Buffer] Finalizing passage ({len(passage_buffer)} words) into current item") if item.get('passage'): item['passage'] += ' ' + passage_text else: item['passage'] = passage_text passage_buffer.clear() # Iterate through every predicted word for idx, item in enumerate(predictions): word = item['word'] label = item['predicted_label'] entity_type = label[2:].strip() if label.startswith(('B-', 'I-')) else None current_text_buffer.append(word) previous_entity_type = last_entity_type is_passage_label = (entity_type == 'PASSAGE') # --- LOGGING: Track progress every 500 words or on B- labels --- if label.startswith('B-'): print(f"[Word {idx}/{total_words}] Found Label: {label} | Word: '{word}'") if not first_question_started: if label != 'B-QUESTION' and not is_passage_label: just_finished_i_option = False is_in_new_passage = False continue if is_passage_label: current_passage_buffer.append(word) last_entity_type = 'PASSAGE' just_finished_i_option = False is_in_new_passage = False continue if label == 'B-QUESTION': print(f"🔍 Detection: New Question Started at word {idx}") if not first_question_started: header_text = ' '.join(current_text_buffer[:-1]).strip() if header_text or current_passage_buffer: print(f" -> Creating METADATA item for text found before first question") metadata_item = {'type': 'METADATA', 'passage': ''} finalize_passage_to_item(metadata_item, current_passage_buffer) if header_text: metadata_item['text'] = header_text structured_data.append(metadata_item) first_question_started = True current_text_buffer = [word] if current_item is not None: finalize_passage_to_item(current_item, current_passage_buffer) current_item['text'] = ' '.join(current_text_buffer[:-1]).strip() structured_data.append(current_item) print(f" -> Saved Question. Total structured items so far: {len(structured_data)}") current_text_buffer = [word] current_item = { 'question': word, 'options': {}, 'answer': '', 'passage': '', 'text': '' } current_option_key = None last_entity_type = 'QUESTION' just_finished_i_option = False is_in_new_passage = False continue if current_item is not None: if is_in_new_passage: if 'new_passage' not in current_item: current_item['new_passage'] = word else: current_item['new_passage'] += f' {word}' if label.startswith('B-') or (label.startswith('I-') and entity_type != 'PASSAGE'): print(f" ↳ [State] Exiting new_passage mode at label {label}") is_in_new_passage = False if label.startswith(('B-', 'I-')): last_entity_type = entity_type continue is_in_new_passage = False if label.startswith('B-'): if entity_type in ['QUESTION', 'OPTION', 'ANSWER', 'SECTION_HEADING']: finalize_passage_to_item(current_item, current_passage_buffer) current_passage_buffer = [] last_entity_type = entity_type if entity_type == 'PASSAGE': if previous_entity_type == 'OPTION' and just_finished_i_option: print(f" ↳ [State] Transitioning to new_passage (Option -> Passage boundary)") current_item['new_passage'] = word is_in_new_passage = True else: current_passage_buffer.append(word) elif entity_type == 'OPTION': current_option_key = word current_item['options'][current_option_key] = word just_finished_i_option = False elif entity_type == 'ANSWER': current_item['answer'] = word current_option_key = None just_finished_i_option = False elif entity_type == 'QUESTION': current_item['question'] += f' {word}' just_finished_i_option = False elif label.startswith('I-'): if entity_type == 'QUESTION': current_item['question'] += f' {word}' elif entity_type == 'PASSAGE': if previous_entity_type == 'OPTION' and just_finished_i_option: current_item['new_passage'] = word is_in_new_passage = True else: if not current_passage_buffer: last_entity_type = 'PASSAGE' current_passage_buffer.append(word) elif entity_type == 'OPTION' and current_option_key is not None: current_item['options'][current_option_key] += f' {word}' just_finished_i_option = True elif entity_type == 'ANSWER': current_item['answer'] += f' {word}' just_finished_i_option = (entity_type == 'OPTION') elif label == 'O': # if last_entity_type == 'QUESTION': # current_item['question'] += f' {word}' # just_finished_i_option = False pass # Final wrap up if current_item is not None: print(f"🏁 Finalizing the very last item...") finalize_passage_to_item(current_item, current_passage_buffer) current_item['text'] = ' '.join(current_text_buffer).strip() structured_data.append(current_item) # Clean up and regex replacement for item in structured_data: item['text'] = re.sub(r'\s{2,}', ' ', item['text']).strip() if 'new_passage' in item: item['new_passage'] = re.sub(r'\s{2,}', ' ', item['new_passage']).strip() print(f"💾 Saving {len(structured_data)} items to {output_path}") try: with open(output_path, 'w', encoding='utf-8') as f: json.dump(structured_data, f, indent=2, ensure_ascii=False) print(f"✅ Decoding Complete. Total time: {time.time() - start_time:.2f}s") except Exception as e: print(f"⚠️ Error saving final JSON: {e}") return structured_data def create_query_text(entry: Dict[str, Any]) -> str: """Combines question and options into a single string for similarity matching.""" query_parts = [] if entry.get("question"): query_parts.append(entry["question"]) for key in ["options", "options_text"]: options = entry.get(key) if options and isinstance(options, dict): for value in options.values(): if value and isinstance(value, str): query_parts.append(value) return " ".join(query_parts) def calculate_similarity(doc1: str, doc2: str) -> float: """Calculates Cosine Similarity between two text strings.""" if not doc1 or not doc2: return 0.0 def clean_text(text): return re.sub(r'^\s*[\(\d\w]+\.?\s*', '', text, flags=re.MULTILINE) clean_doc1 = clean_text(doc1) clean_doc2 = clean_text(doc2) corpus = [clean_doc1, clean_doc2] try: vectorizer = CountVectorizer(stop_words='english', lowercase=True, token_pattern=r'(?u)\b\w\w+\b') tfidf_matrix = vectorizer.fit_transform(corpus) if tfidf_matrix.shape[1] == 0: return 0.0 vectors = tfidf_matrix.toarray() # Handle cases where vectors might be empty or too short if len(vectors) < 2: return 0.0 score = cosine_similarity(vectors[0:1], vectors[1:2])[0][0] return score except Exception: return 0.0 def process_context_linking(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ Links questions to passages based on 'passage' flow vs 'new_passage' priority. Includes 'Decay Logic': If 2 consecutive questions fail to match the active passage, the passage context is dropped to prevent false positives downstream. """ print("\n" + "=" * 80) print("--- STARTING CONTEXT LINKING (WITH DECAY LOGIC) ---") print("=" * 80) if not data: return [] # --- PHASE 1: IDENTIFY PASSAGE DEFINERS --- passage_definer_indices = [] for i, entry in enumerate(data): if entry.get("passage") and entry["passage"].strip(): passage_definer_indices.append(i) if entry.get("new_passage") and entry["new_passage"].strip(): if i not in passage_definer_indices: passage_definer_indices.append(i) # --- PHASE 2: CONTEXT TRANSFER & LINKING --- current_passage_text = None current_new_passage_text = None # NEW: Counter to track consecutive linking failures consecutive_failures = 0 MAX_CONSECUTIVE_FAILURES = 2 for i, entry in enumerate(data): item_type = entry.get("type", "Question") # A. UNCONDITIONALLY UPDATE CONTEXTS (And Reset Decay Counter) if entry.get("passage") and entry["passage"].strip(): current_passage_text = entry["passage"] consecutive_failures = 0 # Reset because we have fresh explicit context # print(f" [Flow] Updated Standard Context from Item {i}") if entry.get("new_passage") and entry["new_passage"].strip(): current_new_passage_text = entry["new_passage"] # We don't necessarily reset standard failures here as this is a local override # B. QUESTION LINKING if entry.get("question") and item_type != "METADATA": combined_query = create_query_text(entry) # Skip if query is too short (noise) if len(combined_query.strip()) < 5: continue # Calculate scores score_old = calculate_similarity(current_passage_text, combined_query) if current_passage_text else 0.0 score_new = calculate_similarity(current_new_passage_text, combined_query) if current_new_passage_text else 0.0 # ------------------------------------------------------------------ # 🛑 CRITICAL FIX APPLIED HERE 🛑 # The original line: q_preview = entry['question'][:30] + '...' # 1. Capture the raw preview string (which might contain the bad surrogate) q_preview_raw = entry['question'][:30] + '...' # 2. Safely clean the string by encoding to UTF-8 and ignoring errors, # then decoding back. This removes the invalid surrogate character. q_preview = q_preview_raw.encode('utf-8', errors='ignore').decode('utf-8') # ------------------------------------------------------------------ # RESOLUTION LOGIC linked = False # 1. Prefer New Passage if significantly better if current_new_passage_text and (score_new > score_old + RESOLUTION_MARGIN) and ( score_new >= SIMILARITY_THRESHOLD): entry["passage"] = current_new_passage_text print(f" [Linker] 🚀 Q{i} ('{q_preview}') -> NEW PASSAGE (Score: {score_new:.3f})") linked = True # Note: We do not reset 'consecutive_failures' for the standard passage here, # because we matched the *new* passage, not the standard one. # 2. Otherwise use Standard Passage if it meets threshold elif current_passage_text and (score_old >= SIMILARITY_THRESHOLD): entry["passage"] = current_passage_text print(f" [Linker] ✅ Q{i} ('{q_preview}') -> STANDARD PASSAGE (Score: {score_old:.3f})") linked = True consecutive_failures = 0 # Success! Reset the kill switch. if not linked: # 3. DECAY LOGIC if current_passage_text: consecutive_failures += 1 # This is the line that was failing (or similar logging lines) print( f" [Linker] ⚠️ Q{i} NOT LINKED. (Failures: {consecutive_failures}/{MAX_CONSECUTIVE_FAILURES})") if consecutive_failures >= MAX_CONSECUTIVE_FAILURES: print(f" [Linker] 🗑️ Context dropped due to {consecutive_failures} consecutive misses.") current_passage_text = None consecutive_failures = 0 else: print(f" [Linker] ⚠️ Q{i} NOT LINKED (No active context).") # --- PHASE 3: CLEANUP AND INTERPOLATION --- print(" [Linker] Running Cleanup & Interpolation...") # 3A. Self-Correction (Remove weak links) for i in passage_definer_indices: entry = data[i] if entry.get("question") and entry.get("type") != "METADATA": passage_to_check = entry.get("passage") or entry.get("new_passage") if passage_to_check: self_sim = calculate_similarity(passage_to_check, create_query_text(entry)) if self_sim < SIMILARITY_THRESHOLD: entry["passage"] = "" if "new_passage" in entry: entry["new_passage"] = "" print(f" [Cleanup] Removed weak link for Q{i}") # 3B. Interpolation (Fill gaps) # We only interpolate if the gap is strictly 1 question wide to avoid undoing the decay logic for i in range(1, len(data) - 1): current_entry = data[i] is_gap = current_entry.get("question") and not current_entry.get("passage") if is_gap: prev_p = data[i - 1].get("passage") next_p = data[i + 1].get("passage") if prev_p and next_p and (prev_p == next_p) and prev_p.strip(): current_entry["passage"] = prev_p print(f" [Linker] 🥪 Q{i} Interpolated from neighbors.") return data def correct_misaligned_options(structured_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: print("\n" + "=" * 80) print("--- 5. STARTING POST-PROCESSING: OPTION ALIGNMENT CORRECTION ---") print("=" * 80) tag_pattern = re.compile(r'(EQUATION\d+|FIGURE\d+)') corrected_count = 0 for item in structured_data: if item.get('type') in ['METADATA']: continue options = item.get('options') if not options or len(options) < 2: continue option_keys = list(options.keys()) for i in range(len(option_keys) - 1): current_key = option_keys[i] next_key = option_keys[i + 1] current_value = options[current_key].strip() next_value = options[next_key].strip() is_current_empty = current_value == current_key content_in_next = next_value.replace(next_key, '', 1).strip() tags_in_next = tag_pattern.findall(content_in_next) has_two_tags = len(tags_in_next) == 2 if is_current_empty and has_two_tags: tag_to_move = tags_in_next[0] options[current_key] = f"{current_key} {tag_to_move}".strip() options[next_key] = f"{next_key} {tags_in_next[1]}".strip() corrected_count += 1 print(f"✅ Option alignment correction finished. Total corrections: {corrected_count}.") return structured_data def get_base64_for_file(filepath: str) -> Optional[str]: """Reads a file and returns its Base64 encoded string without the data URI prefix.""" try: with open(filepath, "rb") as image_file: # Return raw base64 string return base64.b64encode(image_file.read()).decode('utf-8') except Exception as e: print(f"Error reading and encoding file {filepath}: {e}") return None def embed_images_as_base64_in_memory(structured_data: List[Dict[str, Any]], figure_extraction_dir: str) -> List[ Dict[str, Any]]: print("\n" + "=" * 80) print("--- 4. STARTING IMAGE EMBEDDING (Base64) / EQUATION TO LATEX CONVERSION ---") print("=" * 80) if not structured_data: return [] image_files = glob.glob(os.path.join(figure_extraction_dir, "*.png")) image_lookup = {} tag_regex = re.compile(r'(figure|equation)(\d+)', re.IGNORECASE) for filepath in image_files: filename = os.path.basename(filepath) match = re.search(r'_(figure|equation)(\d+)\.png$', filename, re.IGNORECASE) if match: key = f"{match.group(1).upper()}{match.group(2)}" image_lookup[key] = filepath print(f" -> Found {len(image_lookup)} image components.") final_structured_data = [] for item in structured_data: text_fields = [item.get('question', ''), item.get('passage', '')] if 'options' in item: for opt_val in item['options'].values(): text_fields.append(opt_val) if 'new_passage' in item: text_fields.append(item['new_passage']) unique_tags_to_embed = set() for text in text_fields: if not text: continue for match in tag_regex.finditer(text): tag = match.group(0).upper() if tag in image_lookup: unique_tags_to_embed.add(tag) # List of tags that were successfully converted to LaTeX tags_converted_to_latex = set() for tag in sorted(list(unique_tags_to_embed)): filepath = image_lookup[tag] base_key = tag.replace(' ', '').lower() # e.g., figure1 or equation1 if 'EQUATION' in tag: # Equation to LaTeX conversion base64_code = get_base64_for_file(filepath) # This reads the file for conversion if base64_code: latex_output = get_latex_from_base64(base64_code) if not latex_output.startswith('[P2T_ERROR') and not latex_output.startswith('[P2T_WARNING'): # *** CORE CHANGE: Store the clean LaTeX output directly *** item[base_key] = latex_output tags_converted_to_latex.add(tag) print(f" ✅ Embedded Clean LaTeX for {tag}") else: # On failure, embed the error message item[base_key] = latex_output print(f" ⚠️ Failed to convert {tag} to LaTeX. Embedding error message.") else: item[base_key] = "[FILE_ERROR: Could not read image file]" print(f" ❌ File read error for {tag}.") elif 'FIGURE' in tag: # Figure to Base64 conversion base64_code = get_base64_for_file(filepath) item[base_key] = base64_code print(f" ✅ Embedded Base64 for {tag}") final_structured_data.append(item) print(f"✅ Image embedding complete.") return final_structured_data # ============================================================================ # --- MAIN FUNCTION --- # ============================================================================ # def run_document_pipeline(input_pdf_path: str, layoutlmv3_model_path: str) -> Optional[ # List[Dict[str, Any]]]: # def run_document_pipeline( input_pdf_path: str, layoutlmv3_model_path: str, structured_intermediate_output_path: Optional[str] = None) -> Optional[List[Dict[str, Any]]]: # if not os.path.exists(input_pdf_path): return None # print("\n" + "#" * 80) # print("### STARTING OPTIMIZED FULL DOCUMENT ANALYSIS PIPELINE ###") # print("#" * 80) # pdf_name = os.path.splitext(os.path.basename(input_pdf_path))[0] # temp_pipeline_dir = os.path.join(tempfile.gettempdir(), f"pipeline_run_{pdf_name}_{os.getpid()}") # os.makedirs(temp_pipeline_dir, exist_ok=True) # preprocessed_json_path = os.path.join(temp_pipeline_dir, f"{pdf_name}_preprocessed.json") # raw_output_path = os.path.join(temp_pipeline_dir, f"{pdf_name}_raw_predictions.json") # structured_intermediate_output_path = os.path.join(temp_pipeline_dir, f"{pdf_name}_structured_intermediate.json") # final_result = None # try: # # Phase 1: Preprocessing with YOLO First + Masking # preprocessed_json_path_out = run_single_pdf_preprocessing(input_pdf_path, preprocessed_json_path) # if not preprocessed_json_path_out: return None # # Phase 2: Inference # page_raw_predictions_list = run_inference_and_get_raw_words( # input_pdf_path, layoutlmv3_model_path, preprocessed_json_path_out # ) # if not page_raw_predictions_list: return None # # --- DEBUG STEP: SAVE RAW PREDICTIONS --- # # Save raw predictions to the temporary file # with open(raw_output_path, 'w', encoding='utf-8') as f: # json.dump(page_raw_predictions_list, f, indent=4) # # Explicitly copy/save the raw predictions to the user-specified debug path # # if raw_predictions_output_path: # # shutil.copy(raw_output_path, raw_predictions_output_path) # # print(f"\n✅ DEBUG: Raw predictions saved to: {raw_predictions_output_path}") # # ---------------------------------------- # # Phase 3: Decoding # structured_data_list = convert_bio_to_structured_json_relaxed( # raw_output_path, structured_intermediate_output_path # ) # if not structured_data_list: return None # structured_data_list = correct_misaligned_options(structured_data_list) # structured_data_list = process_context_linking(structured_data_list) # # Phase 4: Embedding / Equation to LaTeX Conversion # final_result = embed_images_as_base64_in_memory(structured_data_list, FIGURE_EXTRACTION_DIR) # #================================================================================ # # --- NEW FINAL STEP: HIERARCHICAL CLASSIFICATION TAGGING --- # #================================================================================ # print("\n" + "=" * 80) # print("--- FINAL STEP: HIERARCHICAL SUBJECT/CONCEPT TAGGING ---") # print("=" * 80) # # 1. Initialize and Load the Classifier # classifier = HierarchicalClassifier() # if classifier.load_models(): # # 2. Run Classification on the *Final* Result # # The function modifies the list in place and returns it # final_result = post_process_json_with_inference( # final_result, classifier # ) # print("✅ Classification complete. Tags added to final output.") # else: # print("❌ Classification model loading failed. Outputting un-tagged data.") # # ==================================================================== # except Exception as e: # print(f"❌ FATAL ERROR: {e}") # import traceback # traceback.print_exc() # return None # finally: # try: # for f in glob.glob(os.path.join(temp_pipeline_dir, '*')): # os.remove(f) # os.rmdir(temp_pipeline_dir) # except Exception: # pass # print("\n" + "#" * 80) # print("### OPTIMIZED PIPELINE EXECUTION COMPLETE ###") # print("#" * 80) # return final_result def classify_question_type(item: Dict[str, Any]) -> str: """ Classifies a question as 'MCQ', 'DESCRIPTIVE', or 'INTEGER' based on its options. Args: item: Dictionary containing question data with 'options' field Returns: str: 'MCQ' if options exist and are non-empty, 'DESCRIPTIVE' otherwise """ # Check if options exist and have meaningful content options = item.get('options', {}) if not options: return 'DESCRIPTIVE' # Check if options dict has keys and at least one non-empty value has_valid_options = False for key, value in options.items(): # Check if the value is more than just the key itself (e.g., "A" vs "A Some text") if value and isinstance(value, str): # Remove the key from value and check if there's remaining content remaining_text = value.replace(key, '').strip() if remaining_text and len(remaining_text) > 0: has_valid_options = True break return 'MCQ' if has_valid_options else 'DESCRIPTIVE' def add_question_type_validation(structured_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ Adds 'question_type' field to all question entries in the structured data. Args: structured_data: List of dictionaries containing question data Returns: List[Dict[str, Any]]: Modified list with 'question_type' field added """ print("\n" + "=" * 80) print("--- ADDING QUESTION TYPE VALIDATION ---") print("=" * 80) mcq_count = 0 descriptive_count = 0 metadata_count = 0 for item in structured_data: item_type = item.get('type', 'Question') # Skip metadata entries if item_type == 'METADATA': metadata_count += 1 item['question_type'] = 'METADATA' continue # Classify the question question_type = classify_question_type(item) item['question_type'] = question_type if question_type == 'MCQ': mcq_count += 1 else: descriptive_count += 1 print(f" ✅ Classification Complete:") print(f" - MCQ Questions: {mcq_count}") print(f" - Descriptive/Integer Questions: {descriptive_count}") print(f" - Metadata Entries: {metadata_count}") print(f" - Total Entries: {len(structured_data)}") return structured_data import time import traceback import glob # def run_document_pipeline(input_pdf_path: str, layoutlmv3_model_path: str, structured_intermediate_output_path: Optional[str] = None) -> Optional[List[Dict[str, Any]]]: # if not os.path.exists(input_pdf_path): # print(f"❌ ERROR: File not found: {input_pdf_path}") # return None # print("\n" + "#" * 80) # print("### STARTING OPTIMIZED FULL DOCUMENT ANALYSIS PIPELINE ###") # print(f"Input: {input_pdf_path}") # print("#" * 80) # overall_start = time.time() # pdf_name = os.path.splitext(os.path.basename(input_pdf_path))[0] # temp_pipeline_dir = os.path.join(tempfile.gettempdir(), f"pipeline_run_{pdf_name}_{os.getpid()}") # os.makedirs(temp_pipeline_dir, exist_ok=True) # preprocessed_json_path = os.path.join(temp_pipeline_dir, f"{pdf_name}_preprocessed.json") # raw_output_path = os.path.join(temp_pipeline_dir, f"{pdf_name}_raw_predictions.json") # # If the user didn't provide a path, create one in the temp directory # if structured_intermediate_output_path is None: # structured_intermediate_output_path = os.path.join( # temp_pipeline_dir, f"{pdf_name}_structured_intermediate.json" # ) # final_result = None # try: # # --- Phase 1: Preprocessing --- # print(f"\n[Step 1/5] Preprocessing (YOLO + Masking)...") # p1_start = time.time() # preprocessed_json_path_out = run_single_pdf_preprocessing(input_pdf_path, preprocessed_json_path) # if not preprocessed_json_path_out: # print("❌ FAILED at Step 1: Preprocessing returned None.") # return None # print(f"✅ Step 1 Complete ({time.time() - p1_start:.2f}s)") # # --- Phase 2: Inference --- # print(f"\n[Step 2/5] Inference (LayoutLMv3)...") # p2_start = time.time() # page_raw_predictions_list = run_inference_and_get_raw_words( # input_pdf_path, layoutlmv3_model_path, preprocessed_json_path_out # ) # if not page_raw_predictions_list: # print("❌ FAILED at Step 2: Inference returned no data.") # return None # # Save raw predictions for Step 3 # with open(raw_output_path, 'w', encoding='utf-8') as f: # json.dump(page_raw_predictions_list, f, indent=4) # print(f"✅ Step 2 Complete ({time.time() - p2_start:.2f}s)") # # --- Phase 3: Decoding --- # print(f"\n[Step 3/5] Decoding (BIO to Structured JSON)...") # p3_start = time.time() # structured_data_list = convert_bio_to_structured_json_relaxed( # raw_output_path, structured_intermediate_output_path # ) # if not structured_data_list: # print("❌ FAILED at Step 3: BIO conversion failed.") # return None # # Logic adjustments # print("... Correcting misalignments and linking context ...") # structured_data_list = correct_misaligned_options(structured_data_list) # structured_data_list = process_context_linking(structured_data_list) # print(f"✅ Step 3 Complete ({time.time() - p3_start:.2f}s)") # # --- Phase 4: Base64 & LaTeX --- # print(f"\n[Step 4/5] Finalizing Layout (Base64 Images & LaTeX)...") # p4_start = time.time() # final_result = embed_images_as_base64_in_memory(structured_data_list, FIGURE_EXTRACTION_DIR) # if not final_result: # print("❌ FAILED at Step 4: Final formatting failed.") # return None # print(f"✅ Step 4 Complete ({time.time() - p4_start:.2f}s)") # # --- ADD THIS NEW STEP HERE --- # print(f"\n[Step 4.5/5] Adding Question Type Classification...") # p4_5_start = time.time() # final_result = add_question_type_validation(final_result) # print(f"✅ Step 4.5 Complete ({time.time() - p4_5_start:.2f}s)") # # --- END OF NEW STEP --- # # --- Phase 5: Hierarchical Tagging --- # print(f"\n[Step 5/5] AI Classification (Subject/Concept Tagging)...") # p5_start = time.time() # classifier = HierarchicalClassifier() # if classifier.load_models(): # final_result = post_process_json_with_inference(final_result, classifier) # print(f"✅ Step 5 Complete: Tags added ({time.time() - p5_start:.2f}s)") # else: # print("⚠️ WARNING: Classifier models failed to load. Skipping tags.") # except Exception as e: # print(f"\n‼️ FATAL PIPELINE EXCEPTION:") # print(f"Error Message: {str(e)}") # traceback.print_exc() # return None # finally: # print(f"\nCleaning up temporary files in {temp_pipeline_dir}...") # try: # for f in glob.glob(os.path.join(temp_pipeline_dir, '*')): # os.remove(f) # os.rmdir(temp_pipeline_dir) # print("🧹 Cleanup successful.") # except Exception as e: # print(f"⚠️ Cleanup failed: {e}") # total_time = time.time() - overall_start # print("\n" + "#" * 80) # print(f"### PIPELINE COMPLETE | Total Time: {total_time:.2f}s ###") # print("#" * 80) # return final_result def run_document_pipeline(input_pdf_path: str, layoutlmv3_model_path: str, structured_intermediate_output_path: Optional[str] = None) -> Optional[List[Dict[str, Any]]]: if not os.path.exists(input_pdf_path): print(f"❌ ERROR: File not found: {input_pdf_path}") return None print("\n" + "#" * 80) print("### STARTING OPTIMIZED FULL DOCUMENT ANALYSIS PIPELINE ###") print(f"Input: {input_pdf_path}") print("#" * 80) overall_start = time.time() pdf_name = os.path.splitext(os.path.basename(input_pdf_path))[0] temp_pipeline_dir = os.path.join(tempfile.gettempdir(), f"pipeline_run_{pdf_name}_{os.getpid()}") os.makedirs(temp_pipeline_dir, exist_ok=True) preprocessed_json_path = os.path.join(temp_pipeline_dir, f"{pdf_name}_preprocessed.json") raw_output_path = os.path.join(temp_pipeline_dir, f"{pdf_name}_raw_predictions.json") if structured_intermediate_output_path is None: structured_intermediate_output_path = os.path.join( temp_pipeline_dir, f"{pdf_name}_structured_intermediate.json" ) final_result = None try: # --- Phase 1: Preprocessing --- print(f"\n[Step 1/5] Preprocessing (YOLO + Masking)...") p1_start = time.time() preprocessed_json_path_out = run_single_pdf_preprocessing(input_pdf_path, preprocessed_json_path) if not preprocessed_json_path_out: print("❌ FAILED at Step 1: Preprocessing returned None.") return None print(f"✅ Step 1 Complete ({time.time() - p1_start:.2f}s)") # --- Phase 2: Inference --- print(f"\n[Step 2/5] Inference (LayoutLMv3)...") p2_start = time.time() page_raw_predictions_list = run_inference_and_get_raw_words( input_pdf_path, layoutlmv3_model_path, preprocessed_json_path_out ) if not page_raw_predictions_list: print("❌ FAILED at Step 2: Inference returned no data.") return None with open(raw_output_path, 'w', encoding='utf-8') as f: json.dump(page_raw_predictions_list, f, indent=4) print(f"✅ Step 2 Complete ({time.time() - p2_start:.2f}s)") # --- Phase 3: Decoding --- print(f"\n[Step 3/5] Decoding (BIO to Structured JSON)...") p3_start = time.time() structured_data_list = convert_bio_to_structured_json_relaxed( raw_output_path, structured_intermediate_output_path ) if not structured_data_list: print("❌ FAILED at Step 3: BIO conversion failed.") return None print("... Correcting misalignments and linking context ...") structured_data_list = correct_misaligned_options(structured_data_list) structured_data_list = process_context_linking(structured_data_list) print(f"✅ Step 3 Complete ({time.time() - p3_start:.2f}s)") # --- Phase 4: Base64 & LaTeX --- print(f"\n[Step 4/5] Finalizing Layout (Base64 Images & LaTeX)...") p4_start = time.time() final_result = embed_images_as_base64_in_memory(structured_data_list, FIGURE_EXTRACTION_DIR) if not final_result: print("❌ FAILED at Step 4: Final formatting failed.") return None print(f"✅ Step 4 Complete ({time.time() - p4_start:.2f}s)") # --- Phase 4.5: Question Type Classification --- print(f"\n[Step 4.5/5] Adding Question Type Classification...") p4_5_start = time.time() final_result = add_question_type_validation(final_result) print(f"✅ Step 4.5 Complete ({time.time() - p4_5_start:.2f}s)") # --- Phase 5: Hierarchical Tagging --- print(f"\n[Step 5/5] AI Classification (Subject/Concept Tagging)...") p5_start = time.time() classifier = HierarchicalClassifier() if classifier.load_models(): final_result = post_process_json_with_inference(final_result, classifier) print(f"✅ Step 5 Complete: Tags added ({time.time() - p5_start:.2f}s)") else: print("⚠️ WARNING: Classifier models failed to load. Skipping tags.") # ============================================================ # 🔧 NEW STEP: FILTER OUT METADATA ENTRIES # ============================================================ print(f"\n[Post-Processing] Removing METADATA entries...") initial_count = len(final_result) final_result = [item for item in final_result if item.get('type') != 'METADATA'] removed_count = initial_count - len(final_result) print(f"✅ Removed {removed_count} METADATA entries. {len(final_result)} questions remain.") # ============================================================ except Exception as e: print(f"\n‼️ FATAL PIPELINE EXCEPTION:") print(f"Error Message: {str(e)}") traceback.print_exc() return None # finally: # print(f"\nCleaning up temporary files in {temp_pipeline_dir}...") # try: # for f in glob.glob(os.path.join(temp_pipeline_dir, '*')): # os.remove(f) # os.rmdir(temp_pipeline_dir) # print("🧹 Cleanup successful.") # except Exception as e: # print(f"⚠️ Cleanup failed: {e}") total_time = time.time() - overall_start print("\n" + "#" * 80) print(f"### PIPELINE COMPLETE | Total Time: {total_time:.2f}s ###") print("#" * 80) return final_result if __name__ == "__main__": parser = argparse.ArgumentParser(description="Complete Pipeline") parser.add_argument("--input_pdf", type=str, required=True, help="Input PDF") parser.add_argument("--layoutlmv3_model_path", type=str, default=DEFAULT_LAYOUTLMV3_MODEL_PATH, help="Model Path") # --- ADDED ARGUMENT FOR DEBUGGING --- parser.add_argument("--raw_preds_path", type=str, default='BIO_debug.json', help="Debug path for raw BIO tag predictions (JSON).") # ------------------------------------ args = parser.parse_args() pdf_name = os.path.splitext(os.path.basename(args.input_pdf))[0] final_output_path = os.path.abspath(f"{pdf_name}_final_output_embedded.json") # --- CALCULATE RAW PREDICTIONS OUTPUT PATH (Kept commented as per original script) --- # raw_predictions_output_path = os.path.abspath( # args.raw_preds_path if args.raw_preds_path else f"{pdf_name}_raw_predictions_debug.json") # --------------------------------------------- # --- UPDATED FUNCTION CALL --- final_json_data = run_document_pipeline( args.input_pdf, args.layoutlmv3_model_path ) # ----------------------------- # 🛑 CRITICAL FINAL FIX: AGGRESSIVE CUSTOM JSON SAVING 🛑 if final_json_data: # 1. Dump the Python object to a standard JSON string. # This converts the in-memory double backslash ('\\') into a quadruple backslash ('\\\\') # in the raw json_str string content. json_str = json.dumps(final_json_data, indent=2, ensure_ascii=False) # 2. **AGGRESSIVE UNDO ESCAPING:** We assume we have quadruple backslashes and # replace them with the double backslashes needed for the LaTeX command to work. # This operation essentially replaces four literal backslashes with two literal backslashes. # final_output_content = json_str.replace('\\\\\\\\', '\\\\') # 3. Write the corrected string content to the file. with open(final_output_path, 'w', encoding='utf-8') as f: f.write(json_str) print(f"\n✅ Final Data Saved: {final_output_path}") else: print("\n❌ Pipeline Failed.") sys.exit(1)