heerjtdev commited on
Commit
2e5b054
·
verified ·
1 Parent(s): 9c47ab0

Update working_yolo_pipeline.py

Browse files
Files changed (1) hide show
  1. working_yolo_pipeline.py +722 -33
working_yolo_pipeline.py CHANGED
@@ -92,6 +92,60 @@ def sanitize_text(text: Optional[str]) -> str:
92
 
93
 
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  def get_latex_from_base64(base64_string: str) -> str:
96
  """
97
  Decodes a Base64 image string and uses the pre-initialized TrOCR/ORT model
@@ -118,6 +172,12 @@ def get_latex_from_base64(base64_string: str) -> str:
118
  return "[OCR_WARNING: No formula found]"
119
 
120
  latex_string = raw_generated_text[0]
 
 
 
 
 
 
121
 
122
  # --- 4. Post-processing and Cleanup ---
123
 
@@ -580,8 +640,53 @@ def pixmap_to_numpy(pix: fitz.Pixmap) -> np.ndarray:
580
 
581
 
582
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
583
  def extract_native_words_and_convert(fitz_page, scale_factor: float = 2.0) -> list:
584
  raw_word_data = fitz_page.get_text("words")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
585
  converted_ocr_output = []
586
  DEFAULT_CONFIDENCE = 99.0
587
 
@@ -796,6 +901,275 @@ def post_process_json_with_inference(json_data, classifier):
796
 
797
 
798
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
799
  def preprocess_and_ocr_page(original_img: np.ndarray, model, pdf_path: str,
800
  page_num: int, fitz_page: fitz.Page,
801
  pdf_name: str) -> Tuple[List[Dict[str, Any]], Optional[int]]:
@@ -968,6 +1342,21 @@ def preprocess_and_ocr_page(original_img: np.ndarray, model, pdf_path: str,
968
  config=custom_config
969
  )
970
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
971
  for i in range(len(hocr_data['level'])):
972
  text = hocr_data['text'][i] # Retrieve raw Tesseract text
973
 
@@ -1053,6 +1442,12 @@ def preprocess_and_ocr_page(original_img: np.ndarray, model, pdf_path: str,
1053
  return final_output, page_separator_x
1054
 
1055
 
 
 
 
 
 
 
1056
  def run_single_pdf_preprocessing(pdf_path: str, preprocessed_json_path: str) -> Optional[str]:
1057
  global GLOBAL_FIGURE_COUNT, GLOBAL_EQUATION_COUNT
1058
 
@@ -1197,6 +1592,319 @@ def _merge_integrity(all_token_data: List[Dict[str, Any]],
1197
 
1198
 
1199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1200
  def run_inference_and_get_raw_words(pdf_path: str, model_path: str,
1201
  preprocessed_json_path: str,
1202
  column_detection_params: Optional[Dict] = None) -> List[Dict[str, Any]]:
@@ -1271,6 +1979,20 @@ def run_inference_and_get_raw_words(pdf_path: str, model_path: str,
1271
  "item_original_data": item
1272
  })
1273
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1274
  if not all_token_data:
1275
  continue
1276
 
@@ -1348,19 +2070,12 @@ def run_inference_and_get_raw_words(pdf_path: str, model_path: str,
1348
  model_outputs = model(input_ids, bbox, attention_mask)
1349
 
1350
  # --- Robust extraction: support several forward return types ---
1351
- # We'll try (in order):
1352
- # 1) model_outputs is (emissions_tensor, viterbi_list) -> use emissions for logits, keep decoded
1353
- # 2) model_outputs has .logits attribute (HF ModelOutput)
1354
- # 3) model_outputs is tuple/list containing a logits tensor
1355
- # 4) model_outputs is a tensor (assume logits)
1356
- # 5) model_outputs is a list-of-lists of ints (viterbi decoded) -> use that directly (no logits)
1357
  logits_tensor = None
1358
  decoded_labels_list = None
1359
 
1360
  # case 1: tuple/list with (emissions, viterbi)
1361
  if isinstance(model_outputs, (tuple, list)) and len(model_outputs) == 2:
1362
  a, b = model_outputs
1363
- # a might be tensor (emissions), b might be viterbi list
1364
  if isinstance(a, torch.Tensor):
1365
  logits_tensor = a
1366
  if isinstance(b, list):
@@ -1375,15 +2090,12 @@ def run_inference_and_get_raw_words(pdf_path: str, model_path: str,
1375
  found_tensor = None
1376
  for item in model_outputs:
1377
  if isinstance(item, torch.Tensor):
1378
- # prefer 3D (batch, seq, labels)
1379
  if item.dim() == 3:
1380
  logits_tensor = item
1381
  break
1382
  if found_tensor is None:
1383
  found_tensor = item
1384
  if logits_tensor is None and found_tensor is not None:
1385
- # found_tensor may be (batch, seq, hidden) or (seq, hidden); we avoid guessing.
1386
- # Keep found_tensor only if it matches num_labels dimension
1387
  if found_tensor.dim() == 3 and found_tensor.shape[-1] == NUM_LABELS:
1388
  logits_tensor = found_tensor
1389
  elif found_tensor.dim() == 2 and found_tensor.shape[-1] == NUM_LABELS:
@@ -1395,12 +2107,10 @@ def run_inference_and_get_raw_words(pdf_path: str, model_path: str,
1395
 
1396
  # case 5: model_outputs is a decoded viterbi list (common for CRF-only forward)
1397
  if decoded_labels_list is None and isinstance(model_outputs, list) and model_outputs and isinstance(model_outputs[0], list):
1398
- # assume model_outputs is already viterbi decoded: List[List[int]] with batch dim first
1399
  decoded_labels_list = model_outputs
1400
 
1401
  # If neither logits nor decoded exist, that's fatal
1402
  if logits_tensor is None and decoded_labels_list is None:
1403
- # helpful debug info
1404
  try:
1405
  elem_shapes = [ (type(x), getattr(x, 'shape', None)) for x in model_outputs ] if isinstance(model_outputs, (list, tuple)) else [(type(model_outputs), getattr(model_outputs, 'shape', None))]
1406
  except Exception:
@@ -1409,32 +2119,25 @@ def run_inference_and_get_raw_words(pdf_path: str, model_path: str,
1409
 
1410
  # If we have logits_tensor, normalize shape to [seq_len, num_labels]
1411
  if logits_tensor is not None:
1412
- # If shape is [B, L, C] with B==1, squeeze batch
1413
  if logits_tensor.dim() == 3 and logits_tensor.shape[0] == 1:
1414
  preds_tensor = logits_tensor.squeeze(0) # [L, C]
1415
  else:
1416
  preds_tensor = logits_tensor # possibly [L, C] already
1417
 
1418
- # Safety: ensure we have at least seq_len x channels
1419
  if preds_tensor.dim() != 2:
1420
- # try to reshape or error
1421
  raise RuntimeError(f"Unexpected logits tensor shape: {tuple(preds_tensor.shape)}")
1422
- # We'll use preds_tensor[token_idx] to argmax
1423
  else:
1424
  preds_tensor = None # no logits available
1425
 
1426
  # If decoded labels provided, make a token-level list-of-ints aligned to tokenizer tokens
1427
  decoded_token_labels = None
1428
  if decoded_labels_list is not None:
1429
- # decoded_labels_list is batch-first; we used batch size 1
1430
- # if multiple sequences returned, take first
1431
  decoded_token_labels = decoded_labels_list[0] if isinstance(decoded_labels_list[0], list) else decoded_labels_list
1432
 
1433
  # Now map token-level predictions -> word-level predictions using word_ids
1434
  word_idx_to_pred_id = {}
1435
 
1436
  if preds_tensor is not None:
1437
- # We have logits. Use argmax of logits for each token id up to sequence_length
1438
  for token_idx, word_idx in enumerate(word_ids):
1439
  if token_idx >= sequence_length:
1440
  break
@@ -1443,26 +2146,14 @@ def run_inference_and_get_raw_words(pdf_path: str, model_path: str,
1443
  pred_id = torch.argmax(preds_tensor[token_idx]).item()
1444
  word_idx_to_pred_id[word_idx] = pred_id
1445
  else:
1446
- # No logits, but we have decoded_token_labels from CRF (one label per token)
1447
- # We'll align decoded_token_labels to token positions.
1448
  if decoded_token_labels is None:
1449
- # should not happen due to earlier checks
1450
  raise RuntimeError("No logits and no decoded labels available for mapping.")
1451
- # decoded_token_labels length may be equal to content_token_length (no special tokens)
1452
- # or equal to sequence_length; try to align intelligently:
1453
- # Prefer using decoded_token_labels aligned to the tokenizer tokens (starting at token 1 for CLS)
1454
- # If decoded length == content_token_length, then manual_word_ids maps sub-token -> word idx for content tokens only.
1455
- # We'll iterate tokens and pick label accordingly.
1456
- # Build token_idx -> decoded_label mapping:
1457
- # We'll assume decoded_token_labels correspond to content tokens (no CLS/SEP). If decoded length == sequence_length, then shift by 0.
1458
  decoded_len = len(decoded_token_labels)
1459
- # Heuristic: if decoded_len == content_token_length -> alignment starts at token_idx 1 (skip CLS)
1460
  if decoded_len == content_token_length:
1461
  decoded_start = 1
1462
  elif decoded_len == sequence_length:
1463
  decoded_start = 0
1464
  else:
1465
- # fallback: prefer decoded_start=1 (most common)
1466
  decoded_start = 1
1467
 
1468
  for tok_idx_in_decoded, label_id in enumerate(decoded_token_labels):
@@ -1471,11 +2162,9 @@ def run_inference_and_get_raw_words(pdf_path: str, model_path: str,
1471
  break
1472
  if tok_idx >= sequence_length:
1473
  break
1474
- # map this token to a word index if present
1475
  word_idx = word_ids[tok_idx] if tok_idx < len(word_ids) else None
1476
  if word_idx is not None and word_idx < len(sub_words):
1477
  if word_idx not in word_idx_to_pred_id:
1478
- # label_id may already be an int
1479
  word_idx_to_pred_id[word_idx] = int(label_id)
1480
 
1481
  # Finally convert mapped word preds -> page_raw_predictions entries
 
92
 
93
 
94
 
95
+ # def get_latex_from_base64(base64_string: str) -> str:
96
+ # """
97
+ # Decodes a Base64 image string and uses the pre-initialized TrOCR/ORT model
98
+ # to recognize the formula. It cleans the output by removing spaces and
99
+ # crucially, replacing double backslashes with single backslashes for correct LaTeX.
100
+ # """
101
+ # if ort_model is None or processor is None:
102
+ # return "[MODEL_ERROR: Model not initialized]"
103
+
104
+ # try:
105
+ # # 1. Decode Base64 to Image
106
+ # image_data = base64.b64decode(base64_string)
107
+ # # We must ensure the image is RGB format for the model input
108
+ # image = Image.open(io.BytesIO(image_data)).convert('RGB')
109
+
110
+ # # 2. Preprocess the image
111
+ # pixel_values = processor(images=image, return_tensors="pt").pixel_values
112
+
113
+ # # 3. Text Generation (OCR)
114
+ # generated_ids = ort_model.generate(pixel_values)
115
+ # raw_generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
116
+
117
+ # if not raw_generated_text:
118
+ # return "[OCR_WARNING: No formula found]"
119
+
120
+ # latex_string = raw_generated_text[0]
121
+
122
+ # # --- 4. Post-processing and Cleanup ---
123
+
124
+ # # # A. Remove all spaces/line breaks
125
+ # # cleaned_latex = re.sub(r'\s+', '', latex_string)
126
+ # cleaned_latex = re.sub(r'[\r\n]+', '', latex_string)
127
+
128
+ # # B. CRITICAL FIX: Replace double backslashes (\\) with single backslashes (\).
129
+ # # This corrects model output that already over-escaped the LaTeX commands.
130
+ # # Python literal: '\\\\' is replaced with '\\'.
131
+ # #cleaned_latex = cleaned_latex.replace('\\\\', '\\')
132
+
133
+ # return cleaned_latex
134
+
135
+
136
+ # except Exception as e:
137
+ # # Catch any unexpected errors
138
+ # print(f" ❌ TR-OCR Recognition failed: {e}")
139
+ # return f"[TR_OCR_ERROR: Recognition failed: {e}]"
140
+
141
+
142
+
143
+
144
+
145
+
146
+
147
+
148
+
149
  def get_latex_from_base64(base64_string: str) -> str:
150
  """
151
  Decodes a Base64 image string and uses the pre-initialized TrOCR/ORT model
 
172
  return "[OCR_WARNING: No formula found]"
173
 
174
  latex_string = raw_generated_text[0]
175
+
176
+ # ==============================================================================
177
+ # --- DEBUGGING BLOCK: CHECK TrOCR RAW OUTPUT ---
178
+ # ==============================================================================
179
+ print(f"[DEBUG] TrOCR Raw Output: '{latex_string}'")
180
+ # ==============================================================================
181
 
182
  # --- 4. Post-processing and Cleanup ---
183
 
 
640
 
641
 
642
 
643
+ # def extract_native_words_and_convert(fitz_page, scale_factor: float = 2.0) -> list:
644
+ # raw_word_data = fitz_page.get_text("words")
645
+ # converted_ocr_output = []
646
+ # DEFAULT_CONFIDENCE = 99.0
647
+
648
+ # for x1, y1, x2, y2, word, *rest in raw_word_data:
649
+ # # --- FIX: SANITIZE TEXT HERE ---
650
+ # cleaned_word = sanitize_text(word)
651
+ # if not cleaned_word.strip(): continue
652
+
653
+ # x1_pix = int(x1 * scale_factor)
654
+ # y1_pix = int(y1 * scale_factor)
655
+ # x2_pix = int(x2 * scale_factor)
656
+ # y2_pix = int(y2 * scale_factor)
657
+ # converted_ocr_output.append({
658
+ # 'type': 'text',
659
+ # 'word': cleaned_word, # Use the sanitized word
660
+ # 'confidence': DEFAULT_CONFIDENCE,
661
+ # 'bbox': [x1_pix, y1_pix, x2_pix, y2_pix],
662
+ # 'y0': y1_pix, 'x0': x1_pix
663
+ # })
664
+ # return converted_ocr_output
665
+
666
+
667
+
668
+
669
+
670
  def extract_native_words_and_convert(fitz_page, scale_factor: float = 2.0) -> list:
671
  raw_word_data = fitz_page.get_text("words")
672
+
673
+ # ==============================================================================
674
+ # --- DEBUGGING BLOCK: CHECK FIRST 50 NATIVE WORDS ---
675
+ # ==============================================================================
676
+ print(f"\n[DEBUG] Native Extraction (Page {fitz_page.number + 1}): Checking first 50 words...")
677
+ debug_count = 0
678
+ for item in raw_word_data:
679
+ if debug_count >= 50: break
680
+ # item format: (x0, y0, x1, y1, word, block_no, line_no, word_no)
681
+ word_text = item[4]
682
+
683
+ # Generate unicode hex codes for every character in the word
684
+ unicode_points = [f"\\u{ord(c):04x}" for c in word_text]
685
+ print(f" Word {debug_count}: '{word_text}' -> Codes: {unicode_points}")
686
+ debug_count += 1
687
+ print("----------------------------------------------------------------------\n")
688
+ # ==============================================================================
689
+
690
  converted_ocr_output = []
691
  DEFAULT_CONFIDENCE = 99.0
692
 
 
901
 
902
 
903
 
904
+ # def preprocess_and_ocr_page(original_img: np.ndarray, model, pdf_path: str,
905
+ # page_num: int, fitz_page: fitz.Page,
906
+ # pdf_name: str) -> Tuple[List[Dict[str, Any]], Optional[int]]:
907
+ # """
908
+ # OPTIMIZED FLOW:
909
+ # 1. Run YOLO to find Equations/Tables.
910
+ # 2. Mask raw text with YOLO boxes.
911
+ # 3. Run Column Detection on the MASKED data.
912
+ # 4. Proceed with OCR (Native or High-Res Tesseract Fallback) and Output.
913
+ # """
914
+ # global GLOBAL_FIGURE_COUNT, GLOBAL_EQUATION_COUNT
915
+
916
+ # start_time_total = time.time()
917
+
918
+ # if original_img is None:
919
+ # print(f" ❌ Invalid image for page {page_num}.")
920
+ # return None, None
921
+
922
+ # # ====================================================================
923
+ # # --- STEP 1: YOLO DETECTION ---
924
+ # # ====================================================================
925
+ # start_time_yolo = time.time()
926
+ # results = model.predict(source=original_img, conf=CONF_THRESHOLD, imgsz=640, verbose=False)
927
+
928
+ # relevant_detections = []
929
+ # if results and results[0].boxes:
930
+ # for box in results[0].boxes:
931
+ # class_id = int(box.cls[0])
932
+ # class_name = model.names[class_id]
933
+ # if class_name in TARGET_CLASSES:
934
+ # x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
935
+ # relevant_detections.append(
936
+ # {'coords': (x1, y1, x2, y2), 'y1': y1, 'class': class_name, 'conf': float(box.conf[0])}
937
+ # )
938
+
939
+ # merged_detections = merge_overlapping_boxes(relevant_detections, IOU_MERGE_THRESHOLD)
940
+ # print(f" [LOG] YOLO found {len(merged_detections)} objects in {time.time() - start_time_yolo:.3f}s.")
941
+
942
+ # # ====================================================================
943
+ # # --- STEP 2: PREPARE DATA FOR COLUMN DETECTION (MASKING) ---
944
+ # # ====================================================================
945
+ # # Note: This uses the updated 'get_word_data_for_detection' which has its own optimizations
946
+ # raw_words_for_layout = get_word_data_for_detection(
947
+ # fitz_page, pdf_path, page_num,
948
+ # top_margin_percent=0.10, bottom_margin_percent=0.10
949
+ # )
950
+
951
+ # masked_word_data = merge_yolo_into_word_data(raw_words_for_layout, merged_detections, scale_factor=2.0)
952
+
953
+ # # ====================================================================
954
+ # # --- STEP 3: COLUMN DETECTION ---
955
+ # # ====================================================================
956
+ # page_width_pdf = fitz_page.rect.width
957
+ # page_height_pdf = fitz_page.rect.height
958
+
959
+ # column_detection_params = {
960
+ # 'cluster_bin_size': 2, 'cluster_smoothing': 2,
961
+ # 'cluster_min_width': 10, 'cluster_threshold_percentile': 85,
962
+ # }
963
+
964
+ # separators = calculate_x_gutters(masked_word_data, column_detection_params, page_height_pdf)
965
+
966
+ # page_separator_x = None
967
+ # if separators:
968
+ # central_min = page_width_pdf * 0.35
969
+ # central_max = page_width_pdf * 0.65
970
+ # central_separators = [s for s in separators if central_min <= s <= central_max]
971
+
972
+ # if central_separators:
973
+ # center_x = page_width_pdf / 2
974
+ # page_separator_x = min(central_separators, key=lambda x: abs(x - center_x))
975
+ # print(f" ✅ Column Split Confirmed at X={page_separator_x:.1f}")
976
+ # else:
977
+ # print(" ⚠️ Gutter found off-center. Ignoring.")
978
+ # else:
979
+ # print(" -> Single Column Layout Confirmed.")
980
+
981
+ # # ====================================================================
982
+ # # --- STEP 4: COMPONENT EXTRACTION (Save Images) ---
983
+ # # ====================================================================
984
+ # start_time_components = time.time()
985
+ # component_metadata = []
986
+ # fig_count_page = 0
987
+ # eq_count_page = 0
988
+
989
+ # for detection in merged_detections:
990
+ # x1, y1, x2, y2 = detection['coords']
991
+ # class_name = detection['class']
992
+
993
+ # if class_name == 'figure':
994
+ # GLOBAL_FIGURE_COUNT += 1
995
+ # counter = GLOBAL_FIGURE_COUNT
996
+ # component_word = f"FIGURE{counter}"
997
+ # fig_count_page += 1
998
+ # elif class_name == 'equation':
999
+ # GLOBAL_EQUATION_COUNT += 1
1000
+ # counter = GLOBAL_EQUATION_COUNT
1001
+ # component_word = f"EQUATION{counter}"
1002
+ # eq_count_page += 1
1003
+ # else:
1004
+ # continue
1005
+
1006
+ # component_crop = original_img[y1:y2, x1:x2]
1007
+ # component_filename = f"{pdf_name}_page{page_num}_{class_name}{counter}.png"
1008
+ # cv2.imwrite(os.path.join(FIGURE_EXTRACTION_DIR, component_filename), component_crop)
1009
+
1010
+ # y_midpoint = (y1 + y2) // 2
1011
+ # component_metadata.append({
1012
+ # 'type': class_name, 'word': component_word,
1013
+ # 'bbox': [int(x1), int(y1), int(x2), int(y2)],
1014
+ # 'y0': int(y_midpoint), 'x0': int(x1)
1015
+ # })
1016
+
1017
+ # # ====================================================================
1018
+ # # --- STEP 5: HYBRID OCR (Native Text + Cached Tesseract Fallback) ---
1019
+ # # ====================================================================
1020
+ # raw_ocr_output = []
1021
+ # scale_factor = 2.0 # Pipeline standard scale
1022
+
1023
+ # try:
1024
+ # # Try getting native text first
1025
+ # # NOTE: extract_native_words_and_convert MUST ALSO BE UPDATED TO USE sanitize_text
1026
+ # raw_ocr_output = extract_native_words_and_convert(fitz_page, scale_factor=scale_factor)
1027
+ # except Exception as e:
1028
+ # print(f" ❌ Native text extraction failed: {e}")
1029
+
1030
+ # # If native text is missing, fall back to OCR
1031
+ # if not raw_ocr_output:
1032
+ # if _ocr_cache.has_ocr(pdf_path, page_num):
1033
+ # print(f" ⚡ Using cached Tesseract OCR for page {page_num}")
1034
+ # cached_word_data = _ocr_cache.get_ocr(pdf_path, page_num)
1035
+ # for word_tuple in cached_word_data:
1036
+ # word_text, x1, y1, x2, y2 = word_tuple
1037
+
1038
+ # # Scale from PDF points to Pipeline Pixels (2.0)
1039
+ # x1_pix = int(x1 * scale_factor)
1040
+ # y1_pix = int(y1 * scale_factor)
1041
+ # x2_pix = int(x2 * scale_factor)
1042
+ # y2_pix = int(y2 * scale_factor)
1043
+
1044
+ # raw_ocr_output.append({
1045
+ # 'type': 'text', 'word': word_text, 'confidence': 95.0,
1046
+ # 'bbox': [x1_pix, y1_pix, x2_pix, y2_pix],
1047
+ # 'y0': y1_pix, 'x0': x1_pix
1048
+ # })
1049
+ # else:
1050
+ # # === START OF OPTIMIZED OCR BLOCK ===
1051
+ # try:
1052
+ # # 1. Re-render Page at High Resolution (Zoom 4.0 = ~300 DPI)
1053
+ # ocr_zoom = 4.0
1054
+ # pix_ocr = fitz_page.get_pixmap(matrix=fitz.Matrix(ocr_zoom, ocr_zoom))
1055
+
1056
+ # # Convert PyMuPDF Pixmap to OpenCV format
1057
+ # img_ocr_np = np.frombuffer(pix_ocr.samples, dtype=np.uint8).reshape(pix_ocr.height, pix_ocr.width,
1058
+ # pix_ocr.n)
1059
+ # if pix_ocr.n == 3:
1060
+ # img_ocr_np = cv2.cvtColor(img_ocr_np, cv2.COLOR_RGB2BGR)
1061
+ # elif pix_ocr.n == 4:
1062
+ # img_ocr_np = cv2.cvtColor(img_ocr_np, cv2.COLOR_RGBA2BGR)
1063
+
1064
+ # # 2. Preprocess (Binarization)
1065
+ # processed_img = preprocess_image_for_ocr(img_ocr_np)
1066
+
1067
+ # # 3. Run Tesseract with Optimized Configuration
1068
+ # custom_config = r'--oem 3 --psm 6'
1069
+
1070
+ # hocr_data = pytesseract.image_to_data(
1071
+ # processed_img,
1072
+ # output_type=pytesseract.Output.DICT,
1073
+ # config=custom_config
1074
+ # )
1075
+
1076
+ # for i in range(len(hocr_data['level'])):
1077
+ # text = hocr_data['text'][i] # Retrieve raw Tesseract text
1078
+
1079
+ # # --- FIX: SANITIZE TEXT AND THEN STRIP ---
1080
+ # cleaned_text = sanitize_text(text).strip()
1081
+
1082
+ # if cleaned_text and hocr_data['conf'][i] > -1:
1083
+ # # 4. Coordinate Mapping
1084
+ # scale_adjustment = scale_factor / ocr_zoom
1085
+
1086
+ # x1 = int(hocr_data['left'][i] * scale_adjustment)
1087
+ # y1 = int(hocr_data['top'][i] * scale_adjustment)
1088
+ # w = int(hocr_data['width'][i] * scale_adjustment)
1089
+ # h = int(hocr_data['height'][i] * scale_adjustment)
1090
+ # x2 = x1 + w
1091
+ # y2 = y1 + h
1092
+
1093
+ # raw_ocr_output.append({
1094
+ # 'type': 'text',
1095
+ # 'word': cleaned_text, # Use the sanitized word
1096
+ # 'confidence': float(hocr_data['conf'][i]),
1097
+ # 'bbox': [x1, y1, x2, y2],
1098
+ # 'y0': y1,
1099
+ # 'x0': x1
1100
+ # })
1101
+ # except Exception as e:
1102
+ # print(f" ❌ Tesseract OCR Error: {e}")
1103
+ # # === END OF OPTIMIZED OCR BLOCK ===
1104
+
1105
+ # # ====================================================================
1106
+ # # --- STEP 6: OCR CLEANING AND MERGING ---
1107
+ # # ====================================================================
1108
+ # items_to_sort = []
1109
+
1110
+ # for ocr_word in raw_ocr_output:
1111
+ # is_suppressed = False
1112
+ # for component in component_metadata:
1113
+ # # Do not include words that are inside figure/equation boxes
1114
+ # ioa = calculate_ioa(ocr_word['bbox'], component['bbox'])
1115
+ # if ioa > IOA_SUPPRESSION_THRESHOLD:
1116
+ # is_suppressed = True
1117
+ # break
1118
+ # if not is_suppressed:
1119
+ # items_to_sort.append(ocr_word)
1120
+
1121
+ # # Add figures/equations back into the flow as "words"
1122
+ # items_to_sort.extend(component_metadata)
1123
+
1124
+ # # ====================================================================
1125
+ # # --- STEP 7: LINE-BASED SORTING ---
1126
+ # # ====================================================================
1127
+ # items_to_sort.sort(key=lambda x: (x['y0'], x['x0']))
1128
+ # lines = []
1129
+
1130
+ # for item in items_to_sort:
1131
+ # placed = False
1132
+ # for line in lines:
1133
+ # y_ref = min(it['y0'] for it in line)
1134
+ # if abs(y_ref - item['y0']) < LINE_TOLERANCE:
1135
+ # line.append(item)
1136
+ # placed = True
1137
+ # break
1138
+ # if not placed and item['type'] in ['equation', 'figure']:
1139
+ # for line in lines:
1140
+ # y_ref = min(it['y0'] for it in line)
1141
+ # if abs(y_ref - item['y0']) < 20:
1142
+ # line.append(item)
1143
+ # placed = True
1144
+ # break
1145
+ # if not placed:
1146
+ # lines.append([item])
1147
+
1148
+ # for line in lines:
1149
+ # line.sort(key=lambda x: x['x0'])
1150
+
1151
+ # final_output = []
1152
+ # for line in lines:
1153
+ # for item in line:
1154
+ # data_item = {"word": item["word"], "bbox": item["bbox"], "type": item["type"]}
1155
+ # if 'tag' in item: data_item['tag'] = item['tag']
1156
+ # final_output.append(data_item)
1157
+
1158
+ # return final_output, page_separator_x
1159
+
1160
+
1161
+
1162
+
1163
+
1164
+
1165
+
1166
+
1167
+
1168
+
1169
+
1170
+
1171
+
1172
+
1173
  def preprocess_and_ocr_page(original_img: np.ndarray, model, pdf_path: str,
1174
  page_num: int, fitz_page: fitz.Page,
1175
  pdf_name: str) -> Tuple[List[Dict[str, Any]], Optional[int]]:
 
1342
  config=custom_config
1343
  )
1344
 
1345
+ # ==============================================================================
1346
+ # --- DEBUGGING BLOCK: CHECK FIRST 50 OCR WORDS ---
1347
+ # ==============================================================================
1348
+ print(f"\n[DEBUG] Tesseract OCR Fallback (Page {page_num}): Checking first 50 words...")
1349
+ debug_count = 0
1350
+ for i in range(len(hocr_data['level'])):
1351
+ text = hocr_data['text'][i].strip()
1352
+ if text:
1353
+ unicode_points = [f"\\u{ord(c):04x}" for c in text]
1354
+ print(f" OCR Word {debug_count}: '{text}' -> Codes: {unicode_points}")
1355
+ debug_count += 1
1356
+ if debug_count >= 50: break
1357
+ print("----------------------------------------------------------------------\n")
1358
+ # ==============================================================================
1359
+
1360
  for i in range(len(hocr_data['level'])):
1361
  text = hocr_data['text'][i] # Retrieve raw Tesseract text
1362
 
 
1442
  return final_output, page_separator_x
1443
 
1444
 
1445
+
1446
+
1447
+
1448
+
1449
+
1450
+
1451
  def run_single_pdf_preprocessing(pdf_path: str, preprocessed_json_path: str) -> Optional[str]:
1452
  global GLOBAL_FIGURE_COUNT, GLOBAL_EQUATION_COUNT
1453
 
 
1592
 
1593
 
1594
 
1595
+ # def run_inference_and_get_raw_words(pdf_path: str, model_path: str,
1596
+ # preprocessed_json_path: str,
1597
+ # column_detection_params: Optional[Dict] = None) -> List[Dict[str, Any]]:
1598
+ # print("\n" + "=" * 80)
1599
+ # print("--- 2. STARTING LAYOUTLMV3 INFERENCE PIPELINE (Raw Word Output) ---")
1600
+ # print("=" * 80)
1601
+
1602
+ # tokenizer = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base")
1603
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1604
+ # print(f" -> Using device: {device}")
1605
+
1606
+ # try:
1607
+ # model = LayoutLMv3ForTokenClassification(num_labels=NUM_LABELS)
1608
+ # checkpoint = torch.load(model_path, map_location=device)
1609
+ # model_state = checkpoint.get('model_state_dict', checkpoint)
1610
+ # # Apply patch for layoutlmv3 compatibility with saved state_dict
1611
+ # fixed_state_dict = {key.replace('layoutlm.', 'layoutlmv3.'): value for key, value in model_state.items()}
1612
+ # model.load_state_dict(fixed_state_dict)
1613
+ # model.to(device)
1614
+ # model.eval()
1615
+ # print(f"✅ LayoutLMv3 Model loaded successfully from {os.path.basename(model_path)}.")
1616
+ # except Exception as e:
1617
+ # print(f"❌ FATAL ERROR during LayoutLMv3 model loading: {e}")
1618
+ # return []
1619
+
1620
+ # try:
1621
+ # with open(preprocessed_json_path, 'r', encoding='utf-8') as f:
1622
+ # preprocessed_data = json.load(f)
1623
+ # print(f"✅ Loaded preprocessed data with {len(preprocessed_data)} pages.")
1624
+ # except Exception:
1625
+ # print("❌ Error loading preprocessed JSON.")
1626
+ # return []
1627
+
1628
+ # try:
1629
+ # doc = fitz.open(pdf_path)
1630
+ # except Exception:
1631
+ # print("❌ Error loading PDF.")
1632
+ # return []
1633
+
1634
+ # final_page_predictions = []
1635
+ # CHUNK_SIZE = 500
1636
+
1637
+ # for page_data in preprocessed_data:
1638
+ # page_num_1_based = page_data['page_number']
1639
+ # page_num_0_based = page_num_1_based - 1
1640
+ # page_raw_predictions = []
1641
+ # print(f"\n *** Processing Page {page_num_1_based} ({len(page_data['data'])} raw tokens) ***")
1642
+
1643
+ # fitz_page = doc.load_page(page_num_0_based)
1644
+ # page_width, page_height = fitz_page.rect.width, fitz_page.rect.height
1645
+ # print(f" -> Page dimensions: {page_width:.0f}x{page_height:.0f} (PDF points).")
1646
+
1647
+ # all_token_data = []
1648
+ # scale_factor = 2.0
1649
+
1650
+ # for item in page_data['data']:
1651
+ # raw_yolo_bbox = item['bbox']
1652
+ # bbox_pdf = [
1653
+ # int(raw_yolo_bbox[0] / scale_factor), int(raw_yolo_bbox[1] / scale_factor),
1654
+ # int(raw_yolo_bbox[2] / scale_factor), int(raw_yolo_bbox[3] / scale_factor)
1655
+ # ]
1656
+ # normalized_bbox = [
1657
+ # max(0, min(1000, int(1000 * bbox_pdf[0] / page_width))),
1658
+ # max(0, min(1000, int(1000 * bbox_pdf[1] / page_height))),
1659
+ # max(0, min(1000, int(1000 * bbox_pdf[2] / page_width))),
1660
+ # max(0, min(1000, int(1000 * bbox_pdf[3] / page_height)))
1661
+ # ]
1662
+ # all_token_data.append({
1663
+ # "word": item['word'],
1664
+ # "bbox_raw_pdf_space": bbox_pdf,
1665
+ # "bbox_normalized": normalized_bbox,
1666
+ # "item_original_data": item
1667
+ # })
1668
+
1669
+ # if not all_token_data:
1670
+ # continue
1671
+
1672
+ # column_separator_x = page_data.get('column_separator_x', None)
1673
+ # if column_separator_x is not None:
1674
+ # print(f" -> Using SAVED column separator: X={column_separator_x}")
1675
+ # else:
1676
+ # print(" -> No column separator found. Assuming single chunk.")
1677
+
1678
+ # token_chunks = _merge_integrity(all_token_data, column_separator_x)
1679
+ # total_chunks = len(token_chunks)
1680
+
1681
+ # for chunk_idx, chunk_tokens in enumerate(token_chunks):
1682
+ # if not chunk_tokens: continue
1683
+
1684
+ # # 1. Sanitize: Convert everything to strings and aggressively clean Unicode errors.
1685
+ # chunk_words = [
1686
+ # str(t['word']).encode('utf-8', errors='ignore').decode('utf-8')
1687
+ # for t in chunk_tokens
1688
+ # ]
1689
+ # chunk_normalized_bboxes = [t['bbox_normalized'] for t in chunk_tokens]
1690
+
1691
+ # total_sub_chunks = (len(chunk_words) + CHUNK_SIZE - 1) // CHUNK_SIZE
1692
+ # for i in range(0, len(chunk_words), CHUNK_SIZE):
1693
+ # sub_chunk_idx = i // CHUNK_SIZE + 1
1694
+ # sub_words = chunk_words[i:i + CHUNK_SIZE]
1695
+ # sub_bboxes = chunk_normalized_bboxes[i:i + CHUNK_SIZE]
1696
+ # sub_tokens_data = chunk_tokens[i:i + CHUNK_SIZE]
1697
+
1698
+ # print(f" -> Chunk {chunk_idx + 1}/{total_chunks}, Sub-chunk {sub_chunk_idx}/{total_sub_chunks}: {len(sub_words)} words. Running Inference...")
1699
+
1700
+ # # 2. Manual generation of word_ids
1701
+ # manual_word_ids = []
1702
+ # for current_word_idx, word in enumerate(sub_words):
1703
+ # sub_tokens = tokenizer.tokenize(word)
1704
+ # for _ in sub_tokens:
1705
+ # manual_word_ids.append(current_word_idx)
1706
+
1707
+ # encoded_input = tokenizer(
1708
+ # sub_words,
1709
+ # boxes=sub_bboxes,
1710
+ # truncation=True,
1711
+ # padding="max_length",
1712
+ # max_length=512,
1713
+ # is_split_into_words=True,
1714
+ # return_tensors="pt"
1715
+ # )
1716
+
1717
+ # # Check for empty sequence
1718
+ # if encoded_input['input_ids'].shape[0] == 0:
1719
+ # print(f" -> Warning: Sub-chunk {sub_chunk_idx} encoded to an empty sequence. Skipping.")
1720
+ # continue
1721
+
1722
+ # # 3. Finalize word_ids based on encoded output length
1723
+ # sequence_length = int(torch.sum(encoded_input['attention_mask']).item())
1724
+ # content_token_length = max(0, sequence_length - 2)
1725
+
1726
+ # manual_word_ids = manual_word_ids[:content_token_length]
1727
+
1728
+ # final_word_ids = [None] # CLS token (index 0)
1729
+ # final_word_ids.extend(manual_word_ids)
1730
+
1731
+ # if sequence_length > 1:
1732
+ # final_word_ids.append(None) # SEP token
1733
+
1734
+ # final_word_ids.extend([None] * (512 - len(final_word_ids)))
1735
+ # word_ids = final_word_ids[:512] # Final array for mapping
1736
+
1737
+ # # Inputs are already batched by the tokenizer as [1, 512]
1738
+ # input_ids = encoded_input['input_ids'].to(device)
1739
+ # bbox = encoded_input['bbox'].to(device)
1740
+ # attention_mask = encoded_input['attention_mask'].to(device)
1741
+
1742
+ # with torch.no_grad():
1743
+ # model_outputs = model(input_ids, bbox, attention_mask)
1744
+
1745
+ # # --- Robust extraction: support several forward return types ---
1746
+ # # We'll try (in order):
1747
+ # # 1) model_outputs is (emissions_tensor, viterbi_list) -> use emissions for logits, keep decoded
1748
+ # # 2) model_outputs has .logits attribute (HF ModelOutput)
1749
+ # # 3) model_outputs is tuple/list containing a logits tensor
1750
+ # # 4) model_outputs is a tensor (assume logits)
1751
+ # # 5) model_outputs is a list-of-lists of ints (viterbi decoded) -> use that directly (no logits)
1752
+ # logits_tensor = None
1753
+ # decoded_labels_list = None
1754
+
1755
+ # # case 1: tuple/list with (emissions, viterbi)
1756
+ # if isinstance(model_outputs, (tuple, list)) and len(model_outputs) == 2:
1757
+ # a, b = model_outputs
1758
+ # # a might be tensor (emissions), b might be viterbi list
1759
+ # if isinstance(a, torch.Tensor):
1760
+ # logits_tensor = a
1761
+ # if isinstance(b, list):
1762
+ # decoded_labels_list = b
1763
+
1764
+ # # case 2: HF ModelOutput with .logits
1765
+ # if logits_tensor is None and hasattr(model_outputs, 'logits') and isinstance(model_outputs.logits, torch.Tensor):
1766
+ # logits_tensor = model_outputs.logits
1767
+
1768
+ # # case 3: tuple/list - search for a 3D tensor (B, L, C)
1769
+ # if logits_tensor is None and isinstance(model_outputs, (tuple, list)):
1770
+ # found_tensor = None
1771
+ # for item in model_outputs:
1772
+ # if isinstance(item, torch.Tensor):
1773
+ # # prefer 3D (batch, seq, labels)
1774
+ # if item.dim() == 3:
1775
+ # logits_tensor = item
1776
+ # break
1777
+ # if found_tensor is None:
1778
+ # found_tensor = item
1779
+ # if logits_tensor is None and found_tensor is not None:
1780
+ # # found_tensor may be (batch, seq, hidden) or (seq, hidden); we avoid guessing.
1781
+ # # Keep found_tensor only if it matches num_labels dimension
1782
+ # if found_tensor.dim() == 3 and found_tensor.shape[-1] == NUM_LABELS:
1783
+ # logits_tensor = found_tensor
1784
+ # elif found_tensor.dim() == 2 and found_tensor.shape[-1] == NUM_LABELS:
1785
+ # logits_tensor = found_tensor.unsqueeze(0)
1786
+
1787
+ # # case 4: model_outputs directly a tensor
1788
+ # if logits_tensor is None and isinstance(model_outputs, torch.Tensor):
1789
+ # logits_tensor = model_outputs
1790
+
1791
+ # # case 5: model_outputs is a decoded viterbi list (common for CRF-only forward)
1792
+ # if decoded_labels_list is None and isinstance(model_outputs, list) and model_outputs and isinstance(model_outputs[0], list):
1793
+ # # assume model_outputs is already viterbi decoded: List[List[int]] with batch dim first
1794
+ # decoded_labels_list = model_outputs
1795
+
1796
+ # # If neither logits nor decoded exist, that's fatal
1797
+ # if logits_tensor is None and decoded_labels_list is None:
1798
+ # # helpful debug info
1799
+ # try:
1800
+ # elem_shapes = [ (type(x), getattr(x, 'shape', None)) for x in model_outputs ] if isinstance(model_outputs, (list, tuple)) else [(type(model_outputs), getattr(model_outputs, 'shape', None))]
1801
+ # except Exception:
1802
+ # elem_shapes = str(type(model_outputs))
1803
+ # raise RuntimeError(f"Model output of type {type(model_outputs)} did not contain a valid logits tensor or decoded viterbi. Contents: {elem_shapes}")
1804
+
1805
+ # # If we have logits_tensor, normalize shape to [seq_len, num_labels]
1806
+ # if logits_tensor is not None:
1807
+ # # If shape is [B, L, C] with B==1, squeeze batch
1808
+ # if logits_tensor.dim() == 3 and logits_tensor.shape[0] == 1:
1809
+ # preds_tensor = logits_tensor.squeeze(0) # [L, C]
1810
+ # else:
1811
+ # preds_tensor = logits_tensor # possibly [L, C] already
1812
+
1813
+ # # Safety: ensure we have at least seq_len x channels
1814
+ # if preds_tensor.dim() != 2:
1815
+ # # try to reshape or error
1816
+ # raise RuntimeError(f"Unexpected logits tensor shape: {tuple(preds_tensor.shape)}")
1817
+ # # We'll use preds_tensor[token_idx] to argmax
1818
+ # else:
1819
+ # preds_tensor = None # no logits available
1820
+
1821
+ # # If decoded labels provided, make a token-level list-of-ints aligned to tokenizer tokens
1822
+ # decoded_token_labels = None
1823
+ # if decoded_labels_list is not None:
1824
+ # # decoded_labels_list is batch-first; we used batch size 1
1825
+ # # if multiple sequences returned, take first
1826
+ # decoded_token_labels = decoded_labels_list[0] if isinstance(decoded_labels_list[0], list) else decoded_labels_list
1827
+
1828
+ # # Now map token-level predictions -> word-level predictions using word_ids
1829
+ # word_idx_to_pred_id = {}
1830
+
1831
+ # if preds_tensor is not None:
1832
+ # # We have logits. Use argmax of logits for each token id up to sequence_length
1833
+ # for token_idx, word_idx in enumerate(word_ids):
1834
+ # if token_idx >= sequence_length:
1835
+ # break
1836
+ # if word_idx is not None and word_idx < len(sub_words):
1837
+ # if word_idx not in word_idx_to_pred_id:
1838
+ # pred_id = torch.argmax(preds_tensor[token_idx]).item()
1839
+ # word_idx_to_pred_id[word_idx] = pred_id
1840
+ # else:
1841
+ # # No logits, but we have decoded_token_labels from CRF (one label per token)
1842
+ # # We'll align decoded_token_labels to token positions.
1843
+ # if decoded_token_labels is None:
1844
+ # # should not happen due to earlier checks
1845
+ # raise RuntimeError("No logits and no decoded labels available for mapping.")
1846
+ # # decoded_token_labels length may be equal to content_token_length (no special tokens)
1847
+ # # or equal to sequence_length; try to align intelligently:
1848
+ # # Prefer using decoded_token_labels aligned to the tokenizer tokens (starting at token 1 for CLS)
1849
+ # # If decoded length == content_token_length, then manual_word_ids maps sub-token -> word idx for content tokens only.
1850
+ # # We'll iterate tokens and pick label accordingly.
1851
+ # # Build token_idx -> decoded_label mapping:
1852
+ # # We'll assume decoded_token_labels correspond to content tokens (no CLS/SEP). If decoded length == sequence_length, then shift by 0.
1853
+ # decoded_len = len(decoded_token_labels)
1854
+ # # Heuristic: if decoded_len == content_token_length -> alignment starts at token_idx 1 (skip CLS)
1855
+ # if decoded_len == content_token_length:
1856
+ # decoded_start = 1
1857
+ # elif decoded_len == sequence_length:
1858
+ # decoded_start = 0
1859
+ # else:
1860
+ # # fallback: prefer decoded_start=1 (most common)
1861
+ # decoded_start = 1
1862
+
1863
+ # for tok_idx_in_decoded, label_id in enumerate(decoded_token_labels):
1864
+ # tok_idx = decoded_start + tok_idx_in_decoded
1865
+ # if tok_idx >= 512:
1866
+ # break
1867
+ # if tok_idx >= sequence_length:
1868
+ # break
1869
+ # # map this token to a word index if present
1870
+ # word_idx = word_ids[tok_idx] if tok_idx < len(word_ids) else None
1871
+ # if word_idx is not None and word_idx < len(sub_words):
1872
+ # if word_idx not in word_idx_to_pred_id:
1873
+ # # label_id may already be an int
1874
+ # word_idx_to_pred_id[word_idx] = int(label_id)
1875
+
1876
+ # # Finally convert mapped word preds -> page_raw_predictions entries
1877
+ # for current_word_idx in range(len(sub_words)):
1878
+ # pred_id = word_idx_to_pred_id.get(current_word_idx, 0) # default to 0
1879
+ # predicted_label = ID_TO_LABEL[pred_id]
1880
+ # original_token = sub_tokens_data[current_word_idx]
1881
+ # page_raw_predictions.append({
1882
+ # "word": original_token['word'],
1883
+ # "bbox": original_token['bbox_raw_pdf_space'],
1884
+ # "predicted_label": predicted_label,
1885
+ # "page_number": page_num_1_based
1886
+ # })
1887
+
1888
+ # if page_raw_predictions:
1889
+ # final_page_predictions.append({
1890
+ # "page_number": page_num_1_based,
1891
+ # "data": page_raw_predictions
1892
+ # })
1893
+ # print(f" *** Page {page_num_1_based} Finalized: {len(page_raw_predictions)} labeled words. ***")
1894
+
1895
+ # doc.close()
1896
+ # print("\n" + "=" * 80)
1897
+ # print("--- LAYOUTLMV3 INFERENCE COMPLETE ---")
1898
+ # print("=" * 80)
1899
+ # return final_page_predictions
1900
+
1901
+
1902
+
1903
+
1904
+
1905
+
1906
+
1907
+
1908
  def run_inference_and_get_raw_words(pdf_path: str, model_path: str,
1909
  preprocessed_json_path: str,
1910
  column_detection_params: Optional[Dict] = None) -> List[Dict[str, Any]]:
 
1979
  "item_original_data": item
1980
  })
1981
 
1982
+ # ==============================================================================
1983
+ # --- DEBUGGING BLOCK: CHECK FIRST 50 TOKENS BEFORE INFERENCE ---
1984
+ # ==============================================================================
1985
+ print(f"\n[DEBUG] LayoutLMv3 Input (Page {page_num_1_based}): Checking first 50 tokens...")
1986
+ debug_count = 0
1987
+ for t in all_token_data:
1988
+ if debug_count >= 50: break
1989
+ w = t['word']
1990
+ unicode_points = [f"\\u{ord(c):04x}" for c in w]
1991
+ print(f" Token {debug_count}: '{w}' -> Codes: {unicode_points}")
1992
+ debug_count += 1
1993
+ print("----------------------------------------------------------------------\n")
1994
+ # ==============================================================================
1995
+
1996
  if not all_token_data:
1997
  continue
1998
 
 
2070
  model_outputs = model(input_ids, bbox, attention_mask)
2071
 
2072
  # --- Robust extraction: support several forward return types ---
 
 
 
 
 
 
2073
  logits_tensor = None
2074
  decoded_labels_list = None
2075
 
2076
  # case 1: tuple/list with (emissions, viterbi)
2077
  if isinstance(model_outputs, (tuple, list)) and len(model_outputs) == 2:
2078
  a, b = model_outputs
 
2079
  if isinstance(a, torch.Tensor):
2080
  logits_tensor = a
2081
  if isinstance(b, list):
 
2090
  found_tensor = None
2091
  for item in model_outputs:
2092
  if isinstance(item, torch.Tensor):
 
2093
  if item.dim() == 3:
2094
  logits_tensor = item
2095
  break
2096
  if found_tensor is None:
2097
  found_tensor = item
2098
  if logits_tensor is None and found_tensor is not None:
 
 
2099
  if found_tensor.dim() == 3 and found_tensor.shape[-1] == NUM_LABELS:
2100
  logits_tensor = found_tensor
2101
  elif found_tensor.dim() == 2 and found_tensor.shape[-1] == NUM_LABELS:
 
2107
 
2108
  # case 5: model_outputs is a decoded viterbi list (common for CRF-only forward)
2109
  if decoded_labels_list is None and isinstance(model_outputs, list) and model_outputs and isinstance(model_outputs[0], list):
 
2110
  decoded_labels_list = model_outputs
2111
 
2112
  # If neither logits nor decoded exist, that's fatal
2113
  if logits_tensor is None and decoded_labels_list is None:
 
2114
  try:
2115
  elem_shapes = [ (type(x), getattr(x, 'shape', None)) for x in model_outputs ] if isinstance(model_outputs, (list, tuple)) else [(type(model_outputs), getattr(model_outputs, 'shape', None))]
2116
  except Exception:
 
2119
 
2120
  # If we have logits_tensor, normalize shape to [seq_len, num_labels]
2121
  if logits_tensor is not None:
 
2122
  if logits_tensor.dim() == 3 and logits_tensor.shape[0] == 1:
2123
  preds_tensor = logits_tensor.squeeze(0) # [L, C]
2124
  else:
2125
  preds_tensor = logits_tensor # possibly [L, C] already
2126
 
 
2127
  if preds_tensor.dim() != 2:
 
2128
  raise RuntimeError(f"Unexpected logits tensor shape: {tuple(preds_tensor.shape)}")
 
2129
  else:
2130
  preds_tensor = None # no logits available
2131
 
2132
  # If decoded labels provided, make a token-level list-of-ints aligned to tokenizer tokens
2133
  decoded_token_labels = None
2134
  if decoded_labels_list is not None:
 
 
2135
  decoded_token_labels = decoded_labels_list[0] if isinstance(decoded_labels_list[0], list) else decoded_labels_list
2136
 
2137
  # Now map token-level predictions -> word-level predictions using word_ids
2138
  word_idx_to_pred_id = {}
2139
 
2140
  if preds_tensor is not None:
 
2141
  for token_idx, word_idx in enumerate(word_ids):
2142
  if token_idx >= sequence_length:
2143
  break
 
2146
  pred_id = torch.argmax(preds_tensor[token_idx]).item()
2147
  word_idx_to_pred_id[word_idx] = pred_id
2148
  else:
 
 
2149
  if decoded_token_labels is None:
 
2150
  raise RuntimeError("No logits and no decoded labels available for mapping.")
 
 
 
 
 
 
 
2151
  decoded_len = len(decoded_token_labels)
 
2152
  if decoded_len == content_token_length:
2153
  decoded_start = 1
2154
  elif decoded_len == sequence_length:
2155
  decoded_start = 0
2156
  else:
 
2157
  decoded_start = 1
2158
 
2159
  for tok_idx_in_decoded, label_id in enumerate(decoded_token_labels):
 
2162
  break
2163
  if tok_idx >= sequence_length:
2164
  break
 
2165
  word_idx = word_ids[tok_idx] if tok_idx < len(word_ids) else None
2166
  if word_idx is not None and word_idx < len(sub_words):
2167
  if word_idx not in word_idx_to_pred_id:
 
2168
  word_idx_to_pred_id[word_idx] = int(label_id)
2169
 
2170
  # Finally convert mapped word preds -> page_raw_predictions entries