heerjtdev commited on
Commit
b78516d
Β·
1 Parent(s): a802a9c

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +1598 -0
app.py ADDED
@@ -0,0 +1,1598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import argparse
3
+ import os
4
+ import re
5
+ import torch
6
+ import torch.nn as nn
7
+ from TorchCRF import CRF
8
+ from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Model, LayoutLMv3Config
9
+ from typing import List, Dict, Any, Optional, Union, Tuple
10
+ import fitz # PyMuPDF
11
+ import numpy as np
12
+ import cv2
13
+ from ultralytics import YOLO
14
+ import glob
15
+ import pytesseract
16
+ from PIL import Image
17
+ from scipy.signal import find_peaks
18
+ from scipy.ndimage import gaussian_filter1d
19
+ import sys
20
+ import io
21
+ import base64
22
+ import tempfile
23
+ import time
24
+ import shutil
25
+ from sklearn.feature_extraction.text import CountVectorizer
26
+ from sklearn.metrics.pairwise import cosine_similarity
27
+
28
+ # ============================================================================
29
+ # --- CONFIGURATION AND CONSTANTS ---
30
+ # ============================================================================
31
+
32
+ # NOTE: Update these paths to match your environment before running!
33
+ WEIGHTS_PATH = 'YOLO_MATH/yolo_split_data/runs/detect/math_figure_detector_v3/weights/best.pt'
34
+ DEFAULT_LAYOUTLMV3_MODEL_PATH = "97.pth"
35
+
36
+ # DIRECTORY CONFIGURATION
37
+ OCR_JSON_OUTPUT_DIR = './ocr_json_output_final'
38
+ FIGURE_EXTRACTION_DIR = './figure_extraction'
39
+ TEMP_IMAGE_DIR = './temp_pdf_images'
40
+
41
+ # Detection parameters
42
+ CONF_THRESHOLD = 0.2
43
+ TARGET_CLASSES = ['figure', 'equation']
44
+ IOU_MERGE_THRESHOLD = 0.4
45
+ IOA_SUPPRESSION_THRESHOLD = 0.7
46
+ LINE_TOLERANCE = 15
47
+
48
+ # Similarity
49
+ SIMILARITY_THRESHOLD = 0.10
50
+ RESOLUTION_MARGIN = 0.05
51
+
52
+ # Global counters for sequential numbering across the entire PDF
53
+ GLOBAL_FIGURE_COUNT = 0
54
+ GLOBAL_EQUATION_COUNT = 0
55
+
56
+ # LayoutLMv3 Labels
57
+ ID_TO_LABEL = {
58
+ 0: "O",
59
+ 1: "B-QUESTION", 2: "I-QUESTION",
60
+ 3: "B-OPTION", 4: "I-OPTION",
61
+ 5: "B-ANSWER", 6: "I-ANSWER",
62
+ 7: "B-SECTION_HEADING", 8: "I-SECTION_HEADING",
63
+ 9: "B-PASSAGE", 10: "I-PASSAGE"
64
+ }
65
+ NUM_LABELS = len(ID_TO_LABEL)
66
+
67
+
68
+ # ============================================================================
69
+ # --- PERFORMANCE OPTIMIZATION: OCR CACHE ---
70
+ # ============================================================================
71
+
72
+ class OCRCache:
73
+ """Caches OCR results per page to avoid redundant Tesseract runs."""
74
+
75
+ def __init__(self):
76
+ self.cache = {}
77
+
78
+ def get_key(self, pdf_path: str, page_num: int) -> str:
79
+ return f"{pdf_path}:{page_num}"
80
+
81
+ def has_ocr(self, pdf_path: str, page_num: int) -> bool:
82
+ return self.get_key(pdf_path, page_num) in self.cache
83
+
84
+ def get_ocr(self, pdf_path: str, page_num: int) -> Optional[list]:
85
+ return self.cache.get(self.get_key(pdf_path, page_num))
86
+
87
+ def set_ocr(self, pdf_path: str, page_num: int, ocr_data: list):
88
+ self.cache[self.get_key(pdf_path, page_num)] = ocr_data
89
+
90
+ def clear(self):
91
+ self.cache.clear()
92
+
93
+
94
+ # Global OCR cache instance
95
+ _ocr_cache = OCRCache()
96
+
97
+
98
+ # ============================================================================
99
+ # --- PHASE 1: YOLO/OCR PREPROCESSING FUNCTIONS ---
100
+ # ============================================================================
101
+
102
+ def calculate_iou(box1, box2):
103
+ x1_a, y1_a, x2_a, y2_a = box1
104
+ x1_b, y1_b, x2_b, y2_b = box2
105
+ x_left = max(x1_a, x1_b)
106
+ y_top = max(y1_a, y1_b)
107
+ x_right = min(x2_a, x2_b)
108
+ y_bottom = min(y2_a, y2_b)
109
+ intersection_area = max(0, x_right - x_left) * max(0, y_bottom - y_top)
110
+ box_a_area = (x2_a - x1_a) * (y2_a - y1_a)
111
+ box_b_area = (x2_b - x1_b) * (y2_b - y1_b)
112
+ union_area = float(box_a_area + box_b_area - intersection_area)
113
+ return intersection_area / union_area if union_area > 0 else 0
114
+
115
+
116
+ def calculate_ioa(box1, box2):
117
+ x1_a, y1_a, x2_a, y2_a = box1
118
+ x1_b, y1_b, x2_b, y2_b = box2
119
+ x_left = max(x1_a, x1_b)
120
+ y_top = max(y1_a, y1_b)
121
+ x_right = min(x2_a, x2_b)
122
+ y_bottom = min(y2_a, y2_b)
123
+ intersection_area = max(0, x_right - x_left) * max(0, y_bottom - y_top)
124
+ box_a_area = (x2_a - x1_a) * (y2_a - y1_a)
125
+ return intersection_area / box_a_area if box_a_area > 0 else 0
126
+
127
+
128
+ def filter_nested_boxes(detections, ioa_threshold=0.80):
129
+ """
130
+ Removes boxes that are inside larger boxes (Containment Check).
131
+ Prioritizes keeping the LARGEST box (the 'parent' container).
132
+ """
133
+ if not detections:
134
+ return []
135
+
136
+ # 1. Calculate Area for all detections
137
+ for d in detections:
138
+ x1, y1, x2, y2 = d['coords']
139
+ d['area'] = (x2 - x1) * (y2 - y1)
140
+
141
+ # 2. Sort by Area Descending (Largest to Smallest)
142
+ # This ensures we process the 'container' first
143
+ detections.sort(key=lambda x: x['area'], reverse=True)
144
+
145
+ keep_indices = []
146
+ is_suppressed = [False] * len(detections)
147
+
148
+ for i in range(len(detections)):
149
+ if is_suppressed[i]: continue
150
+
151
+ keep_indices.append(i)
152
+ box_a = detections[i]['coords']
153
+
154
+ # Compare with all smaller boxes
155
+ for j in range(i + 1, len(detections)):
156
+ if is_suppressed[j]: continue
157
+
158
+ box_b = detections[j]['coords']
159
+
160
+ # Calculate Intersection
161
+ x_left = max(box_a[0], box_b[0])
162
+ y_top = max(box_a[1], box_b[1])
163
+ x_right = min(box_a[2], box_b[2])
164
+ y_bottom = min(box_a[3], box_b[3])
165
+
166
+ if x_right < x_left or y_bottom < y_top:
167
+ intersection = 0
168
+ else:
169
+ intersection = (x_right - x_left) * (y_bottom - y_top)
170
+
171
+ # Calculate IoA (Intersection over Area of the SMALLER box)
172
+ # Since we sorted by area, 'box_b' (detections[j]) is the smaller one.
173
+ area_b = detections[j]['area']
174
+
175
+ if area_b > 0:
176
+ ioa_small = intersection / area_b
177
+
178
+ # If the small box is > 90% inside the big box, suppress the small one.
179
+ if ioa_small > ioa_threshold:
180
+ is_suppressed[j] = True
181
+ # print(f" [Suppress] Removed nested object inside larger '{detections[i]['class']}'")
182
+
183
+ return [detections[i] for i in keep_indices]
184
+
185
+
186
+ def merge_overlapping_boxes(detections, iou_threshold):
187
+ if not detections: return []
188
+ detections.sort(key=lambda d: d['conf'], reverse=True)
189
+ merged_detections = []
190
+ is_merged = [False] * len(detections)
191
+ for i in range(len(detections)):
192
+ if is_merged[i]: continue
193
+ current_box = detections[i]['coords']
194
+ current_class = detections[i]['class']
195
+ merged_x1, merged_y1, merged_x2, merged_y2 = current_box
196
+ for j in range(i + 1, len(detections)):
197
+ if is_merged[j] or detections[j]['class'] != current_class: continue
198
+ other_box = detections[j]['coords']
199
+ iou = calculate_iou(current_box, other_box)
200
+ if iou > iou_threshold:
201
+ merged_x1 = min(merged_x1, other_box[0])
202
+ merged_y1 = min(merged_y1, other_box[1])
203
+ merged_x2 = max(merged_x2, other_box[2])
204
+ merged_y2 = max(merged_y2, other_box[3])
205
+ is_merged[j] = True
206
+ merged_detections.append({
207
+ 'coords': (merged_x1, merged_y1, merged_x2, merged_y2),
208
+ 'y1': merged_y1, 'class': current_class, 'conf': detections[i]['conf']
209
+ })
210
+ return merged_detections
211
+
212
+
213
+ def merge_yolo_into_word_data(raw_word_data: list, yolo_detections: list, scale_factor: float) -> list:
214
+ """
215
+ Filters out raw words that are inside YOLO boxes and replaces them with
216
+ a single solid 'placeholder' block for the column detector.
217
+ """
218
+ if not yolo_detections:
219
+ return raw_word_data
220
+
221
+ # 1. Convert YOLO boxes (Pixels) to PDF Coordinates (Points)
222
+ pdf_space_boxes = []
223
+ for det in yolo_detections:
224
+ x1, y1, x2, y2 = det['coords']
225
+ pdf_box = (
226
+ x1 / scale_factor,
227
+ y1 / scale_factor,
228
+ x2 / scale_factor,
229
+ y2 / scale_factor
230
+ )
231
+ pdf_space_boxes.append(pdf_box)
232
+
233
+ # 2. Filter out raw words that are inside YOLO boxes
234
+ cleaned_word_data = []
235
+ for word_tuple in raw_word_data:
236
+ wx1, wy1, wx2, wy2 = word_tuple[1], word_tuple[2], word_tuple[3], word_tuple[4]
237
+ w_center_x = (wx1 + wx2) / 2
238
+ w_center_y = (wy1 + wy2) / 2
239
+
240
+ is_inside_yolo = False
241
+ for px1, py1, px2, py2 in pdf_space_boxes:
242
+ if px1 <= w_center_x <= px2 and py1 <= w_center_y <= py2:
243
+ is_inside_yolo = True
244
+ break
245
+
246
+ if not is_inside_yolo:
247
+ cleaned_word_data.append(word_tuple)
248
+
249
+ # 3. Add the YOLO boxes themselves as "Solid Words"
250
+ for i, (px1, py1, px2, py2) in enumerate(pdf_space_boxes):
251
+ dummy_entry = (f"BLOCK_{i}", px1, py1, px2, py2)
252
+ cleaned_word_data.append(dummy_entry)
253
+
254
+ return cleaned_word_data
255
+
256
+
257
+ # ============================================================================
258
+ # --- MISSING HELPER FUNCTION ---
259
+ # ============================================================================
260
+
261
+ def preprocess_image_for_ocr(img_np):
262
+ """
263
+ Converts image to grayscale and applies Otsu's Binarization
264
+ to separate text from background clearly.
265
+ """
266
+ # 1. Convert to Grayscale if needed
267
+ if len(img_np.shape) == 3:
268
+ gray = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
269
+ else:
270
+ gray = img_np
271
+
272
+ # 2. Apply Otsu's Thresholding (Automatic binary threshold)
273
+ # This makes text solid black and background solid white
274
+ _, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
275
+
276
+ return thresh
277
+
278
+
279
+ def calculate_vertical_gap_coverage(word_data: list, sep_x: int, page_height: float, gutter_width: int = 10) -> float:
280
+ """
281
+ Calculates what percentage of the page's vertical text span is 'cleanly split' by the separator.
282
+ A valid column split should split > 65% of the page verticality.
283
+ """
284
+ if not word_data:
285
+ return 0.0
286
+
287
+ # Determine the vertical span of the actual text content
288
+ y_coords = [w[2] for w in word_data] + [w[4] for w in word_data] # y1 and y2
289
+ min_y, max_y = min(y_coords), max(y_coords)
290
+ total_text_height = max_y - min_y
291
+
292
+ if total_text_height <= 0:
293
+ return 0.0
294
+
295
+ # Create a boolean array representing the Y-axis (1 pixel per unit)
296
+ gap_open_mask = np.ones(int(total_text_height) + 1, dtype=bool)
297
+
298
+ zone_left = sep_x - (gutter_width / 2)
299
+ zone_right = sep_x + (gutter_width / 2)
300
+ offset_y = int(min_y)
301
+
302
+ for _, x1, y1, x2, y2 in word_data:
303
+ # Check if this word horizontally interferes with the separator
304
+ if x2 > zone_left and x1 < zone_right:
305
+ y_start_idx = max(0, int(y1) - offset_y)
306
+ y_end_idx = min(len(gap_open_mask), int(y2) - offset_y)
307
+ if y_end_idx > y_start_idx:
308
+ gap_open_mask[y_start_idx:y_end_idx] = False
309
+
310
+ open_pixels = np.sum(gap_open_mask)
311
+ coverage_ratio = open_pixels / len(gap_open_mask)
312
+
313
+ return coverage_ratio
314
+
315
+
316
+ def calculate_x_gutters(word_data: list, params: Dict, page_height: float) -> List[int]:
317
+ """
318
+ Calculates X-axis histogram and validates using BRIDGING DENSITY and Vertical Coverage.
319
+ """
320
+ if not word_data: return []
321
+
322
+ x_points = []
323
+ # Use only word_data elements 1 (x1) and 3 (x2)
324
+ for item in word_data:
325
+ x_points.extend([item[1], item[3]])
326
+
327
+ if not x_points: return []
328
+ max_x = max(x_points)
329
+
330
+ # 1. Determine total text height for ratio calculation
331
+ y_coords = [item[2] for item in word_data] + [item[4] for item in word_data]
332
+ min_y, max_y = min(y_coords), max(y_coords)
333
+ total_text_height = max_y - min_y
334
+ if total_text_height <= 0: return []
335
+
336
+ # Histogram Setup
337
+ bin_size = params.get('cluster_bin_size', 5)
338
+ smoothing = params.get('cluster_smoothing', 1)
339
+ min_width = params.get('cluster_min_width', 20)
340
+ threshold_percentile = params.get('cluster_threshold_percentile', 85)
341
+
342
+ num_bins = int(np.ceil(max_x / bin_size))
343
+ hist, bin_edges = np.histogram(x_points, bins=num_bins, range=(0, max_x))
344
+ smoothed_hist = gaussian_filter1d(hist.astype(float), sigma=smoothing)
345
+ inverted_signal = np.max(smoothed_hist) - smoothed_hist
346
+
347
+ peaks, properties = find_peaks(
348
+ inverted_signal,
349
+ height=np.max(inverted_signal) - np.percentile(smoothed_hist, threshold_percentile),
350
+ distance=min_width / bin_size
351
+ )
352
+
353
+ if not peaks.size: return []
354
+ separator_x_coords = [int(bin_edges[p]) for p in peaks]
355
+ final_separators = []
356
+
357
+ for x_coord in separator_x_coords:
358
+ # --- CHECK 1: BRIDGING DENSITY (The "Cut Through" Check) ---
359
+ # Calculate the total vertical height of words that physically cross this line.
360
+ bridging_height = 0
361
+ bridging_count = 0
362
+
363
+ for item in word_data:
364
+ wx1, wy1, wx2, wy2 = item[1], item[2], item[3], item[4]
365
+
366
+ # Check if this word physically sits on top of the separator line
367
+ if wx1 < x_coord and wx2 > x_coord:
368
+ word_h = wy2 - wy1
369
+ bridging_height += word_h
370
+ bridging_count += 1
371
+
372
+ # Calculate Ratio: How much of the page's text height is blocked by these crossing words?
373
+ bridging_ratio = bridging_height / total_text_height
374
+
375
+ # THRESHOLD: If bridging blocks > 8% of page height, REJECT.
376
+ # This allows for page numbers or headers (usually < 5%) to cross, but NOT paragraphs.
377
+ if bridging_ratio > 0.08:
378
+ print(
379
+ f" ❌ Separator X={x_coord} REJECTED: Bridging Ratio {bridging_ratio:.1%} (>15%) cuts through text.")
380
+ continue
381
+
382
+ # --- CHECK 2: VERTICAL GAP COVERAGE (The "Clean Split" Check) ---
383
+ # The gap must exist cleanly for > 65% of the text height.
384
+ coverage = calculate_vertical_gap_coverage(word_data, x_coord, page_height, gutter_width=min_width)
385
+
386
+ if coverage >= 0.80:
387
+ final_separators.append(x_coord)
388
+ print(f" -> Separator X={x_coord} ACCEPTED (Coverage: {coverage:.1%}, Bridging: {bridging_ratio:.1%})")
389
+ else:
390
+ print(f" ❌ Separator X={x_coord} REJECTED (Coverage: {coverage:.1%}, Bridging: {bridging_ratio:.1%})")
391
+
392
+ return sorted(final_separators)
393
+
394
+
395
+ def get_word_data_for_detection(page: fitz.Page, pdf_path: str, page_num: int,
396
+ top_margin_percent=0.10, bottom_margin_percent=0.10) -> list:
397
+ """Extract word data with OCR caching to avoid redundant Tesseract runs."""
398
+ word_data = page.get_text("words")
399
+
400
+ if len(word_data) > 0:
401
+ word_data = [(w[4], w[0], w[1], w[2], w[3]) for w in word_data]
402
+ else:
403
+ if _ocr_cache.has_ocr(pdf_path, page_num):
404
+ word_data = _ocr_cache.get_ocr(pdf_path, page_num)
405
+ else:
406
+ try:
407
+ # --- OPTIMIZATION START ---
408
+ # 1. Render at Higher Resolution (Zoom 4.0 = ~300 DPI)
409
+ zoom_level = 4.0
410
+ pix = page.get_pixmap(matrix=fitz.Matrix(zoom_level, zoom_level))
411
+
412
+ # 2. Convert directly to OpenCV format (Faster than PIL)
413
+ img_np = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, pix.n)
414
+ if pix.n == 3:
415
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
416
+ elif pix.n == 4:
417
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGBA2BGR)
418
+
419
+ # 3. Apply Preprocessing (Thresholding)
420
+ processed_img = preprocess_image_for_ocr(img_np)
421
+
422
+ # 4. Optimized Tesseract Config
423
+ # --psm 6: Assume a single uniform block of text (Great for columns/questions)
424
+ # --oem 3: Default engine (LSTM)
425
+ custom_config = r'--oem 3 --psm 6'
426
+
427
+ data = pytesseract.image_to_data(processed_img, output_type=pytesseract.Output.DICT,
428
+ config=custom_config)
429
+
430
+ full_word_data = []
431
+ for i in range(len(data['level'])):
432
+ text = data['text'][i].strip()
433
+ if text:
434
+ # Scale coordinates back to PDF points
435
+ x1 = data['left'][i] / zoom_level
436
+ y1 = data['top'][i] / zoom_level
437
+ x2 = (data['left'][i] + data['width'][i]) / zoom_level
438
+ y2 = (data['top'][i] + data['height'][i]) / zoom_level
439
+ full_word_data.append((text, x1, y1, x2, y2))
440
+
441
+ word_data = full_word_data
442
+ _ocr_cache.set_ocr(pdf_path, page_num, word_data)
443
+ # --- OPTIMIZATION END ---
444
+ except Exception as e:
445
+ print(f" ❌ OCR Error in detection phase: {e}")
446
+ return []
447
+
448
+ # Apply margin filtering
449
+ page_height = page.rect.height
450
+ y_min = page_height * top_margin_percent
451
+ y_max = page_height * (1 - bottom_margin_percent)
452
+ return [d for d in word_data if d[2] >= y_min and d[4] <= y_max]
453
+
454
+
455
+ def pixmap_to_numpy(pix: fitz.Pixmap) -> np.ndarray:
456
+ img_data = pix.samples
457
+ img = np.frombuffer(img_data, dtype=np.uint8).reshape(pix.height, pix.width, pix.n)
458
+ if pix.n == 4:
459
+ img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGR)
460
+ elif pix.n == 3:
461
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
462
+ return img
463
+
464
+
465
+ def extract_native_words_and_convert(fitz_page, scale_factor: float = 2.0) -> list:
466
+ raw_word_data = fitz_page.get_text("words")
467
+ converted_ocr_output = []
468
+ DEFAULT_CONFIDENCE = 99.0
469
+
470
+ for x1, y1, x2, y2, word, *rest in raw_word_data:
471
+ if not word.strip(): continue
472
+ x1_pix = int(x1 * scale_factor)
473
+ y1_pix = int(y1 * scale_factor)
474
+ x2_pix = int(x2 * scale_factor)
475
+ y2_pix = int(y2 * scale_factor)
476
+ converted_ocr_output.append({
477
+ 'type': 'text',
478
+ 'word': word,
479
+ 'confidence': DEFAULT_CONFIDENCE,
480
+ 'bbox': [x1_pix, y1_pix, x2_pix, y2_pix],
481
+ 'y0': y1_pix, 'x0': x1_pix
482
+ })
483
+ return converted_ocr_output
484
+
485
+
486
+ def preprocess_and_ocr_page(original_img: np.ndarray, model, pdf_path: str,
487
+ page_num: int, fitz_page: fitz.Page,
488
+ pdf_name: str) -> Tuple[List[Dict[str, Any]], Optional[int]]:
489
+ """
490
+ OPTIMIZED FLOW:
491
+ 1. Run YOLO to find Equations/Tables.
492
+ 2. Mask raw text with YOLO boxes.
493
+ 3. Run Column Detection on the MASKED data.
494
+ 4. Proceed with OCR (Native or High-Res Tesseract Fallback) and Output.
495
+ """
496
+ global GLOBAL_FIGURE_COUNT, GLOBAL_EQUATION_COUNT
497
+
498
+ start_time_total = time.time()
499
+
500
+ if original_img is None:
501
+ print(f" ❌ Invalid image for page {page_num}.")
502
+ return None, None
503
+
504
+ # ====================================================================
505
+ # --- STEP 1: YOLO DETECTION ---
506
+ # ====================================================================
507
+ start_time_yolo = time.time()
508
+ results = model.predict(source=original_img, conf=CONF_THRESHOLD, imgsz=640, verbose=False)
509
+
510
+ relevant_detections = []
511
+ if results and results[0].boxes:
512
+ for box in results[0].boxes:
513
+ class_id = int(box.cls[0])
514
+ class_name = model.names[class_id]
515
+ if class_name in TARGET_CLASSES:
516
+ x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
517
+ relevant_detections.append(
518
+ {'coords': (x1, y1, x2, y2), 'y1': y1, 'class': class_name, 'conf': float(box.conf[0])}
519
+ )
520
+
521
+ merged_detections = merge_overlapping_boxes(relevant_detections, IOU_MERGE_THRESHOLD)
522
+ print(f" [LOG] YOLO found {len(merged_detections)} objects in {time.time() - start_time_yolo:.3f}s.")
523
+
524
+ # ====================================================================
525
+ # --- STEP 2: PREPARE DATA FOR COLUMN DETECTION (MASKING) ---
526
+ # ====================================================================
527
+ # Note: This uses the updated 'get_word_data_for_detection' which has its own optimizations
528
+ raw_words_for_layout = get_word_data_for_detection(
529
+ fitz_page, pdf_path, page_num,
530
+ top_margin_percent=0.10, bottom_margin_percent=0.10
531
+ )
532
+
533
+ masked_word_data = merge_yolo_into_word_data(raw_words_for_layout, merged_detections, scale_factor=2.0)
534
+
535
+ # ====================================================================
536
+ # --- STEP 3: COLUMN DETECTION ---
537
+ # ====================================================================
538
+ page_width_pdf = fitz_page.rect.width
539
+ page_height_pdf = fitz_page.rect.height
540
+
541
+ column_detection_params = {
542
+ 'cluster_bin_size': 2, 'cluster_smoothing': 2,
543
+ 'cluster_min_width': 10, 'cluster_threshold_percentile': 85,
544
+ }
545
+
546
+ separators = calculate_x_gutters(masked_word_data, column_detection_params, page_height_pdf)
547
+
548
+ page_separator_x = None
549
+ if separators:
550
+ central_min = page_width_pdf * 0.35
551
+ central_max = page_width_pdf * 0.65
552
+ central_separators = [s for s in separators if central_min <= s <= central_max]
553
+
554
+ if central_separators:
555
+ center_x = page_width_pdf / 2
556
+ page_separator_x = min(central_separators, key=lambda x: abs(x - center_x))
557
+ print(f" βœ… Column Split Confirmed at X={page_separator_x:.1f}")
558
+ else:
559
+ print(" ⚠️ Gutter found off-center. Ignoring.")
560
+ else:
561
+ print(" -> Single Column Layout Confirmed.")
562
+
563
+ # ====================================================================
564
+ # --- STEP 4: COMPONENT EXTRACTION (Save Images) ---
565
+ # ====================================================================
566
+ start_time_components = time.time()
567
+ component_metadata = []
568
+ fig_count_page = 0
569
+ eq_count_page = 0
570
+
571
+ for detection in merged_detections:
572
+ x1, y1, x2, y2 = detection['coords']
573
+ class_name = detection['class']
574
+
575
+ if class_name == 'figure':
576
+ GLOBAL_FIGURE_COUNT += 1
577
+ counter = GLOBAL_FIGURE_COUNT
578
+ component_word = f"FIGURE{counter}"
579
+ fig_count_page += 1
580
+ elif class_name == 'equation':
581
+ GLOBAL_EQUATION_COUNT += 1
582
+ counter = GLOBAL_EQUATION_COUNT
583
+ component_word = f"EQUATION{counter}"
584
+ eq_count_page += 1
585
+ else:
586
+ continue
587
+
588
+ component_crop = original_img[y1:y2, x1:x2]
589
+ component_filename = f"{pdf_name}_page{page_num}_{class_name}{counter}.png"
590
+ cv2.imwrite(os.path.join(FIGURE_EXTRACTION_DIR, component_filename), component_crop)
591
+
592
+ y_midpoint = (y1 + y2) // 2
593
+ component_metadata.append({
594
+ 'type': class_name, 'word': component_word,
595
+ 'bbox': [int(x1), int(y1), int(x2), int(y2)],
596
+ 'y0': int(y_midpoint), 'x0': int(x1)
597
+ })
598
+
599
+ # ====================================================================
600
+ # --- STEP 5: HYBRID OCR (Native Text + Cached Tesseract Fallback) ---
601
+ # ====================================================================
602
+ raw_ocr_output = []
603
+ scale_factor = 2.0 # Pipeline standard scale
604
+
605
+ try:
606
+ # Try getting native text first
607
+ raw_ocr_output = extract_native_words_and_convert(fitz_page, scale_factor=scale_factor)
608
+ except Exception as e:
609
+ print(f" ❌ Native text extraction failed: {e}")
610
+
611
+ # If native text is missing, fall back to OCR
612
+ if not raw_ocr_output:
613
+ if _ocr_cache.has_ocr(pdf_path, page_num):
614
+ print(f" ⚑ Using cached Tesseract OCR for page {page_num}")
615
+ cached_word_data = _ocr_cache.get_ocr(pdf_path, page_num)
616
+ for word_tuple in cached_word_data:
617
+ word_text, x1, y1, x2, y2 = word_tuple
618
+
619
+ # Scale from PDF points to Pipeline Pixels (2.0)
620
+ x1_pix = int(x1 * scale_factor)
621
+ y1_pix = int(y1 * scale_factor)
622
+ x2_pix = int(x2 * scale_factor)
623
+ y2_pix = int(y2 * scale_factor)
624
+
625
+ raw_ocr_output.append({
626
+ 'type': 'text', 'word': word_text, 'confidence': 95.0,
627
+ 'bbox': [x1_pix, y1_pix, x2_pix, y2_pix],
628
+ 'y0': y1_pix, 'x0': x1_pix
629
+ })
630
+ else:
631
+ # === START OF OPTIMIZED OCR BLOCK ===
632
+ try:
633
+ # 1. Re-render Page at High Resolution (Zoom 4.0 = ~300 DPI)
634
+ # We do this specifically for OCR accuracy, separate from the pipeline image
635
+ ocr_zoom = 4.0
636
+ pix_ocr = fitz_page.get_pixmap(matrix=fitz.Matrix(ocr_zoom, ocr_zoom))
637
+
638
+ # Convert PyMuPDF Pixmap to OpenCV format
639
+ img_ocr_np = np.frombuffer(pix_ocr.samples, dtype=np.uint8).reshape(pix_ocr.height, pix_ocr.width,
640
+ pix_ocr.n)
641
+ if pix_ocr.n == 3:
642
+ img_ocr_np = cv2.cvtColor(img_ocr_np, cv2.COLOR_RGB2BGR)
643
+ elif pix_ocr.n == 4:
644
+ img_ocr_np = cv2.cvtColor(img_ocr_np, cv2.COLOR_RGBA2BGR)
645
+
646
+ # 2. Preprocess (Binarization)
647
+ # Ensure 'preprocess_image_for_ocr' is defined at top of file!
648
+ processed_img = preprocess_image_for_ocr(img_ocr_np)
649
+
650
+ # 3. Run Tesseract with Optimized Configuration
651
+ # --oem 3: Default LSTM engine
652
+ # --psm 6: Assume a single uniform block of text (Critical for lists/questions)
653
+ custom_config = r'--oem 3 --psm 6'
654
+
655
+ hocr_data = pytesseract.image_to_data(
656
+ processed_img,
657
+ output_type=pytesseract.Output.DICT,
658
+ config=custom_config
659
+ )
660
+
661
+ for i in range(len(hocr_data['level'])):
662
+ text = hocr_data['text'][i].strip()
663
+ if text and hocr_data['conf'][i] > -1:
664
+ # 4. Coordinate Mapping
665
+ # We scanned at Zoom 4.0, but our pipeline expects Zoom 2.0.
666
+ # Scale Factor = (Target 2.0) / (Source 4.0) = 0.5
667
+ scale_adjustment = scale_factor / ocr_zoom
668
+
669
+ x1 = int(hocr_data['left'][i] * scale_adjustment)
670
+ y1 = int(hocr_data['top'][i] * scale_adjustment)
671
+ w = int(hocr_data['width'][i] * scale_adjustment)
672
+ h = int(hocr_data['height'][i] * scale_adjustment)
673
+ x2 = x1 + w
674
+ y2 = y1 + h
675
+
676
+ raw_ocr_output.append({
677
+ 'type': 'text',
678
+ 'word': text,
679
+ 'confidence': float(hocr_data['conf'][i]),
680
+ 'bbox': [x1, y1, x2, y2],
681
+ 'y0': y1,
682
+ 'x0': x1
683
+ })
684
+ except Exception as e:
685
+ print(f" ❌ Tesseract OCR Error: {e}")
686
+ # === END OF OPTIMIZED OCR BLOCK ===
687
+
688
+ # ====================================================================
689
+ # --- STEP 6: OCR CLEANING AND MERGING ---
690
+ # ====================================================================
691
+ items_to_sort = []
692
+
693
+ for ocr_word in raw_ocr_output:
694
+ is_suppressed = False
695
+ for component in component_metadata:
696
+ # Do not include words that are inside figure/equation boxes
697
+ ioa = calculate_ioa(ocr_word['bbox'], component['bbox'])
698
+ if ioa > IOA_SUPPRESSION_THRESHOLD:
699
+ is_suppressed = True
700
+ break
701
+ if not is_suppressed:
702
+ items_to_sort.append(ocr_word)
703
+
704
+ # Add figures/equations back into the flow as "words"
705
+ items_to_sort.extend(component_metadata)
706
+
707
+ # ====================================================================
708
+ # --- STEP 7: LINE-BASED SORTING ---
709
+ # ====================================================================
710
+ items_to_sort.sort(key=lambda x: (x['y0'], x['x0']))
711
+ lines = []
712
+
713
+ for item in items_to_sort:
714
+ placed = False
715
+ for line in lines:
716
+ y_ref = min(it['y0'] for it in line)
717
+ if abs(y_ref - item['y0']) < LINE_TOLERANCE:
718
+ line.append(item)
719
+ placed = True
720
+ break
721
+ if not placed and item['type'] in ['equation', 'figure']:
722
+ for line in lines:
723
+ y_ref = min(it['y0'] for it in line)
724
+ if abs(y_ref - item['y0']) < 20:
725
+ line.append(item)
726
+ placed = True
727
+ break
728
+ if not placed:
729
+ lines.append([item])
730
+
731
+ for line in lines:
732
+ line.sort(key=lambda x: x['x0'])
733
+
734
+ final_output = []
735
+ for line in lines:
736
+ for item in line:
737
+ data_item = {"word": item["word"], "bbox": item["bbox"], "type": item["type"]}
738
+ if 'tag' in item: data_item['tag'] = item['tag']
739
+ final_output.append(data_item)
740
+
741
+ return final_output, page_separator_x
742
+
743
+
744
+ def run_single_pdf_preprocessing(pdf_path: str, preprocessed_json_path: str) -> Optional[str]:
745
+ global GLOBAL_FIGURE_COUNT, GLOBAL_EQUATION_COUNT
746
+
747
+ GLOBAL_FIGURE_COUNT = 0
748
+ GLOBAL_EQUATION_COUNT = 0
749
+ _ocr_cache.clear()
750
+
751
+ print("\n" + "=" * 80)
752
+ print("--- 1. STARTING OPTIMIZED YOLO/OCR PREPROCESSING PIPELINE ---")
753
+ print("=" * 80)
754
+
755
+ if not os.path.exists(pdf_path):
756
+ print(f"❌ FATAL ERROR: Input PDF not found at {pdf_path}.")
757
+ return None
758
+
759
+ os.makedirs(os.path.dirname(preprocessed_json_path), exist_ok=True)
760
+ os.makedirs(FIGURE_EXTRACTION_DIR, exist_ok=True)
761
+
762
+ model = YOLO(WEIGHTS_PATH)
763
+ pdf_name = os.path.splitext(os.path.basename(pdf_path))[0]
764
+
765
+ try:
766
+ doc = fitz.open(pdf_path)
767
+ print(f"βœ… Opened PDF: {pdf_name} ({doc.page_count} pages)")
768
+ except Exception as e:
769
+ print(f"❌ ERROR loading PDF file: {e}")
770
+ return None
771
+
772
+ all_pages_data = []
773
+ total_pages_processed = 0
774
+ mat = fitz.Matrix(2.0, 2.0)
775
+
776
+ print("\n[STEP 1.2: ITERATING PAGES - IN-MEMORY PROCESSING]")
777
+
778
+ for page_num_0_based in range(doc.page_count):
779
+ page_num = page_num_0_based + 1
780
+ print(f" -> Processing Page {page_num}/{doc.page_count}...")
781
+
782
+ fitz_page = doc.load_page(page_num_0_based)
783
+
784
+ try:
785
+ pix = fitz_page.get_pixmap(matrix=mat)
786
+ original_img = pixmap_to_numpy(pix)
787
+ except Exception as e:
788
+ print(f" ❌ Error converting page {page_num} to image: {e}")
789
+ continue
790
+
791
+ final_output, page_separator_x = preprocess_and_ocr_page(
792
+ original_img,
793
+ model,
794
+ pdf_path,
795
+ page_num,
796
+ fitz_page,
797
+ pdf_name
798
+ )
799
+
800
+ if final_output is not None:
801
+ page_data = {
802
+ "page_number": page_num,
803
+ "data": final_output,
804
+ "column_separator_x": page_separator_x
805
+ }
806
+ all_pages_data.append(page_data)
807
+ total_pages_processed += 1
808
+ else:
809
+ print(f" ❌ Skipped page {page_num} due to processing error.")
810
+
811
+ doc.close()
812
+
813
+ if all_pages_data:
814
+ try:
815
+ with open(preprocessed_json_path, 'w') as f:
816
+ json.dump(all_pages_data, f, indent=4)
817
+ print(f"\n βœ… Combined structured OCR JSON saved to: {os.path.basename(preprocessed_json_path)}")
818
+ except Exception as e:
819
+ print(f"❌ ERROR saving combined JSON output: {e}")
820
+ return None
821
+ else:
822
+ print("❌ WARNING: No page data generated. Halting pipeline.")
823
+ return None
824
+
825
+ print("\n" + "=" * 80)
826
+ print(f"--- YOLO/OCR PREPROCESSING COMPLETE ({total_pages_processed} pages processed) ---")
827
+ print("=" * 80)
828
+
829
+ return preprocessed_json_path
830
+
831
+
832
+ # ============================================================================
833
+ # --- PHASE 2: LAYOUTLMV3 INFERENCE FUNCTIONS ---
834
+ # ============================================================================
835
+
836
+ class LayoutLMv3ForTokenClassification(nn.Module):
837
+ def __init__(self, num_labels: int = NUM_LABELS):
838
+ super().__init__()
839
+ self.num_labels = num_labels
840
+ config = LayoutLMv3Config.from_pretrained("microsoft/layoutlmv3-base", num_labels=num_labels)
841
+ self.layoutlmv3 = LayoutLMv3Model.from_pretrained("microsoft/layoutlmv3-base", config=config)
842
+ self.classifier = nn.Linear(config.hidden_size, num_labels)
843
+ self.crf = CRF(num_labels)
844
+ self.init_weights()
845
+
846
+ def init_weights(self):
847
+ nn.init.xavier_uniform_(self.classifier.weight)
848
+ if self.classifier.bias is not None: nn.init.zeros_(self.classifier.bias)
849
+
850
+ def forward(self, input_ids: torch.Tensor, bbox: torch.Tensor, attention_mask: torch.Tensor,
851
+ labels: Optional[torch.Tensor] = None):
852
+ outputs = self.layoutlmv3(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, return_dict=True)
853
+ sequence_output = outputs.last_hidden_state
854
+ emissions = self.classifier(sequence_output)
855
+ mask = attention_mask.bool()
856
+ if labels is not None:
857
+ loss = -self.crf(emissions, labels, mask=mask).mean()
858
+ return loss
859
+ else:
860
+ return self.crf.viterbi_decode(emissions, mask=mask)
861
+
862
+
863
+ def _merge_integrity(all_token_data: List[Dict[str, Any]],
864
+ column_separator_x: Optional[int]) -> List[List[Dict[str, Any]]]:
865
+ """Splits the token data objects into column chunks based on a separator."""
866
+ if column_separator_x is None:
867
+ print(" -> No column separator. Treating as one chunk.")
868
+ return [all_token_data]
869
+
870
+ left_column_tokens, right_column_tokens = [], []
871
+ for token_data in all_token_data:
872
+ bbox_raw = token_data['bbox_raw_pdf_space']
873
+ center_x = (bbox_raw[0] + bbox_raw[2]) / 2
874
+ if center_x < column_separator_x:
875
+ left_column_tokens.append(token_data)
876
+ else:
877
+ right_column_tokens.append(token_data)
878
+
879
+ chunks = [c for c in [left_column_tokens, right_column_tokens] if c]
880
+ print(f" -> Data split into {len(chunks)} column chunk(s) using separator X={column_separator_x}.")
881
+ return chunks
882
+
883
+
884
+ def run_inference_and_get_raw_words(pdf_path: str, model_path: str,
885
+ preprocessed_json_path: str,
886
+ column_detection_params: Optional[Dict] = None) -> List[Dict[str, Any]]:
887
+ print("\n" + "=" * 80)
888
+ print("--- 2. STARTING LAYOUTLMV3 INFERENCE PIPELINE (Raw Word Output) ---")
889
+ print("=" * 80)
890
+
891
+ tokenizer = LayoutLMv3TokenizerFast.from_pretrained("microsoft/layoutlmv3-base")
892
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
893
+ print(f" -> Using device: {device}")
894
+
895
+ try:
896
+ model = LayoutLMv3ForTokenClassification(num_labels=NUM_LABELS)
897
+ checkpoint = torch.load(model_path, map_location=device)
898
+ model_state = checkpoint.get('model_state_dict', checkpoint)
899
+ fixed_state_dict = {key.replace('layoutlm.', 'layoutlmv3.'): value for key, value in model_state.items()}
900
+ model.load_state_dict(fixed_state_dict)
901
+ model.to(device)
902
+ model.eval()
903
+ print(f"βœ… LayoutLMv3 Model loaded successfully from {os.path.basename(model_path)}.")
904
+ except Exception as e:
905
+ print(f"❌ FATAL ERROR during LayoutLMv3 model loading: {e}")
906
+ return []
907
+
908
+ try:
909
+ with open(preprocessed_json_path, 'r', encoding='utf-8') as f:
910
+ preprocessed_data = json.load(f)
911
+ print(f"βœ… Loaded preprocessed data with {len(preprocessed_data)} pages.")
912
+ except Exception:
913
+ print("❌ Error loading preprocessed JSON.")
914
+ return []
915
+
916
+ try:
917
+ doc = fitz.open(pdf_path)
918
+ except Exception:
919
+ print("❌ Error loading PDF.")
920
+ return []
921
+
922
+ final_page_predictions = []
923
+ CHUNK_SIZE = 500
924
+
925
+ for page_data in preprocessed_data:
926
+ page_num_1_based = page_data['page_number']
927
+ page_num_0_based = page_num_1_based - 1
928
+ page_raw_predictions = []
929
+ print(f"\n *** Processing Page {page_num_1_based} ({len(page_data['data'])} raw tokens) ***")
930
+
931
+ fitz_page = doc.load_page(page_num_0_based)
932
+ page_width, page_height = fitz_page.rect.width, fitz_page.rect.height
933
+ print(f" -> Page dimensions: {page_width:.0f}x{page_height:.0f} (PDF points).")
934
+
935
+ all_token_data = []
936
+ scale_factor = 2.0
937
+
938
+ for item in page_data['data']:
939
+ raw_yolo_bbox = item['bbox']
940
+ bbox_pdf = [
941
+ int(raw_yolo_bbox[0] / scale_factor), int(raw_yolo_bbox[1] / scale_factor),
942
+ int(raw_yolo_bbox[2] / scale_factor), int(raw_yolo_bbox[3] / scale_factor)
943
+ ]
944
+ normalized_bbox = [
945
+ max(0, min(1000, int(1000 * bbox_pdf[0] / page_width))),
946
+ max(0, min(1000, int(1000 * bbox_pdf[1] / page_height))),
947
+ max(0, min(1000, int(1000 * bbox_pdf[2] / page_width))),
948
+ max(0, min(1000, int(1000 * bbox_pdf[3] / page_height)))
949
+ ]
950
+ all_token_data.append({
951
+ "word": item['word'],
952
+ "bbox_raw_pdf_space": bbox_pdf,
953
+ "bbox_normalized": normalized_bbox,
954
+ "item_original_data": item
955
+ })
956
+
957
+ if not all_token_data: continue
958
+
959
+ column_separator_x = page_data.get('column_separator_x', None)
960
+ if column_separator_x is not None:
961
+ print(f" -> Using SAVED column separator: X={column_separator_x}")
962
+ else:
963
+ print(" -> No column separator found. Assuming single chunk.")
964
+
965
+ token_chunks = _merge_integrity(all_token_data, column_separator_x)
966
+ total_chunks = len(token_chunks)
967
+
968
+ for chunk_idx, chunk_tokens in enumerate(token_chunks):
969
+ if not chunk_tokens: continue
970
+
971
+ chunk_words = [t['word'] for t in chunk_tokens]
972
+ chunk_normalized_bboxes = [t['bbox_normalized'] for t in chunk_tokens]
973
+
974
+ total_sub_chunks = (len(chunk_words) + CHUNK_SIZE - 1) // CHUNK_SIZE
975
+ for i in range(0, len(chunk_words), CHUNK_SIZE):
976
+ sub_chunk_idx = i // CHUNK_SIZE + 1
977
+ sub_words = chunk_words[i:i + CHUNK_SIZE]
978
+ sub_bboxes = chunk_normalized_bboxes[i:i + CHUNK_SIZE]
979
+ sub_tokens_data = chunk_tokens[i:i + CHUNK_SIZE]
980
+
981
+ print(
982
+ f" -> Chunk {chunk_idx + 1}/{total_chunks}, Sub-chunk {sub_chunk_idx}/{total_sub_chunks}: {len(sub_words)} words. Running Inference...")
983
+
984
+ encoded_input = tokenizer(
985
+ sub_words, boxes=sub_bboxes, truncation=True, padding="max_length",
986
+ max_length=512, return_tensors="pt"
987
+ )
988
+ input_ids = encoded_input['input_ids'].to(device)
989
+ bbox = encoded_input['bbox'].to(device)
990
+ attention_mask = encoded_input['attention_mask'].to(device)
991
+
992
+ with torch.no_grad():
993
+ predictions_int_list = model(input_ids, bbox, attention_mask)
994
+
995
+ if not predictions_int_list: continue
996
+ predictions_int = predictions_int_list[0]
997
+ word_ids = encoded_input.word_ids()
998
+ word_idx_to_pred_id = {}
999
+
1000
+ for token_idx, word_idx in enumerate(word_ids):
1001
+ if word_idx is not None and word_idx < len(sub_words):
1002
+ if word_idx not in word_idx_to_pred_id:
1003
+ word_idx_to_pred_id[word_idx] = predictions_int[token_idx]
1004
+
1005
+ for current_word_idx in range(len(sub_words)):
1006
+ pred_id_or_tensor = word_idx_to_pred_id.get(current_word_idx, 0)
1007
+ pred_id = pred_id_or_tensor.item() if torch.is_tensor(pred_id_or_tensor) else pred_id_or_tensor
1008
+ predicted_label = ID_TO_LABEL[pred_id]
1009
+ original_token = sub_tokens_data[current_word_idx]
1010
+ page_raw_predictions.append({
1011
+ "word": original_token['word'],
1012
+ "bbox": original_token['bbox_raw_pdf_space'],
1013
+ "predicted_label": predicted_label,
1014
+ "page_number": page_num_1_based
1015
+ })
1016
+
1017
+ if page_raw_predictions:
1018
+ final_page_predictions.append({
1019
+ "page_number": page_num_1_based,
1020
+ "data": page_raw_predictions
1021
+ })
1022
+ print(f" *** Page {page_num_1_based} Finalized: {len(page_raw_predictions)} labeled words. ***")
1023
+
1024
+ doc.close()
1025
+ print("\n" + "=" * 80)
1026
+ print("--- LAYOUTLMV3 INFERENCE COMPLETE ---")
1027
+ print("=" * 80)
1028
+ return final_page_predictions
1029
+
1030
+
1031
+ def create_label_studio_span(page_results, start_idx, end_idx, label):
1032
+ entity_words = [page_results[i]['word'] for i in range(start_idx, end_idx + 1)]
1033
+ entity_bboxes = [page_results[i]['bbox'] for i in range(start_idx, end_idx + 1)]
1034
+ x0 = min(bbox[0] for bbox in entity_bboxes)
1035
+ y0 = min(bbox[1] for bbox in entity_bboxes)
1036
+ x1 = max(bbox[2] for bbox in entity_bboxes)
1037
+ y1 = max(bbox[3] for bbox in entity_bboxes)
1038
+ all_words_on_page = [r['word'] for r in page_results]
1039
+ start_char = len(" ".join(all_words_on_page[:start_idx]))
1040
+ if start_idx != 0: start_char += 1
1041
+ end_char = start_char + len(" ".join(entity_words))
1042
+ span_text = " ".join(entity_words)
1043
+ return {
1044
+ "from_name": "label", "to_name": "text", "type": "labels",
1045
+ "value": {
1046
+ "start": start_char, "end": end_char, "text": span_text,
1047
+ "labels": [label],
1048
+ "bbox": {"x": x0, "y": y0, "width": x1 - x0, "height": y1 - y0}
1049
+ }, "score": 0.99
1050
+ }
1051
+
1052
+
1053
+ def convert_raw_predictions_to_label_studio(page_data_list, output_path: str):
1054
+ final_tasks = []
1055
+ print("\n[PHASE: LABEL STUDIO CONVERSION]")
1056
+ for page_data in page_data_list:
1057
+ page_num = page_data['page_number']
1058
+ page_results = page_data['data']
1059
+ if not page_results: continue
1060
+ original_words = [r['word'] for r in page_results]
1061
+ text_string = " ".join(original_words)
1062
+ results = []
1063
+ current_entity_label = None
1064
+ current_entity_start_word_index = None
1065
+
1066
+ for i, pred_item in enumerate(page_results):
1067
+ label = pred_item['predicted_label']
1068
+ tag_only = label.split('-', 1)[-1] if '-' in label else label
1069
+ if label.startswith('B-'):
1070
+ if current_entity_label:
1071
+ results.append(create_label_studio_span(page_results, current_entity_start_word_index, i - 1,
1072
+ current_entity_label))
1073
+ current_entity_label = tag_only
1074
+ current_entity_start_word_index = i
1075
+ elif label.startswith('I-') and current_entity_label == tag_only:
1076
+ continue
1077
+ else:
1078
+ if current_entity_label:
1079
+ results.append(create_label_studio_span(page_results, current_entity_start_word_index, i - 1,
1080
+ current_entity_label))
1081
+ current_entity_label = None
1082
+ current_entity_start_word_index = None
1083
+ if current_entity_label:
1084
+ results.append(
1085
+ create_label_studio_span(page_results, current_entity_start_word_index, len(page_results) - 1,
1086
+ current_entity_label))
1087
+
1088
+ final_tasks.append({
1089
+ "data": {
1090
+ "text": text_string, "original_words": original_words,
1091
+ "original_bboxes": [r['bbox'] for r in page_results]
1092
+ },
1093
+ "annotations": [{"result": results}],
1094
+ "meta": {"page_number": page_num}
1095
+ })
1096
+ with open(output_path, "w", encoding='utf-8') as f:
1097
+ json.dump(final_tasks, f, indent=2, ensure_ascii=False)
1098
+ print(f"\nβœ… Label Studio tasks saved to {output_path}.")
1099
+
1100
+
1101
+ # ============================================================================
1102
+ # --- PHASE 3: BIO TO STRUCTURED JSON DECODER ---
1103
+ # ============================================================================
1104
+
1105
+
1106
+ def convert_bio_to_structured_json_relaxed(input_path: str, output_path: str) -> Optional[List[Dict[str, Any]]]:
1107
+ print("\n" + "=" * 80)
1108
+ print("--- 3. STARTING BIO TO STRUCTURED JSON DECODING ---")
1109
+ print("=" * 80)
1110
+ try:
1111
+ with open(input_path, 'r', encoding='utf-8') as f:
1112
+ predictions_by_page = json.load(f)
1113
+ except Exception as e:
1114
+ print(f"❌ Error loading raw prediction file: {e}")
1115
+ return None
1116
+
1117
+ predictions = []
1118
+ for page_item in predictions_by_page:
1119
+ if isinstance(page_item, dict) and 'data' in page_item:
1120
+ predictions.extend(page_item['data'])
1121
+
1122
+ structured_data = []
1123
+ current_item = None
1124
+ current_option_key = None
1125
+ current_passage_buffer = []
1126
+ current_text_buffer = []
1127
+ first_question_started = False
1128
+ last_entity_type = None
1129
+ just_finished_i_option = False
1130
+ is_in_new_passage = False
1131
+
1132
+ def finalize_passage_to_item(item, passage_buffer):
1133
+ if passage_buffer:
1134
+ passage_text = re.sub(r'\s{2,}', ' ', ' '.join(passage_buffer)).strip()
1135
+ if item.get('passage'):
1136
+ item['passage'] += ' ' + passage_text
1137
+ else:
1138
+ item['passage'] = passage_text
1139
+ passage_buffer.clear()
1140
+
1141
+ for item in predictions:
1142
+ word = item['word']
1143
+ label = item['predicted_label']
1144
+ entity_type = label[2:].strip() if label.startswith(('B-', 'I-')) else None
1145
+ current_text_buffer.append(word)
1146
+ previous_entity_type = last_entity_type
1147
+ is_passage_label = (entity_type == 'PASSAGE')
1148
+
1149
+ if not first_question_started:
1150
+ if label != 'B-QUESTION' and not is_passage_label:
1151
+ just_finished_i_option = False
1152
+ is_in_new_passage = False
1153
+ continue
1154
+ if is_passage_label:
1155
+ current_passage_buffer.append(word)
1156
+ last_entity_type = 'PASSAGE'
1157
+ just_finished_i_option = False
1158
+ is_in_new_passage = False
1159
+ continue
1160
+
1161
+ if label == 'B-QUESTION':
1162
+ if not first_question_started:
1163
+ header_text = ' '.join(current_text_buffer[:-1]).strip()
1164
+ if header_text or current_passage_buffer:
1165
+ metadata_item = {'type': 'METADATA', 'passage': ''}
1166
+ finalize_passage_to_item(metadata_item, current_passage_buffer)
1167
+ if header_text: metadata_item['text'] = header_text
1168
+ structured_data.append(metadata_item)
1169
+ first_question_started = True
1170
+ current_text_buffer = [word]
1171
+
1172
+ if current_item is not None:
1173
+ finalize_passage_to_item(current_item, current_passage_buffer)
1174
+ current_item['text'] = ' '.join(current_text_buffer[:-1]).strip()
1175
+ structured_data.append(current_item)
1176
+ current_text_buffer = [word]
1177
+
1178
+ current_item = {
1179
+ 'question': word, 'options': {}, 'answer': '', 'passage': '', 'text': ''
1180
+ }
1181
+ current_option_key = None
1182
+ last_entity_type = 'QUESTION'
1183
+ just_finished_i_option = False
1184
+ is_in_new_passage = False
1185
+ continue
1186
+
1187
+ if current_item is not None:
1188
+ if is_in_new_passage:
1189
+ # πŸ”‘ Robust Initialization and Appending for 'new_passage'
1190
+ if 'new_passage' not in current_item:
1191
+ current_item['new_passage'] = word
1192
+ else:
1193
+ current_item['new_passage'] += f' {word}'
1194
+
1195
+ if label.startswith('B-') or (label.startswith('I-') and entity_type != 'PASSAGE'):
1196
+ is_in_new_passage = False
1197
+ if label.startswith(('B-', 'I-')): last_entity_type = entity_type
1198
+ continue
1199
+ is_in_new_passage = False
1200
+
1201
+ if label.startswith('B-'):
1202
+ if entity_type in ['QUESTION', 'OPTION', 'ANSWER', 'SECTION_HEADING']:
1203
+ finalize_passage_to_item(current_item, current_passage_buffer)
1204
+ current_passage_buffer = []
1205
+ last_entity_type = entity_type
1206
+ if entity_type == 'PASSAGE':
1207
+ if previous_entity_type == 'OPTION' and just_finished_i_option:
1208
+ current_item['new_passage'] = word # Initialize the new passage start
1209
+ is_in_new_passage = True
1210
+ else:
1211
+ current_passage_buffer.append(word)
1212
+ elif entity_type == 'OPTION':
1213
+ current_option_key = word
1214
+ current_item['options'][current_option_key] = word
1215
+ just_finished_i_option = False
1216
+ elif entity_type == 'ANSWER':
1217
+ current_item['answer'] = word
1218
+ current_option_key = None
1219
+ just_finished_i_option = False
1220
+ elif entity_type == 'QUESTION':
1221
+ current_item['question'] += f' {word}'
1222
+ just_finished_i_option = False
1223
+
1224
+ elif label.startswith('I-'):
1225
+ if entity_type == 'QUESTION':
1226
+ current_item['question'] += f' {word}'
1227
+ elif entity_type == 'PASSAGE':
1228
+ if previous_entity_type == 'OPTION' and just_finished_i_option:
1229
+ current_item['new_passage'] = word # Initialize the new passage start
1230
+ is_in_new_passage = True
1231
+ else:
1232
+ if not current_passage_buffer: last_entity_type = 'PASSAGE'
1233
+ current_passage_buffer.append(word)
1234
+ elif entity_type == 'OPTION' and current_option_key is not None:
1235
+ current_item['options'][current_option_key] += f' {word}'
1236
+ just_finished_i_option = True
1237
+ elif entity_type == 'ANSWER':
1238
+ current_item['answer'] += f' {word}'
1239
+ just_finished_i_option = (entity_type == 'OPTION')
1240
+
1241
+ elif label == 'O':
1242
+ if last_entity_type == 'QUESTION':
1243
+ current_item['question'] += f' {word}'
1244
+ just_finished_i_option = False
1245
+
1246
+ if current_item is not None:
1247
+ finalize_passage_to_item(current_item, current_passage_buffer)
1248
+ current_item['text'] = ' '.join(current_text_buffer).strip()
1249
+ structured_data.append(current_item)
1250
+
1251
+ for item in structured_data:
1252
+ item['text'] = re.sub(r'\s{2,}', ' ', item['text']).strip()
1253
+ if 'new_passage' in item:
1254
+ item['new_passage'] = re.sub(r'\s{2,}', ' ', item['new_passage']).strip()
1255
+
1256
+ try:
1257
+ with open(output_path, 'w', encoding='utf-8') as f:
1258
+ json.dump(structured_data, f, indent=2, ensure_ascii=False)
1259
+ except Exception:
1260
+ pass
1261
+
1262
+ return structured_data
1263
+
1264
+
1265
+ def create_query_text(entry: Dict[str, Any]) -> str:
1266
+ """Combines question and options into a single string for similarity matching."""
1267
+ query_parts = []
1268
+ if entry.get("question"):
1269
+ query_parts.append(entry["question"])
1270
+
1271
+ for key in ["options", "options_text"]:
1272
+ options = entry.get(key)
1273
+ if options and isinstance(options, dict):
1274
+ for value in options.values():
1275
+ if value and isinstance(value, str):
1276
+ query_parts.append(value)
1277
+ return " ".join(query_parts)
1278
+
1279
+
1280
+ def calculate_similarity(doc1: str, doc2: str) -> float:
1281
+ """Calculates Cosine Similarity between two text strings."""
1282
+ if not doc1 or not doc2:
1283
+ return 0.0
1284
+
1285
+ def clean_text(text):
1286
+ return re.sub(r'^\s*[\(\d\w]+\.?\s*', '', text, flags=re.MULTILINE)
1287
+
1288
+ clean_doc1 = clean_text(doc1)
1289
+ clean_doc2 = clean_text(doc2)
1290
+ corpus = [clean_doc1, clean_doc2]
1291
+
1292
+ try:
1293
+ vectorizer = CountVectorizer(stop_words='english', lowercase=True, token_pattern=r'(?u)\b\w\w+\b')
1294
+ tfidf_matrix = vectorizer.fit_transform(corpus)
1295
+ if tfidf_matrix.shape[1] == 0:
1296
+ return 0.0
1297
+ vectors = tfidf_matrix.toarray()
1298
+ # Handle cases where vectors might be empty or too short
1299
+ if len(vectors) < 2:
1300
+ return 0.0
1301
+ score = cosine_similarity(vectors[0:1], vectors[1:2])[0][0]
1302
+ return score
1303
+ except Exception:
1304
+ return 0.0
1305
+
1306
+
1307
+ def process_context_linking(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
1308
+ """
1309
+ Links questions to passages based on 'passage' flow vs 'new_passage' priority.
1310
+ Includes 'Decay Logic': If 2 consecutive questions fail to match the active passage,
1311
+ the passage context is dropped to prevent false positives downstream.
1312
+ """
1313
+ print("\n" + "=" * 80)
1314
+ print("--- STARTING CONTEXT LINKING (WITH DECAY LOGIC) ---")
1315
+ print("=" * 80)
1316
+
1317
+ if not data: return []
1318
+
1319
+ # --- PHASE 1: IDENTIFY PASSAGE DEFINERS ---
1320
+ passage_definer_indices = []
1321
+ for i, entry in enumerate(data):
1322
+ if entry.get("passage") and entry["passage"].strip():
1323
+ passage_definer_indices.append(i)
1324
+ if entry.get("new_passage") and entry["new_passage"].strip():
1325
+ if i not in passage_definer_indices:
1326
+ passage_definer_indices.append(i)
1327
+
1328
+ # --- PHASE 2: CONTEXT TRANSFER & LINKING ---
1329
+ current_passage_text = None
1330
+ current_new_passage_text = None
1331
+
1332
+ # NEW: Counter to track consecutive linking failures
1333
+ consecutive_failures = 0
1334
+ MAX_CONSECUTIVE_FAILURES = 2
1335
+
1336
+ for i, entry in enumerate(data):
1337
+ item_type = entry.get("type", "Question")
1338
+
1339
+ # A. UNCONDITIONALLY UPDATE CONTEXTS (And Reset Decay Counter)
1340
+ if entry.get("passage") and entry["passage"].strip():
1341
+ current_passage_text = entry["passage"]
1342
+ consecutive_failures = 0 # Reset because we have fresh explicit context
1343
+ # print(f" [Flow] Updated Standard Context from Item {i}")
1344
+
1345
+ if entry.get("new_passage") and entry["new_passage"].strip():
1346
+ current_new_passage_text = entry["new_passage"]
1347
+ # We don't necessarily reset standard failures here as this is a local override
1348
+
1349
+ # B. QUESTION LINKING
1350
+ if entry.get("question") and item_type != "METADATA":
1351
+ combined_query = create_query_text(entry)
1352
+
1353
+ # Skip if query is too short (noise)
1354
+ if len(combined_query.strip()) < 5:
1355
+ continue
1356
+
1357
+ # Calculate scores
1358
+ score_old = calculate_similarity(current_passage_text, combined_query) if current_passage_text else 0.0
1359
+ score_new = calculate_similarity(current_new_passage_text,
1360
+ combined_query) if current_new_passage_text else 0.0
1361
+
1362
+ q_preview = entry['question'][:30] + '...'
1363
+
1364
+ # RESOLUTION LOGIC
1365
+ linked = False
1366
+
1367
+ # 1. Prefer New Passage if significantly better
1368
+ if current_new_passage_text and (score_new > score_old + RESOLUTION_MARGIN) and (
1369
+ score_new >= SIMILARITY_THRESHOLD):
1370
+ entry["passage"] = current_new_passage_text
1371
+ print(f" [Linker] πŸš€ Q{i} ('{q_preview}') -> NEW PASSAGE (Score: {score_new:.3f})")
1372
+ linked = True
1373
+ # Note: We do not reset 'consecutive_failures' for the standard passage here,
1374
+ # because we matched the *new* passage, not the standard one.
1375
+
1376
+ # 2. Otherwise use Standard Passage if it meets threshold
1377
+ elif current_passage_text and (score_old >= SIMILARITY_THRESHOLD):
1378
+ entry["passage"] = current_passage_text
1379
+ print(f" [Linker] βœ… Q{i} ('{q_preview}') -> STANDARD PASSAGE (Score: {score_old:.3f})")
1380
+ linked = True
1381
+ consecutive_failures = 0 # Success! Reset the kill switch.
1382
+
1383
+ if not linked:
1384
+ # 3. DECAY LOGIC
1385
+ if current_passage_text:
1386
+ consecutive_failures += 1
1387
+ print(
1388
+ f" [Linker] ⚠️ Q{i} NOT LINKED. (Failures: {consecutive_failures}/{MAX_CONSECUTIVE_FAILURES})")
1389
+
1390
+ if consecutive_failures >= MAX_CONSECUTIVE_FAILURES:
1391
+ print(f" [Linker] πŸ—‘οΈ Context dropped due to {consecutive_failures} consecutive misses.")
1392
+ current_passage_text = None
1393
+ consecutive_failures = 0
1394
+ else:
1395
+ print(f" [Linker] ⚠️ Q{i} NOT LINKED (No active context).")
1396
+
1397
+ # --- PHASE 3: CLEANUP AND INTERPOLATION ---
1398
+ print(" [Linker] Running Cleanup & Interpolation...")
1399
+
1400
+ # 3A. Self-Correction (Remove weak links)
1401
+ for i in passage_definer_indices:
1402
+ entry = data[i]
1403
+ if entry.get("question") and entry.get("type") != "METADATA":
1404
+ passage_to_check = entry.get("passage") or entry.get("new_passage")
1405
+ if passage_to_check:
1406
+ self_sim = calculate_similarity(passage_to_check, create_query_text(entry))
1407
+ if self_sim < SIMILARITY_THRESHOLD:
1408
+ entry["passage"] = ""
1409
+ if "new_passage" in entry: entry["new_passage"] = ""
1410
+ print(f" [Cleanup] Removed weak link for Q{i}")
1411
+
1412
+ # 3B. Interpolation (Fill gaps)
1413
+ # We only interpolate if the gap is strictly 1 question wide to avoid undoing the decay logic
1414
+ for i in range(1, len(data) - 1):
1415
+ current_entry = data[i]
1416
+ is_gap = current_entry.get("question") and not current_entry.get("passage")
1417
+ if is_gap:
1418
+ prev_p = data[i - 1].get("passage")
1419
+ next_p = data[i + 1].get("passage")
1420
+ if prev_p and next_p and (prev_p == next_p) and prev_p.strip():
1421
+ current_entry["passage"] = prev_p
1422
+ print(f" [Linker] πŸ₯ͺ Q{i} Interpolated from neighbors.")
1423
+
1424
+ return data
1425
+
1426
+
1427
+ def correct_misaligned_options(structured_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
1428
+ print("\n" + "=" * 80)
1429
+ print("--- 5. STARTING POST-PROCESSING: OPTION ALIGNMENT CORRECTION ---")
1430
+ print("=" * 80)
1431
+ tag_pattern = re.compile(r'(EQUATION\d+|FIGURE\d+)')
1432
+ corrected_count = 0
1433
+ for item in structured_data:
1434
+ if item.get('type') in ['METADATA']: continue
1435
+ options = item.get('options')
1436
+ if not options or len(options) < 2: continue
1437
+ option_keys = list(options.keys())
1438
+ for i in range(len(option_keys) - 1):
1439
+ current_key = option_keys[i]
1440
+ next_key = option_keys[i + 1]
1441
+ current_value = options[current_key].strip()
1442
+ next_value = options[next_key].strip()
1443
+ is_current_empty = current_value == current_key
1444
+ content_in_next = next_value.replace(next_key, '', 1).strip()
1445
+ tags_in_next = tag_pattern.findall(content_in_next)
1446
+ has_two_tags = len(tags_in_next) == 2
1447
+ if is_current_empty and has_two_tags:
1448
+ tag_to_move = tags_in_next[0]
1449
+ options[current_key] = f"{current_key} {tag_to_move}".strip()
1450
+ options[next_key] = f"{next_key} {tags_in_next[1]}".strip()
1451
+ corrected_count += 1
1452
+ print(f"βœ… Option alignment correction finished. Total corrections: {corrected_count}.")
1453
+ return structured_data
1454
+
1455
+
1456
+ # ============================================================================
1457
+ # --- PHASE 4: IMAGE EMBEDDING (Base64) ---
1458
+ # ============================================================================
1459
+
1460
+ def get_base64_for_file(filepath: str) -> str:
1461
+ try:
1462
+ with open(filepath, 'rb') as f:
1463
+ return base64.b64encode(f.read()).decode('utf-8')
1464
+ except Exception as e:
1465
+ print(f" ❌ Error encoding file {filepath}: {e}")
1466
+ return ""
1467
+
1468
+
1469
+ def embed_images_as_base64_in_memory(structured_data: List[Dict[str, Any]], figure_extraction_dir: str) -> List[
1470
+ Dict[str, Any]]:
1471
+ print("\n" + "=" * 80)
1472
+ print("--- 4. STARTING IMAGE EMBEDDING (Base64) ---")
1473
+ print("=" * 80)
1474
+ if not structured_data: return []
1475
+ image_files = glob.glob(os.path.join(figure_extraction_dir, "*.png"))
1476
+ image_lookup = {}
1477
+ tag_regex = re.compile(r'(figure|equation)(\d+)', re.IGNORECASE)
1478
+ for filepath in image_files:
1479
+ filename = os.path.basename(filepath)
1480
+ match = re.search(r'_(figure|equation)(\d+)\.png$', filename, re.IGNORECASE)
1481
+ if match:
1482
+ key = f"{match.group(1).upper()}{match.group(2)}"
1483
+ image_lookup[key] = filepath
1484
+ print(f" -> Found {len(image_lookup)} image components.")
1485
+ final_structured_data = []
1486
+ for item in structured_data:
1487
+ text_fields = [item.get('question', ''), item.get('passage', '')]
1488
+ if 'options' in item:
1489
+ for opt_val in item['options'].values(): text_fields.append(opt_val)
1490
+ if 'new_passage' in item: text_fields.append(item['new_passage'])
1491
+ unique_tags_to_embed = set()
1492
+ for text in text_fields:
1493
+ if not text: continue
1494
+ for match in tag_regex.finditer(text):
1495
+ tag = match.group(0).upper()
1496
+ if tag in image_lookup: unique_tags_to_embed.add(tag)
1497
+ for tag in sorted(list(unique_tags_to_embed)):
1498
+ filepath = image_lookup[tag]
1499
+ base64_code = get_base64_for_file(filepath)
1500
+ base_key = tag.replace(' ', '').lower()
1501
+ item[base_key] = base64_code
1502
+ final_structured_data.append(item)
1503
+ print(f"βœ… Image embedding complete.")
1504
+ return final_structured_data
1505
+
1506
+
1507
+ # ============================================================================
1508
+ # --- MAIN FUNCTION ---
1509
+ # ============================================================================
1510
+
1511
+ def run_document_pipeline(input_pdf_path: str, layoutlmv3_model_path: str, label_studio_output_path: str) -> Optional[
1512
+ List[Dict[str, Any]]]:
1513
+ if not os.path.exists(input_pdf_path): return None
1514
+
1515
+ print("\n" + "#" * 80)
1516
+ print("### STARTING OPTIMIZED FULL DOCUMENT ANALYSIS PIPELINE ###")
1517
+ print("#" * 80)
1518
+
1519
+ pdf_name = os.path.splitext(os.path.basename(input_pdf_path))[0]
1520
+ temp_pipeline_dir = os.path.join(tempfile.gettempdir(), f"pipeline_run_{pdf_name}_{os.getpid()}")
1521
+ os.makedirs(temp_pipeline_dir, exist_ok=True)
1522
+
1523
+ preprocessed_json_path = os.path.join(temp_pipeline_dir, f"{pdf_name}_preprocessed.json")
1524
+ raw_output_path = os.path.join(temp_pipeline_dir, f"{pdf_name}_raw_predictions.json")
1525
+ structured_intermediate_output_path = os.path.join(temp_pipeline_dir, f"{pdf_name}_structured_intermediate.json")
1526
+
1527
+ final_result = None
1528
+ try:
1529
+ # Phase 1: Preprocessing with YOLO First + Masking
1530
+ preprocessed_json_path_out = run_single_pdf_preprocessing(input_pdf_path, preprocessed_json_path)
1531
+ if not preprocessed_json_path_out: return None
1532
+
1533
+ # Phase 2: Inference
1534
+ page_raw_predictions_list = run_inference_and_get_raw_words(
1535
+ input_pdf_path, layoutlmv3_model_path, preprocessed_json_path_out
1536
+ )
1537
+ if not page_raw_predictions_list: return None
1538
+
1539
+ with open(raw_output_path, 'w', encoding='utf-8') as f:
1540
+ json.dump(page_raw_predictions_list, f, indent=4)
1541
+
1542
+ # Phase 3: Decoding
1543
+ structured_data_list = convert_bio_to_structured_json_relaxed(
1544
+ raw_output_path, structured_intermediate_output_path
1545
+ )
1546
+ if not structured_data_list: return None
1547
+ structured_data_list = correct_misaligned_options(structured_data_list)
1548
+ structured_data_list = process_context_linking(structured_data_list)
1549
+
1550
+ try:
1551
+ convert_raw_predictions_to_label_studio(page_raw_predictions_list, label_studio_output_path)
1552
+ except Exception as e:
1553
+ print(f"❌ Error during Label Studio conversion: {e}")
1554
+
1555
+ # Phase 4: Embedding
1556
+ final_result = embed_images_as_base64_in_memory(structured_data_list, FIGURE_EXTRACTION_DIR)
1557
+
1558
+ except Exception as e:
1559
+ print(f"❌ FATAL ERROR: {e}")
1560
+ import traceback
1561
+ traceback.print_exc()
1562
+ return None
1563
+
1564
+ finally:
1565
+ try:
1566
+ for f in glob.glob(os.path.join(temp_pipeline_dir, '*')):
1567
+ os.remove(f)
1568
+ os.rmdir(temp_pipeline_dir)
1569
+ except Exception:
1570
+ pass
1571
+
1572
+ print("\n" + "#" * 80)
1573
+ print("### OPTIMIZED PIPELINE EXECUTION COMPLETE ###")
1574
+ print("#" * 80)
1575
+ return final_result
1576
+
1577
+
1578
+ if __name__ == "__main__":
1579
+ parser = argparse.ArgumentParser(description="Complete Pipeline")
1580
+ parser.add_argument("--input_pdf", type=str, required=True, help="Input PDF")
1581
+ parser.add_argument("--layoutlmv3_model_path", type=str, default=DEFAULT_LAYOUTLMV3_MODEL_PATH, help="Model Path")
1582
+ parser.add_argument("--ls_output_path", type=str, default=None, help="Label Studio Output Path")
1583
+ args = parser.parse_args()
1584
+
1585
+ pdf_name = os.path.splitext(os.path.basename(args.input_pdf))[0]
1586
+ final_output_path = os.path.abspath(f"{pdf_name}_final_output_embedded.json")
1587
+ ls_output_path = os.path.abspath(
1588
+ args.ls_output_path if args.ls_output_path else f"{pdf_name}_label_studio_tasks.json")
1589
+
1590
+ final_json_data = run_document_pipeline(args.input_pdf, args.layoutlmv3_model_path, ls_output_path)
1591
+
1592
+ if final_json_data:
1593
+ with open(final_output_path, 'w', encoding='utf-8') as f:
1594
+ json.dump(final_json_data, f, indent=2, ensure_ascii=False)
1595
+ print(f"\nβœ… Final Data Saved: {final_output_path}")
1596
+ else:
1597
+ print("\n❌ Pipeline Failed.")
1598
+ sys.exit(1)