heerjtdev commited on
Commit
99e2da8
Β·
verified Β·
1 Parent(s): b0b67f8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +527 -16
app.py CHANGED
@@ -585,6 +585,520 @@
585
 
586
 
587
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
588
  import base64
589
  from PIL import Image
590
  import re
@@ -603,14 +1117,12 @@ import io
603
  import json
604
 
605
  # ============================================================================
606
- # --- Global Setup and Configuration ---
607
  # ============================================================================
608
 
609
- # Configure logging to write to a string buffer for display in the report
610
  log_stream = io.StringIO()
611
  logging.basicConfig(level=logging.WARNING, stream=log_stream, format='%(levelname)s:%(message)s')
612
 
613
- # Patch torch.load to prevent weights_only error with older models
614
  _original_torch_load = torch.load
615
  def patched_torch_load(*args, **kwargs):
616
  kwargs["weights_only"] = False
@@ -620,7 +1132,6 @@ torch.load = patched_torch_load
620
  WEIGHTS_PATH = 'best.pt'
621
  SCALE_FACTOR = 2.0
622
 
623
- # --- OCR Model Initialization ---
624
  from transformers import TrOCRProcessor
625
  from optimum.onnxruntime import ORTModelForVision2Seq
626
 
@@ -635,14 +1146,13 @@ except Exception as e:
635
  ort_model = None
636
  OCR_MODEL_LOADED = False
637
 
638
- # Detection parameters
639
  CONF_THRESHOLD = 0.2
640
  TARGET_CLASSES = ['figure', 'equation']
641
  IOU_MERGE_THRESHOLD = 0.4
642
  IOA_SUPPRESSION_THRESHOLD = 0.7
643
 
644
  # ============================================================================
645
- # --- BOX COMBINATION LOGIC (FIXED) ---
646
  # ============================================================================
647
 
648
  def calculate_iou(box1, box2):
@@ -685,9 +1195,9 @@ def filter_nested_boxes(detections, ioa_threshold=0.80):
685
  return [detections[i] for i in keep_indices]
686
 
687
 
 
688
  def merge_overlapping_boxes(detections, iou_threshold):
689
  if not detections: return []
690
- # 1. Sort by confidence (YOLO standard)
691
  detections.sort(key=lambda d: d['conf'], reverse=True)
692
  merged_detections = []
693
  is_merged = [False] * len(detections)
@@ -709,16 +1219,15 @@ def merge_overlapping_boxes(detections, iou_threshold):
709
  is_merged[j] = True
710
  merged_detections.append({
711
  'coords': (merged_x1, merged_y1, merged_x2, merged_y2),
712
- # 'y1' is retained for clarity, though 'coords' contains it
713
  'y1': merged_y1,
714
  'class': current_class,
715
  'conf': detections[i]['conf']
716
  })
717
 
718
- # --- FIX IMPLEMENTATION: READING ORDER SORT ---
719
- # Sort primarily by y1 (vertical position), secondarily by x1 (horizontal position).
720
- # This correctly handles two-column layouts like Q.10 options (A), (B), (C), (D)
721
- merged_detections.sort(key=lambda d: (d['coords'][1], d['coords'][0]))
722
 
723
  return merged_detections
724
 
@@ -786,6 +1295,7 @@ def get_latex_from_base64(base64_string: str) -> str:
786
  return f"[TR_OCR_ERROR: {e}]"
787
 
788
 
 
789
  def run_yolo_detection_and_count(
790
  image: np.ndarray, model: YOLO, page_num: int,
791
  current_eq_count: int, current_fig_count: int
@@ -800,7 +1310,7 @@ def run_yolo_detection_and_count(
800
 
801
  detected_items: List[Dict[str, Union[Image.Image, str, Tuple[float,...]]]] = []
802
  yolo_detections = []
803
-
804
  try:
805
  results = model.predict(image, conf=CONF_THRESHOLD, verbose=False)
806
  if results and results[0].boxes:
@@ -817,10 +1327,11 @@ def run_yolo_detection_and_count(
817
  logging.error(f"ERROR: YOLO inference failed on page {page_num}: {e}")
818
  return [], eq_counter, fig_counter
819
 
 
820
  merged_detections = merge_overlapping_boxes(yolo_detections, IOU_MERGE_THRESHOLD)
821
  final_detections = filter_nested_boxes(merged_detections, IOA_SUPPRESSION_THRESHOLD)
822
 
823
- # Note: final_detections is now sorted by (y1, x1) in reading order.
824
 
825
  for det in final_detections:
826
  bbox = det["coords"]
@@ -1070,9 +1581,9 @@ if __name__ == "__main__":
1070
  output_structured_latex,
1071
  output_gallery
1072
  ],
1073
- title="πŸ“Š YOLO Detection & Math OCR Pipeline (Reading Order Fix)",
1074
  description=(
1075
- "Upload a PDF. YOLO detects equations/figures, and OCR converts equations to LaTeX. Now includes a fix for two-column reading order."
1076
  ),
1077
  )
1078
 
 
585
 
586
 
587
 
588
+
589
+
590
+ # import base64
591
+ # from PIL import Image
592
+ # import re
593
+ # import fitz # PyMuPDF
594
+ # import numpy as np
595
+ # import cv2
596
+ # import torch
597
+ # import torch.serialization
598
+ # import os
599
+ # import time
600
+ # from typing import Optional, Tuple, List, Dict, Any, Union
601
+ # from ultralytics import YOLO
602
+ # import logging
603
+ # import gradio as gr
604
+ # import io
605
+ # import json
606
+
607
+ # # ============================================================================
608
+ # # --- Global Setup and Configuration ---
609
+ # # ============================================================================
610
+
611
+ # # Configure logging to write to a string buffer for display in the report
612
+ # log_stream = io.StringIO()
613
+ # logging.basicConfig(level=logging.WARNING, stream=log_stream, format='%(levelname)s:%(message)s')
614
+
615
+ # # Patch torch.load to prevent weights_only error with older models
616
+ # _original_torch_load = torch.load
617
+ # def patched_torch_load(*args, **kwargs):
618
+ # kwargs["weights_only"] = False
619
+ # return _original_torch_load(*args, **kwargs)
620
+ # torch.load = patched_torch_load
621
+
622
+ # WEIGHTS_PATH = 'best.pt'
623
+ # SCALE_FACTOR = 2.0
624
+
625
+ # # --- OCR Model Initialization ---
626
+ # from transformers import TrOCRProcessor
627
+ # from optimum.onnxruntime import ORTModelForVision2Seq
628
+
629
+ # MODEL_NAME = 'breezedeus/pix2text-mfr-1.5'
630
+ # try:
631
+ # processor = TrOCRProcessor.from_pretrained(MODEL_NAME)
632
+ # ort_model = ORTModelForVision2Seq.from_pretrained(MODEL_NAME, use_cache=False)
633
+ # OCR_MODEL_LOADED = True
634
+ # except Exception as e:
635
+ # logging.warning(f"OCR model loading failed: {e}")
636
+ # processor = None
637
+ # ort_model = None
638
+ # OCR_MODEL_LOADED = False
639
+
640
+ # # Detection parameters
641
+ # CONF_THRESHOLD = 0.2
642
+ # TARGET_CLASSES = ['figure', 'equation']
643
+ # IOU_MERGE_THRESHOLD = 0.4
644
+ # IOA_SUPPRESSION_THRESHOLD = 0.7
645
+
646
+ # # ============================================================================
647
+ # # --- BOX COMBINATION LOGIC (FIXED) ---
648
+ # # ============================================================================
649
+
650
+ # def calculate_iou(box1, box2):
651
+ # x1_a, y1_a, x2_a, y2_a = box1
652
+ # x1_b, y1_b, x2_b, y2_b = box2
653
+ # x_left = max(x1_a, x1_b)
654
+ # y_top = max(y1_a, y1_b)
655
+ # x_right = min(x2_a, x2_b)
656
+ # y_bottom = min(y2_a, y2_b)
657
+ # intersection_area = max(0, x_right - x_left) * max(0, y_bottom - y_top)
658
+ # box_a_area = (x2_a - x1_a) * (y2_a - y1_a)
659
+ # box_b_area = (x2_b - x1_b) * (y2_b - y1_b)
660
+ # union_area = float(box_a_area + box_b_area - intersection_area)
661
+ # return intersection_area / union_area if union_area > 0 else 0
662
+
663
+
664
+ # def filter_nested_boxes(detections, ioa_threshold=0.80):
665
+ # if not detections: return []
666
+ # for d in detections:
667
+ # x1, y1, x2, y2 = d['coords']
668
+ # d['area'] = (x2 - x1) * (y2 - y1)
669
+ # detections.sort(key=lambda x: x['area'], reverse=True)
670
+ # keep_indices = []
671
+ # is_suppressed = [False] * len(detections)
672
+ # for i in range(len(detections)):
673
+ # if is_suppressed[i]: continue
674
+ # keep_indices.append(i)
675
+ # box_a = detections[i]['coords']
676
+ # for j in range(i + 1, len(detections)):
677
+ # if is_suppressed[j]: continue
678
+ # box_b = detections[j]['coords']
679
+ # x_left = max(box_a[0], box_b[0])
680
+ # y_top = max(box_a[1], box_b[1])
681
+ # x_right = min(box_a[2], box_b[2])
682
+ # y_bottom = min(box_a[3], box_b[3])
683
+ # intersection = max(0, x_right - x_left) * max(0, y_bottom - y_top)
684
+ # area_b = detections[j]['area']
685
+ # if area_b > 0 and intersection / area_b > ioa_threshold:
686
+ # is_suppressed[j] = True
687
+ # return [detections[i] for i in keep_indices]
688
+
689
+
690
+ # def merge_overlapping_boxes(detections, iou_threshold):
691
+ # if not detections: return []
692
+ # # 1. Sort by confidence (YOLO standard)
693
+ # detections.sort(key=lambda d: d['conf'], reverse=True)
694
+ # merged_detections = []
695
+ # is_merged = [False] * len(detections)
696
+
697
+ # for i in range(len(detections)):
698
+ # if is_merged[i]: continue
699
+ # current_box = detections[i]['coords']
700
+ # current_class = detections[i]['class']
701
+ # merged_x1, merged_y1, merged_x2, merged_y2 = current_box
702
+ # for j in range(i + 1, len(detections)):
703
+ # if is_merged[j] or detections[j]['class'] != current_class: continue
704
+ # other_box = detections[j]['coords']
705
+ # iou = calculate_iou(current_box, other_box)
706
+ # if iou > iou_threshold:
707
+ # merged_x1 = min(merged_x1, other_box[0])
708
+ # merged_y1 = min(merged_y1, other_box[1])
709
+ # merged_x2 = max(merged_x2, other_box[2])
710
+ # merged_y2 = max(other_box[3], other_box[3])
711
+ # is_merged[j] = True
712
+ # merged_detections.append({
713
+ # 'coords': (merged_x1, merged_y1, merged_x2, merged_y2),
714
+ # # 'y1' is retained for clarity, though 'coords' contains it
715
+ # 'y1': merged_y1,
716
+ # 'class': current_class,
717
+ # 'conf': detections[i]['conf']
718
+ # })
719
+
720
+ # # --- FIX IMPLEMENTATION: READING ORDER SORT ---
721
+ # # Sort primarily by y1 (vertical position), secondarily by x1 (horizontal position).
722
+ # # This correctly handles two-column layouts like Q.10 options (A), (B), (C), (D)
723
+ # merged_detections.sort(key=lambda d: (d['coords'][1], d['coords'][0]))
724
+
725
+ # return merged_detections
726
+
727
+ # # ============================================================================
728
+ # # --- UTILITY FUNCTIONS (Retained) ---
729
+ # # ============================================================================
730
+
731
+ # def pixmap_to_numpy(pix: fitz.Pixmap) -> np.ndarray:
732
+ # """Converts a PyMuPDF Pixmap to a NumPy array for OpenCV/YOLO."""
733
+ # img = np.frombuffer(pix.samples, dtype=np.uint8).reshape(
734
+ # (pix.h, pix.w, pix.n)
735
+ # )
736
+ # if pix.n == 4:
737
+ # img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
738
+ # elif pix.n == 1:
739
+ # img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
740
+ # return img
741
+
742
+
743
+ # def crop_and_convert_to_pil(image: np.ndarray, bbox: Tuple[float, float, float, float]) -> Image.Image:
744
+ # """Crops the numpy array and returns a PIL Image object."""
745
+ # x1, y1, x2, y2 = map(int, bbox)
746
+ # h, w, _ = image.shape
747
+
748
+ # x1 = max(0, x1)
749
+ # y1 = max(0, y1)
750
+ # x2 = min(w, x2)
751
+ # y2 = min(h, y2)
752
+
753
+ # crop_np = image[y1:y2, x1:x2]
754
+ # crop_pil = Image.fromarray(cv2.cvtColor(crop_np, cv2.COLOR_BGR2RGB))
755
+
756
+ # return crop_pil
757
+
758
+
759
+ # def pil_to_base64(img: Image.Image) -> str:
760
+ # """Converts a PIL Image object to a Base64 encoded string (PNG format) for OCR input."""
761
+ # buffer = io.BytesIO()
762
+ # img.save(buffer, format="PNG")
763
+ # return base64.b64encode(buffer.getvalue()).decode("utf-8")
764
+
765
+
766
+ # def get_latex_from_base64(base64_string: str) -> str:
767
+ # """Performs the OCR conversion using the globally loaded model."""
768
+ # if not OCR_MODEL_LOADED:
769
+ # return "[MODEL_ERROR: Model not loaded]"
770
+
771
+ # try:
772
+ # image_data = base64.b64decode(base64_string)
773
+ # image = Image.open(io.BytesIO(image_data)).convert('RGB')
774
+
775
+ # pixel_values = processor(images=image, return_tensors="pt").pixel_values
776
+ # generated_ids = ort_model.generate(pixel_values)
777
+ # raw_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
778
+
779
+ # if not raw_text:
780
+ # return "[OCR_WARNING: No formula found]"
781
+
782
+ # latex = raw_text[0]
783
+ # latex = re.sub(r'[\r\n]+', '', latex)
784
+
785
+ # return latex
786
+
787
+ # except Exception as e:
788
+ # return f"[TR_OCR_ERROR: {e}]"
789
+
790
+
791
+ # def run_yolo_detection_and_count(
792
+ # image: np.ndarray, model: YOLO, page_num: int,
793
+ # current_eq_count: int, current_fig_count: int
794
+ # ) -> Tuple[List[Dict[str, Union[Image.Image, str, Tuple[float,...]]]], int, int]:
795
+ # """
796
+ # Performs YOLO detection and returns a list of detected item dictionaries
797
+ # and the updated total counters.
798
+ # """
799
+
800
+ # eq_counter = current_eq_count
801
+ # fig_counter = current_fig_count
802
+
803
+ # detected_items: List[Dict[str, Union[Image.Image, str, Tuple[float,...]]]] = []
804
+ # yolo_detections = []
805
+
806
+ # try:
807
+ # results = model.predict(image, conf=CONF_THRESHOLD, verbose=False)
808
+ # if results and results[0].boxes:
809
+ # for box in results[0].boxes.data.tolist():
810
+ # x1, y1, x2, y2, conf, cls_id = box
811
+ # cls_name = model.names[int(cls_id)]
812
+ # if cls_name in TARGET_CLASSES:
813
+ # yolo_detections.append({
814
+ # 'coords': (x1, y1, x2, y2),
815
+ # 'class': cls_name,
816
+ # 'conf': conf
817
+ # })
818
+ # except Exception as e:
819
+ # logging.error(f"ERROR: YOLO inference failed on page {page_num}: {e}")
820
+ # return [], eq_counter, fig_counter
821
+
822
+ # merged_detections = merge_overlapping_boxes(yolo_detections, IOU_MERGE_THRESHOLD)
823
+ # final_detections = filter_nested_boxes(merged_detections, IOA_SUPPRESSION_THRESHOLD)
824
+
825
+ # # Note: final_detections is now sorted by (y1, x1) in reading order.
826
+
827
+ # for det in final_detections:
828
+ # bbox = det["coords"]
829
+ # crop_pil = crop_and_convert_to_pil(image, bbox)
830
+
831
+ # item = {
832
+ # "type": det["class"],
833
+ # "coords": bbox,
834
+ # "pil_image": crop_pil,
835
+ # }
836
+
837
+ # if det["class"] == "equation":
838
+ # eq_counter += 1
839
+ # item["id"] = f"EQUATION{eq_counter}"
840
+ # item["latex"] = ""
841
+ # elif det["class"] == "figure":
842
+ # fig_counter += 1
843
+ # item["id"] = f"FIGURE{fig_counter}"
844
+ # item["latex"] = "[FIGURE - No LaTeX]"
845
+
846
+ # detected_items.append(item)
847
+
848
+ # return detected_items, eq_counter, fig_counter
849
+
850
+
851
+ # # ============================================================================
852
+ # # --- MAIN DOCUMENT PROCESSING FUNCTION (Retained Logic) ---
853
+ # # ============================================================================
854
+
855
+ # def run_single_pdf_preprocessing(
856
+ # pdf_path: str
857
+ # ) -> Tuple[int, int, int, str, float, Dict[str, Union[int, str]], List[Tuple[Image.Image, str]]]:
858
+ # """
859
+ # Runs the pipeline, performs OCR, and returns final results.
860
+ # """
861
+
862
+ # log_stream.truncate(0)
863
+ # log_stream.seek(0)
864
+
865
+ # start_time = time.time()
866
+
867
+ # all_extracted_items: List[Dict[str, Union[Image.Image, str]]] = []
868
+
869
+ # total_figure_count = 0
870
+ # total_equation_count = 0
871
+
872
+
873
+ # # 1. Validation and Model Loading (YOLO)
874
+ # t0 = time.time()
875
+ # if not os.path.exists(pdf_path):
876
+ # report = f"❌ FATAL ERROR: Input PDF not found at {pdf_path}."
877
+ # return 0, 0, 0, report, time.time() - start_time, {}, []
878
+
879
+ # try:
880
+ # model = YOLO(WEIGHTS_PATH)
881
+ # logging.warning(f"INFO: Loaded YOLO model from: {WEIGHTS_PATH}")
882
+ # except Exception as e:
883
+ # report = f"❌ ERROR loading YOLO model: {e}\n(Ensure 'best.pt' is available and valid.)"
884
+ # return 0, 0, 0, report, time.time() - start_time, {}, []
885
+ # t1 = time.time()
886
+ # logging.warning(f"INFO: Model Loading Time: {t1-t0:.4f}s")
887
+
888
+ # # 2. PDF Loading (fitz)
889
+ # t2 = time.time()
890
+ # try:
891
+ # doc = fitz.open(pdf_path)
892
+ # total_pages = doc.page_count
893
+ # logging.warning(f"INFO: Opened PDF with {doc.page_count} pages")
894
+ # except Exception as e:
895
+ # report = f"❌ ERROR loading PDF file: {e}"
896
+ # return 0, 0, 0, report, time.time() - start_time, {}, []
897
+ # t3 = time.time()
898
+ # logging.warning(f"INFO: PDF Initialization Time: {t3-t2:.4f}s")
899
+
900
+ # mat = fitz.Matrix(SCALE_FACTOR, SCALE_FACTOR)
901
+
902
+ # # 3. Page Processing, Detection, and OCR Loop
903
+ # t4 = time.time()
904
+ # for page_num_0_based in range(doc.page_count):
905
+ # page_start_time = time.time()
906
+ # fitz_page = doc.load_page(page_num_0_based)
907
+ # page_num = page_num_0_based + 1
908
+
909
+ # # Render page to image for YOLO
910
+ # try:
911
+ # pix_start = time.time()
912
+ # pix = fitz_page.get_pixmap(matrix=mat)
913
+ # original_img = pixmap_to_numpy(pix)
914
+ # pix_time = time.time() - pix_start
915
+ # except Exception as e:
916
+ # logging.error(f"ERROR: Error converting page {page_num} to image: {e}. Skipping.")
917
+ # continue
918
+
919
+ # # YOLO Detection
920
+ # detect_start = time.time()
921
+ # (
922
+ # page_extracted_items,
923
+ # total_equation_count,
924
+ # total_figure_count
925
+ # ) = run_yolo_detection_and_count(
926
+ # original_img,
927
+ # model,
928
+ # page_num,
929
+ # total_equation_count,
930
+ # total_figure_count
931
+ # )
932
+ # detect_time = time.time() - detect_start
933
+
934
+ # # --- OCR/LaTeX Conversion and Logging ---
935
+ # ocr_total_time = 0
936
+ # page_equations = 0
937
+
938
+ # for item in page_extracted_items:
939
+ # if item["type"] == "equation":
940
+ # page_equations += 1
941
+ # ocr_start = time.time()
942
+
943
+ # b64_string = pil_to_base64(item["pil_image"])
944
+ # item["latex"] = get_latex_from_base64(b64_string)
945
+
946
+ # ocr_time = time.time() - ocr_start
947
+ # ocr_total_time += ocr_time
948
+
949
+ # logging.warning(f"LATEX: Page {page_num}, ID {item['id']} -> Time: {ocr_time:.4f}s, Formula: {item['latex'][:50]}...")
950
+
951
+ # all_extracted_items.extend(page_extracted_items)
952
+
953
+ # page_figures = sum(1 for item in page_extracted_items if item["type"] == "figure")
954
+
955
+ # page_total_time = time.time() - page_start_time
956
+ # logging.warning(f"SUMMARY: Page {page_num}: EQs={page_equations}, Figs={page_figures} | Page Time: {page_total_time:.4f}s (Detect={detect_time:.4f}s, OCR Total={ocr_total_time:.4f}s)")
957
+
958
+ # doc.close()
959
+ # t5 = time.time()
960
+ # detection_loop_time = t5 - t4
961
+ # logging.warning(f"INFO: Total Detection and OCR Loop Time ({total_pages} pages): {detection_loop_time:.4f}s")
962
+
963
+ # # 4. Final Report Generation and Gallery Formatting
964
+
965
+ # # Create the structured JSON output as requested by the user
966
+ # structured_latex_output = {
967
+ # "Total Pages": total_pages,
968
+ # "Total Equations": total_equation_count,
969
+ # }
970
+ # for item in all_extracted_items:
971
+ # if item["type"] == "equation":
972
+ # # Map EQUATION ID to LaTeX code
973
+ # structured_latex_output[item["id"]] = item["latex"]
974
+
975
+
976
+ # # Format the extracted items for the Gradio Gallery
977
+ # gallery_items: List[Tuple[Image.Image, str]] = []
978
+
979
+ # for item in all_extracted_items:
980
+ # image_label = item["id"]
981
+ # if item["type"] == "equation":
982
+ # image_label = f'{item["id"]}: {item["latex"]}'
983
+
984
+ # gallery_items.append((item["pil_image"], image_label))
985
+
986
+
987
+ # total_execution_time = t5 - start_time
988
+
989
+ # full_log = log_stream.getvalue()
990
+
991
+ # report = (
992
+ # f"βœ… **YOLO Counting & OCR Complete!**\n\n"
993
+ # f"**1) Total Pages Detected in PDF:** **{total_pages}**\n"
994
+ # f"**2) Total Equations Detected:** **{total_equation_count}**\n"
995
+ # f"**3) Total Figures Detected:** **{total_figure_count}**\n"
996
+ # f"---\n"
997
+ # f"**4) Total Execution Time:** **{total_execution_time:.4f}s**\n"
998
+ # f"### Full Processing Log\n"
999
+ # f"```text\n"
1000
+ # f"{full_log}"
1001
+ # f"\n```"
1002
+ # )
1003
+
1004
+ # # Return the new structured_latex_output instead of the page counts
1005
+ # return total_pages, total_equation_count, total_figure_count, report, total_execution_time, structured_latex_output, gallery_items
1006
+
1007
+
1008
+ # # ============================================================================
1009
+ # # --- GRADIO INTERFACE FUNCTION & DEFINITION (Retained) ---
1010
+ # # ============================================================================
1011
+
1012
+ # def gradio_process_pdf(pdf_file) -> Tuple[str, str, str, str, Dict[str, Union[int, str]], List[Tuple[Image.Image, str]]]:
1013
+ # """Gradio wrapper function to handle file upload and return results."""
1014
+ # if pdf_file is None:
1015
+ # return "N/A", "N/A", "N/A", "Please upload a PDF file.", {}, []
1016
+
1017
+ # pdf_path = pdf_file.name
1018
+
1019
+ # try:
1020
+ # (
1021
+ # num_pages,
1022
+ # num_equations,
1023
+ # num_figures,
1024
+ # report,
1025
+ # total_time,
1026
+ # structured_latex_output,
1027
+ # gallery_items
1028
+ # ) = run_single_pdf_preprocessing(pdf_path)
1029
+
1030
+
1031
+ # return str(num_pages), str(num_equations), str(num_figures), report, structured_latex_output, gallery_items
1032
+
1033
+
1034
+ # except Exception as e:
1035
+ # error_msg = f"An unexpected error occurred: {e}"
1036
+ # logging.error(f"FATAL: {error_msg}", exc_info=True)
1037
+ # full_log = log_stream.getvalue()
1038
+ # error_report = f"❌ CRITICAL ERROR:\n{error_msg}\n\n### Log up to Failure\n```text\n{full_log}\n```"
1039
+ # return "Error", "Error", "Error", error_report, {}, []
1040
+
1041
+
1042
+ # if __name__ == "__main__":
1043
+
1044
+ # if not os.path.exists(WEIGHTS_PATH):
1045
+ # logging.error(f"❌ FATAL ERROR: YOLO weight file '{WEIGHTS_PATH}' not found. Cannot run live inference.")
1046
+
1047
+ # input_file = gr.File(label="Upload PDF Document", type="filepath", file_types=[".pdf"])
1048
+
1049
+ # output_pages = gr.Textbox(label="Total Pages in PDF", interactive=False)
1050
+ # output_equations = gr.Textbox(label="Total Equations Detected", interactive=False)
1051
+ # output_figures = gr.Textbox(label="Total Figures Detected", interactive=False)
1052
+ # output_report = gr.Markdown(label="Processing Summary and Full Log")
1053
+
1054
+ # output_structured_latex = gr.JSON(label="Structured LaTeX Output (EQUATIONx : <latex code>)")
1055
+
1056
+ # output_gallery = gr.Gallery(
1057
+ # label="Detected Items (with Extracted LaTeX)",
1058
+ # columns=3,
1059
+ # height="auto",
1060
+ # object_fit="contain",
1061
+ # allow_preview=False
1062
+ # )
1063
+
1064
+ # interface = gr.Interface(
1065
+ # fn=gradio_process_pdf,
1066
+ # inputs=input_file,
1067
+ # outputs=[
1068
+ # output_pages,
1069
+ # output_equations,
1070
+ # output_figures,
1071
+ # output_report,
1072
+ # output_structured_latex,
1073
+ # output_gallery
1074
+ # ],
1075
+ # title="πŸ“Š YOLO Detection & Math OCR Pipeline (Reading Order Fix)",
1076
+ # description=(
1077
+ # "Upload a PDF. YOLO detects equations/figures, and OCR converts equations to LaTeX. Now includes a fix for two-column reading order."
1078
+ # ),
1079
+ # )
1080
+
1081
+ # print("\nStarting Gradio application...")
1082
+ # interface.launch(inbrowser=True)
1083
+
1084
+
1085
+
1086
+
1087
+
1088
+
1089
+
1090
+
1091
+
1092
+
1093
+
1094
+
1095
+
1096
+
1097
+
1098
+
1099
+
1100
+
1101
+
1102
  import base64
1103
  from PIL import Image
1104
  import re
 
1117
  import json
1118
 
1119
  # ============================================================================
1120
+ # --- Global Setup and Configuration (Retained) ---
1121
  # ============================================================================
1122
 
 
1123
  log_stream = io.StringIO()
1124
  logging.basicConfig(level=logging.WARNING, stream=log_stream, format='%(levelname)s:%(message)s')
1125
 
 
1126
  _original_torch_load = torch.load
1127
  def patched_torch_load(*args, **kwargs):
1128
  kwargs["weights_only"] = False
 
1132
  WEIGHTS_PATH = 'best.pt'
1133
  SCALE_FACTOR = 2.0
1134
 
 
1135
  from transformers import TrOCRProcessor
1136
  from optimum.onnxruntime import ORTModelForVision2Seq
1137
 
 
1146
  ort_model = None
1147
  OCR_MODEL_LOADED = False
1148
 
 
1149
  CONF_THRESHOLD = 0.2
1150
  TARGET_CLASSES = ['figure', 'equation']
1151
  IOU_MERGE_THRESHOLD = 0.4
1152
  IOA_SUPPRESSION_THRESHOLD = 0.7
1153
 
1154
  # ============================================================================
1155
+ # --- BOX COMBINATION LOGIC (PURE VERTICAL FIX) ---
1156
  # ============================================================================
1157
 
1158
  def calculate_iou(box1, box2):
 
1195
  return [detections[i] for i in keep_indices]
1196
 
1197
 
1198
+ # --- UPDATED: page_width argument removed ---
1199
  def merge_overlapping_boxes(detections, iou_threshold):
1200
  if not detections: return []
 
1201
  detections.sort(key=lambda d: d['conf'], reverse=True)
1202
  merged_detections = []
1203
  is_merged = [False] * len(detections)
 
1219
  is_merged[j] = True
1220
  merged_detections.append({
1221
  'coords': (merged_x1, merged_y1, merged_x2, merged_y2),
 
1222
  'y1': merged_y1,
1223
  'class': current_class,
1224
  'conf': detections[i]['conf']
1225
  })
1226
 
1227
+ # --- PURE VERTICAL FIX IMPLEMENTATION ---
1228
+ # Sort ONLY by the top y-coordinate (coords[1]).
1229
+ # This ignores horizontal position and any complex layout.
1230
+ merged_detections.sort(key=lambda d: d['coords'][1])
1231
 
1232
  return merged_detections
1233
 
 
1295
  return f"[TR_OCR_ERROR: {e}]"
1296
 
1297
 
1298
+ # --- UPDATED: page width argument removed from signature and call ---
1299
  def run_yolo_detection_and_count(
1300
  image: np.ndarray, model: YOLO, page_num: int,
1301
  current_eq_count: int, current_fig_count: int
 
1310
 
1311
  detected_items: List[Dict[str, Union[Image.Image, str, Tuple[float,...]]]] = []
1312
  yolo_detections = []
1313
+
1314
  try:
1315
  results = model.predict(image, conf=CONF_THRESHOLD, verbose=False)
1316
  if results and results[0].boxes:
 
1327
  logging.error(f"ERROR: YOLO inference failed on page {page_num}: {e}")
1328
  return [], eq_counter, fig_counter
1329
 
1330
+ # Call merge_overlapping_boxes without page_width
1331
  merged_detections = merge_overlapping_boxes(yolo_detections, IOU_MERGE_THRESHOLD)
1332
  final_detections = filter_nested_boxes(merged_detections, IOA_SUPPRESSION_THRESHOLD)
1333
 
1334
+ # Note: final_detections is now sorted purely by y1
1335
 
1336
  for det in final_detections:
1337
  bbox = det["coords"]
 
1581
  output_structured_latex,
1582
  output_gallery
1583
  ],
1584
+ title="πŸ“Š YOLO Detection & Math OCR Pipeline (Pure Vertical Sort)",
1585
  description=(
1586
+ "Upload a PDF. YOLO detects equations/figures, and OCR converts equations to LaTeX. The output is now strictly sorted by the top bounding box Y-coordinate."
1587
  ),
1588
  )
1589