heerjtdev commited on
Commit
c7915ba
·
verified ·
1 Parent(s): ae075a3

Update working_yolo_pipeline.py

Browse files
Files changed (1) hide show
  1. working_yolo_pipeline.py +286 -6
working_yolo_pipeline.py CHANGED
@@ -25,6 +25,95 @@ import shutil
25
  from sklearn.feature_extraction.text import CountVectorizer
26
  from sklearn.metrics.pairwise import cosine_similarity
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  # ============================================================================
29
  # --- CONFIGURATION AND CONSTANTS ---
30
  # ============================================================================
@@ -1466,10 +1555,50 @@ def get_base64_for_file(filepath: str) -> str:
1466
  return ""
1467
 
1468
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1469
  def embed_images_as_base64_in_memory(structured_data: List[Dict[str, Any]], figure_extraction_dir: str) -> List[
1470
  Dict[str, Any]]:
1471
  print("\n" + "=" * 80)
1472
- print("--- 4. STARTING IMAGE EMBEDDING (Base64) ---")
1473
  print("=" * 80)
1474
  if not structured_data: return []
1475
  image_files = glob.glob(os.path.join(figure_extraction_dir, "*.png"))
@@ -1482,25 +1611,57 @@ def embed_images_as_base64_in_memory(structured_data: List[Dict[str, Any]], figu
1482
  key = f"{match.group(1).upper()}{match.group(2)}"
1483
  image_lookup[key] = filepath
1484
  print(f" -> Found {len(image_lookup)} image components.")
 
1485
  final_structured_data = []
 
1486
  for item in structured_data:
1487
  text_fields = [item.get('question', ''), item.get('passage', '')]
1488
  if 'options' in item:
1489
  for opt_val in item['options'].values(): text_fields.append(opt_val)
1490
  if 'new_passage' in item: text_fields.append(item['new_passage'])
 
1491
  unique_tags_to_embed = set()
1492
  for text in text_fields:
1493
  if not text: continue
1494
  for match in tag_regex.finditer(text):
1495
  tag = match.group(0).upper()
1496
  if tag in image_lookup: unique_tags_to_embed.add(tag)
 
 
 
 
1497
  for tag in sorted(list(unique_tags_to_embed)):
1498
  filepath = image_lookup[tag]
 
1499
  base64_code = get_base64_for_file(filepath)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1500
  base_key = tag.replace(' ', '').lower()
1501
  item[base_key] = base64_code
 
1502
  final_structured_data.append(item)
1503
- print(f"✅ Image embedding complete.")
 
1504
  return final_structured_data
1505
 
1506
 
@@ -1508,7 +1669,76 @@ def embed_images_as_base64_in_memory(structured_data: List[Dict[str, Any]], figu
1508
  # --- MAIN FUNCTION ---
1509
  # ============================================================================
1510
 
1511
- def run_document_pipeline(input_pdf_path: str, layoutlmv3_model_path: str, label_studio_output_path: str) -> Optional[
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1512
  List[Dict[str, Any]]]:
1513
  if not os.path.exists(input_pdf_path): return None
1514
 
@@ -1536,9 +1766,17 @@ def run_document_pipeline(input_pdf_path: str, layoutlmv3_model_path: str, label
1536
  )
1537
  if not page_raw_predictions_list: return None
1538
 
 
 
1539
  with open(raw_output_path, 'w', encoding='utf-8') as f:
1540
  json.dump(page_raw_predictions_list, f, indent=4)
1541
 
 
 
 
 
 
 
1542
  # Phase 3: Decoding
1543
  structured_data_list = convert_bio_to_structured_json_relaxed(
1544
  raw_output_path, structured_intermediate_output_path
@@ -1552,7 +1790,7 @@ def run_document_pipeline(input_pdf_path: str, layoutlmv3_model_path: str, label
1552
  except Exception as e:
1553
  print(f"❌ Error during Label Studio conversion: {e}")
1554
 
1555
- # Phase 4: Embedding
1556
  final_result = embed_images_as_base64_in_memory(structured_data_list, FIGURE_EXTRACTION_DIR)
1557
 
1558
  except Exception as e:
@@ -1575,19 +1813,61 @@ def run_document_pipeline(input_pdf_path: str, layoutlmv3_model_path: str, label
1575
  return final_result
1576
 
1577
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1578
  if __name__ == "__main__":
1579
  parser = argparse.ArgumentParser(description="Complete Pipeline")
1580
  parser.add_argument("--input_pdf", type=str, required=True, help="Input PDF")
1581
  parser.add_argument("--layoutlmv3_model_path", type=str, default=DEFAULT_LAYOUTLMV3_MODEL_PATH, help="Model Path")
1582
  parser.add_argument("--ls_output_path", type=str, default=None, help="Label Studio Output Path")
 
 
 
 
1583
  args = parser.parse_args()
1584
 
1585
  pdf_name = os.path.splitext(os.path.basename(args.input_pdf))[0]
1586
  final_output_path = os.path.abspath(f"{pdf_name}_final_output_embedded.json")
1587
  ls_output_path = os.path.abspath(
1588
  args.ls_output_path if args.ls_output_path else f"{pdf_name}_label_studio_tasks.json")
1589
-
1590
- final_json_data = run_document_pipeline(args.input_pdf, args.layoutlmv3_model_path, ls_output_path)
 
 
 
 
 
 
 
 
 
 
 
1591
 
1592
  if final_json_data:
1593
  with open(final_output_path, 'w', encoding='utf-8') as f:
 
25
  from sklearn.feature_extraction.text import CountVectorizer
26
  from sklearn.metrics.pairwise import cosine_similarity
27
 
28
+
29
+
30
+
31
+
32
+
33
+ #=============================================================================
34
+ #-----EXPERIMENT LATEX
35
+ #=============================================================================
36
+
37
+
38
+ # --- NEW IMPORTS ---
39
+ from pix2text import Pix2Text
40
+ import logging
41
+ # -------------------
42
+
43
+ # ============================================================================
44
+ # --- CONFIGURATION AND CONSTANTS ---
45
+ # ... (Your existing constants like WEIGHTS_PATH, OCR_JSON_OUTPUT_DIR, etc.)
46
+ # ============================================================================
47
+
48
+ # ============================================================================
49
+ # --- PIX2TEXT INITIALIZATION AND HELPER ---
50
+ # ============================================================================
51
+ # Set up logging to WARNING level to suppress excessive output from model libraries
52
+ logging.basicConfig(level=logging.WARNING)
53
+ logging.getLogger('pix2text').setLevel(logging.WARNING)
54
+
55
+ # Initialize Pix2Text model globally (expensive operation, do it once)
56
+ p2t = None
57
+ try:
58
+ # Use 'yolox_tiny' for potentially faster inference in a pipeline context
59
+ p2t = Pix2Text(analyzer_config={'model_name': 'yolox_tiny'})
60
+ print("✅ Pix2Text model initialized successfully for equation conversion.")
61
+ except Exception as e:
62
+ print(f"❌ Error initializing Pix2Text model. Equations will not be converted: {e}")
63
+ p2t = None
64
+
65
+ def get_latex_from_base64(base64_string: str) -> str:
66
+ """
67
+ Decodes a Base64 image string, uses Pix2Text to recognize the formula,
68
+ and returns the LaTeX code, wrapped in $$.
69
+ """
70
+ if p2t is None:
71
+ return "[P2T_ERROR: Model not initialized]"
72
+
73
+ try:
74
+ # 1. Decode Base64 to Image
75
+ image_data = base64.b64decode(base64_string)
76
+ image = Image.open(io.BytesIO(image_data))
77
+
78
+ # 2. Recognize text and formulas
79
+ result = p2t.recognize(image, save_formula_images=False, use_analyzer=True)
80
+
81
+ # 3. Parse the result for LaTeX
82
+ extracted_latex_parts = []
83
+ if isinstance(result, list):
84
+ for item in result:
85
+ if hasattr(item, 'text'):
86
+ extracted_latex_parts.append(item.text)
87
+ elif isinstance(item, str):
88
+ extracted_latex_parts.append(item)
89
+ elif isinstance(result, str):
90
+ extracted_latex_parts = [result]
91
+
92
+ extracted_latex = " ".join(extracted_latex_parts).strip()
93
+
94
+ if not extracted_latex:
95
+ return "[P2T_WARNING: No formula found]"
96
+
97
+ # Wrap result in LaTeX delimiters
98
+ return f"$${extracted_latex}$$"
99
+
100
+ except Exception as e:
101
+ # Catch any unexpected errors
102
+ print(f" ❌ Pix2Text Recognition failed: {e}")
103
+ return f"[P2T_ERROR: Recognition failed: {e}]"
104
+
105
+
106
+
107
+
108
+
109
+
110
+ #=============================================================================
111
+ #-----EXPERIMENT LATEX
112
+ #=============================================================================
113
+
114
+
115
+
116
+
117
  # ============================================================================
118
  # --- CONFIGURATION AND CONSTANTS ---
119
  # ============================================================================
 
1555
  return ""
1556
 
1557
 
1558
+ # def embed_images_as_base64_in_memory(structured_data: List[Dict[str, Any]], figure_extraction_dir: str) -> List[
1559
+ # Dict[str, Any]]:
1560
+ # print("\n" + "=" * 80)
1561
+ # print("--- 4. STARTING IMAGE EMBEDDING (Base64) ---")
1562
+ # print("=" * 80)
1563
+ # if not structured_data: return []
1564
+ # image_files = glob.glob(os.path.join(figure_extraction_dir, "*.png"))
1565
+ # image_lookup = {}
1566
+ # tag_regex = re.compile(r'(figure|equation)(\d+)', re.IGNORECASE)
1567
+ # for filepath in image_files:
1568
+ # filename = os.path.basename(filepath)
1569
+ # match = re.search(r'_(figure|equation)(\d+)\.png$', filename, re.IGNORECASE)
1570
+ # if match:
1571
+ # key = f"{match.group(1).upper()}{match.group(2)}"
1572
+ # image_lookup[key] = filepath
1573
+ # print(f" -> Found {len(image_lookup)} image components.")
1574
+ # final_structured_data = []
1575
+ # for item in structured_data:
1576
+ # text_fields = [item.get('question', ''), item.get('passage', '')]
1577
+ # if 'options' in item:
1578
+ # for opt_val in item['options'].values(): text_fields.append(opt_val)
1579
+ # if 'new_passage' in item: text_fields.append(item['new_passage'])
1580
+ # unique_tags_to_embed = set()
1581
+ # for text in text_fields:
1582
+ # if not text: continue
1583
+ # for match in tag_regex.finditer(text):
1584
+ # tag = match.group(0).upper()
1585
+ # if tag in image_lookup: unique_tags_to_embed.add(tag)
1586
+ # for tag in sorted(list(unique_tags_to_embed)):
1587
+ # filepath = image_lookup[tag]
1588
+ # base64_code = get_base64_for_file(filepath)
1589
+ # base_key = tag.replace(' ', '').lower()
1590
+ # item[base_key] = base64_code
1591
+ # final_structured_data.append(item)
1592
+ # print(f"✅ Image embedding complete.")
1593
+ # return final_structured_data
1594
+
1595
+
1596
+
1597
+
1598
  def embed_images_as_base64_in_memory(structured_data: List[Dict[str, Any]], figure_extraction_dir: str) -> List[
1599
  Dict[str, Any]]:
1600
  print("\n" + "=" * 80)
1601
+ print("--- 4. STARTING IMAGE EMBEDDING (Base64) / EQUATION TO LATEX CONVERSION ---")
1602
  print("=" * 80)
1603
  if not structured_data: return []
1604
  image_files = glob.glob(os.path.join(figure_extraction_dir, "*.png"))
 
1611
  key = f"{match.group(1).upper()}{match.group(2)}"
1612
  image_lookup[key] = filepath
1613
  print(f" -> Found {len(image_lookup)} image components.")
1614
+
1615
  final_structured_data = []
1616
+
1617
  for item in structured_data:
1618
  text_fields = [item.get('question', ''), item.get('passage', '')]
1619
  if 'options' in item:
1620
  for opt_val in item['options'].values(): text_fields.append(opt_val)
1621
  if 'new_passage' in item: text_fields.append(item['new_passage'])
1622
+
1623
  unique_tags_to_embed = set()
1624
  for text in text_fields:
1625
  if not text: continue
1626
  for match in tag_regex.finditer(text):
1627
  tag = match.group(0).upper()
1628
  if tag in image_lookup: unique_tags_to_embed.add(tag)
1629
+
1630
+ # List of tags that were successfully converted to LaTeX
1631
+ tags_converted_to_latex = set()
1632
+
1633
  for tag in sorted(list(unique_tags_to_embed)):
1634
  filepath = image_lookup[tag]
1635
+ # Get the base64 code for processing, whether we embed it or convert it to LaTeX
1636
  base64_code = get_base64_for_file(filepath)
1637
+
1638
+ # --- PIX2TEXT/EQUATION CONVERSION LOGIC START ---
1639
+ if tag.startswith('EQUATION') and p2t is not None:
1640
+ print(f" -> Converting EQUATION {tag} to LaTeX...")
1641
+ latex_code = get_latex_from_base64(base64_code)
1642
+
1643
+ # Replace the original tag (e.g., EQUATION1) in the item's text fields with LaTeX
1644
+ for key in ['question', 'passage', 'new_passage']:
1645
+ if item.get(key) and tag in item[key]:
1646
+ item[key] = item[key].replace(tag, latex_code)
1647
+
1648
+ if 'options' in item:
1649
+ for opt_key, opt_val in item['options'].items():
1650
+ if tag in opt_val:
1651
+ item['options'][opt_key] = opt_val.replace(tag, latex_code)
1652
+
1653
+ tags_converted_to_latex.add(tag)
1654
+ # Skip the embedding of the Base64 code for equations
1655
+ continue
1656
+ # --- PIX2TEXT/EQUATION CONVERSION LOGIC END ---
1657
+
1658
+ # Original logic (for figures): Embed the base64 code
1659
  base_key = tag.replace(' ', '').lower()
1660
  item[base_key] = base64_code
1661
+
1662
  final_structured_data.append(item)
1663
+
1664
+ print(f"✅ Image embedding complete. {len(tags_converted_to_latex)} equations converted to LaTeX.")
1665
  return final_structured_data
1666
 
1667
 
 
1669
  # --- MAIN FUNCTION ---
1670
  # ============================================================================
1671
 
1672
+ # def run_document_pipeline(input_pdf_path: str, layoutlmv3_model_path: str, label_studio_output_path: str) -> Optional[
1673
+ # List[Dict[str, Any]]]:
1674
+ # if not os.path.exists(input_pdf_path): return None
1675
+
1676
+ # print("\n" + "#" * 80)
1677
+ # print("### STARTING OPTIMIZED FULL DOCUMENT ANALYSIS PIPELINE ###")
1678
+ # print("#" * 80)
1679
+
1680
+ # pdf_name = os.path.splitext(os.path.basename(input_pdf_path))[0]
1681
+ # temp_pipeline_dir = os.path.join(tempfile.gettempdir(), f"pipeline_run_{pdf_name}_{os.getpid()}")
1682
+ # os.makedirs(temp_pipeline_dir, exist_ok=True)
1683
+
1684
+ # preprocessed_json_path = os.path.join(temp_pipeline_dir, f"{pdf_name}_preprocessed.json")
1685
+ # raw_output_path = os.path.join(temp_pipeline_dir, f"{pdf_name}_raw_predictions.json")
1686
+ # structured_intermediate_output_path = os.path.join(temp_pipeline_dir, f"{pdf_name}_structured_intermediate.json")
1687
+
1688
+ # final_result = None
1689
+ # try:
1690
+ # # Phase 1: Preprocessing with YOLO First + Masking
1691
+ # preprocessed_json_path_out = run_single_pdf_preprocessing(input_pdf_path, preprocessed_json_path)
1692
+ # if not preprocessed_json_path_out: return None
1693
+
1694
+ # # Phase 2: Inference
1695
+ # page_raw_predictions_list = run_inference_and_get_raw_words(
1696
+ # input_pdf_path, layoutlmv3_model_path, preprocessed_json_path_out
1697
+ # )
1698
+ # if not page_raw_predictions_list: return None
1699
+
1700
+ # with open(raw_output_path, 'w', encoding='utf-8') as f:
1701
+ # json.dump(page_raw_predictions_list, f, indent=4)
1702
+
1703
+ # # Phase 3: Decoding
1704
+ # structured_data_list = convert_bio_to_structured_json_relaxed(
1705
+ # raw_output_path, structured_intermediate_output_path
1706
+ # )
1707
+ # if not structured_data_list: return None
1708
+ # structured_data_list = correct_misaligned_options(structured_data_list)
1709
+ # structured_data_list = process_context_linking(structured_data_list)
1710
+
1711
+ # try:
1712
+ # convert_raw_predictions_to_label_studio(page_raw_predictions_list, label_studio_output_path)
1713
+ # except Exception as e:
1714
+ # print(f"❌ Error during Label Studio conversion: {e}")
1715
+
1716
+ # # Phase 4: Embedding
1717
+ # final_result = embed_images_as_base64_in_memory(structured_data_list, FIGURE_EXTRACTION_DIR)
1718
+
1719
+ # except Exception as e:
1720
+ # print(f"❌ FATAL ERROR: {e}")
1721
+ # import traceback
1722
+ # traceback.print_exc()
1723
+ # return None
1724
+
1725
+ # finally:
1726
+ # try:
1727
+ # for f in glob.glob(os.path.join(temp_pipeline_dir, '*')):
1728
+ # os.remove(f)
1729
+ # os.rmdir(temp_pipeline_dir)
1730
+ # except Exception:
1731
+ # pass
1732
+
1733
+ # print("\n" + "#" * 80)
1734
+ # print("### OPTIMIZED PIPELINE EXECUTION COMPLETE ###")
1735
+ # print("#" * 80)
1736
+ # return final_result
1737
+
1738
+
1739
+
1740
+
1741
+ def run_document_pipeline(input_pdf_path: str, layoutlmv3_model_path: str, label_studio_output_path: str, raw_predictions_output_path: str) -> Optional[
1742
  List[Dict[str, Any]]]:
1743
  if not os.path.exists(input_pdf_path): return None
1744
 
 
1766
  )
1767
  if not page_raw_predictions_list: return None
1768
 
1769
+ # --- DEBUG STEP: SAVE RAW PREDICTIONS ---
1770
+ # Save raw predictions to the temporary file
1771
  with open(raw_output_path, 'w', encoding='utf-8') as f:
1772
  json.dump(page_raw_predictions_list, f, indent=4)
1773
 
1774
+ # Explicitly copy/save the raw predictions to the user-specified debug path
1775
+ if raw_predictions_output_path:
1776
+ shutil.copy(raw_output_path, raw_predictions_output_path)
1777
+ print(f"\n✅ DEBUG: Raw predictions saved to: {raw_predictions_output_path}")
1778
+ # ----------------------------------------
1779
+
1780
  # Phase 3: Decoding
1781
  structured_data_list = convert_bio_to_structured_json_relaxed(
1782
  raw_output_path, structured_intermediate_output_path
 
1790
  except Exception as e:
1791
  print(f"❌ Error during Label Studio conversion: {e}")
1792
 
1793
+ # Phase 4: Embedding / Equation to LaTeX Conversion
1794
  final_result = embed_images_as_base64_in_memory(structured_data_list, FIGURE_EXTRACTION_DIR)
1795
 
1796
  except Exception as e:
 
1813
  return final_result
1814
 
1815
 
1816
+
1817
+
1818
+ # if __name__ == "__main__":
1819
+ # parser = argparse.ArgumentParser(description="Complete Pipeline")
1820
+ # parser.add_argument("--input_pdf", type=str, required=True, help="Input PDF")
1821
+ # parser.add_argument("--layoutlmv3_model_path", type=str, default=DEFAULT_LAYOUTLMV3_MODEL_PATH, help="Model Path")
1822
+ # parser.add_argument("--ls_output_path", type=str, default=None, help="Label Studio Output Path")
1823
+ # args = parser.parse_args()
1824
+
1825
+ # pdf_name = os.path.splitext(os.path.basename(args.input_pdf))[0]
1826
+ # final_output_path = os.path.abspath(f"{pdf_name}_final_output_embedded.json")
1827
+ # ls_output_path = os.path.abspath(
1828
+ # args.ls_output_path if args.ls_output_path else f"{pdf_name}_label_studio_tasks.json")
1829
+
1830
+ # final_json_data = run_document_pipeline(args.input_pdf, args.layoutlmv3_model_path, ls_output_path)
1831
+
1832
+ # if final_json_data:
1833
+ # with open(final_output_path, 'w', encoding='utf-8') as f:
1834
+ # json.dump(final_json_data, f, indent=2, ensure_ascii=False)
1835
+ # print(f"\n✅ Final Data Saved: {final_output_path}")
1836
+ # else:
1837
+ # print("\n❌ Pipeline Failed.")
1838
+ # sys.exit(1)
1839
+
1840
+
1841
+
1842
+
1843
  if __name__ == "__main__":
1844
  parser = argparse.ArgumentParser(description="Complete Pipeline")
1845
  parser.add_argument("--input_pdf", type=str, required=True, help="Input PDF")
1846
  parser.add_argument("--layoutlmv3_model_path", type=str, default=DEFAULT_LAYOUTLMV3_MODEL_PATH, help="Model Path")
1847
  parser.add_argument("--ls_output_path", type=str, default=None, help="Label Studio Output Path")
1848
+ # --- ADDED ARGUMENT FOR DEBUGGING ---
1849
+ parser.add_argument("--raw_preds_path", type=str, default='BIO_debug.json',
1850
+ help="Debug path for raw BIO tag predictions (JSON).")
1851
+ # ------------------------------------
1852
  args = parser.parse_args()
1853
 
1854
  pdf_name = os.path.splitext(os.path.basename(args.input_pdf))[0]
1855
  final_output_path = os.path.abspath(f"{pdf_name}_final_output_embedded.json")
1856
  ls_output_path = os.path.abspath(
1857
  args.ls_output_path if args.ls_output_path else f"{pdf_name}_label_studio_tasks.json")
1858
+ # --- CALCULATE RAW PREDICTIONS OUTPUT PATH ---
1859
+ raw_predictions_output_path = os.path.abspath(
1860
+ args.raw_preds_path if args.raw_preds_path else f"{pdf_name}_raw_predictions_debug.json")
1861
+ # ---------------------------------------------
1862
+
1863
+ # --- UPDATED FUNCTION CALL ---
1864
+ final_json_data = run_document_pipeline(
1865
+ args.input_pdf,
1866
+ args.layoutlmv3_model_path,
1867
+ ls_output_path,
1868
+ raw_predictions_output_path # Pass the new argument
1869
+ )
1870
+ # -----------------------------
1871
 
1872
  if final_json_data:
1873
  with open(final_output_path, 'w', encoding='utf-8') as f: