deepkansara-123 commited on
Commit
0f3c560
Β·
verified Β·
1 Parent(s): 994b14b

Update working_yolo_pipeline.py

Browse files
Files changed (1) hide show
  1. working_yolo_pipeline.py +126 -80
working_yolo_pipeline.py CHANGED
@@ -2761,100 +2761,146 @@ def embed_images_as_base64_in_memory(structured_data: List[Dict[str, Any]], figu
2761
 
2762
  # def run_document_pipeline(input_pdf_path: str, layoutlmv3_model_path: str) -> Optional[
2763
  # List[Dict[str, Any]]]:
2764
- def run_document_pipeline( input_pdf_path: str, layoutlmv3_model_path: str, structured_intermediate_output_path: Optional[str] = None) -> Optional[List[Dict[str, Any]]]:
2765
- if not os.path.exists(input_pdf_path): return None
2766
 
2767
- print("\n" + "#" * 80)
2768
- print("### STARTING OPTIMIZED FULL DOCUMENT ANALYSIS PIPELINE ###")
2769
- print("#" * 80)
2770
 
2771
- pdf_name = os.path.splitext(os.path.basename(input_pdf_path))[0]
2772
- temp_pipeline_dir = os.path.join(tempfile.gettempdir(), f"pipeline_run_{pdf_name}_{os.getpid()}")
2773
- os.makedirs(temp_pipeline_dir, exist_ok=True)
2774
 
2775
- preprocessed_json_path = os.path.join(temp_pipeline_dir, f"{pdf_name}_preprocessed.json")
2776
- raw_output_path = os.path.join(temp_pipeline_dir, f"{pdf_name}_raw_predictions.json")
2777
- structured_intermediate_output_path = os.path.join(temp_pipeline_dir, f"{pdf_name}_structured_intermediate.json")
 
 
 
2778
 
2779
- final_result = None
2780
- try:
2781
- # Phase 1: Preprocessing with YOLO First + Masking
2782
- preprocessed_json_path_out = run_single_pdf_preprocessing(input_pdf_path, preprocessed_json_path)
2783
- if not preprocessed_json_path_out: return None
2784
 
2785
- # Phase 2: Inference
2786
- page_raw_predictions_list = run_inference_and_get_raw_words(
2787
- input_pdf_path, layoutlmv3_model_path, preprocessed_json_path_out
2788
- )
2789
- if not page_raw_predictions_list: return None
2790
-
2791
- # --- DEBUG STEP: SAVE RAW PREDICTIONS ---
2792
- # Save raw predictions to the temporary file
2793
- with open(raw_output_path, 'w', encoding='utf-8') as f:
2794
- json.dump(page_raw_predictions_list, f, indent=4)
2795
-
2796
- # Explicitly copy/save the raw predictions to the user-specified debug path
2797
- # if raw_predictions_output_path:
2798
- # shutil.copy(raw_output_path, raw_predictions_output_path)
2799
- # print(f"\nβœ… DEBUG: Raw predictions saved to: {raw_predictions_output_path}")
2800
- # ----------------------------------------
2801
-
2802
- # Phase 3: Decoding
2803
- structured_data_list = convert_bio_to_structured_json_relaxed(
2804
- raw_output_path, structured_intermediate_output_path
2805
  )
2806
- if not structured_data_list: return None
2807
- structured_data_list = correct_misaligned_options(structured_data_list)
2808
- structured_data_list = process_context_linking(structured_data_list)
2809
 
 
 
 
2810
 
2811
- # Phase 4: Embedding / Equation to LaTeX Conversion
2812
- final_result = embed_images_as_base64_in_memory(structured_data_list, FIGURE_EXTRACTION_DIR)
 
 
2813
 
 
2814
 
2815
 
2816
-
2817
- #================================================================================
2818
- # --- NEW FINAL STEP: HIERARCHICAL CLASSIFICATION TAGGING ---
2819
- #================================================================================
2820
-
2821
- print("\n" + "=" * 80)
2822
- print("--- FINAL STEP: HIERARCHICAL SUBJECT/CONCEPT TAGGING ---")
2823
- print("=" * 80)
2824
-
2825
- # 1. Initialize and Load the Classifier
2826
- classifier = HierarchicalClassifier()
2827
- if classifier.load_models():
2828
- # 2. Run Classification on the *Final* Result
2829
- # The function modifies the list in place and returns it
2830
- final_result = post_process_json_with_inference(
2831
- final_result, classifier
2832
- )
2833
- print("βœ… Classification complete. Tags added to final output.")
2834
- else:
2835
- print("❌ Classification model loading failed. Outputting un-tagged data.")
2836
 
2837
- # ====================================================================
2838
-
 
 
2839
 
2840
- except Exception as e:
2841
- print(f"❌ FATAL ERROR: {e}")
2842
- import traceback
2843
- traceback.print_exc()
2844
- return None
2845
 
2846
- finally:
2847
- try:
2848
- for f in glob.glob(os.path.join(temp_pipeline_dir, '*')):
2849
- os.remove(f)
2850
- os.rmdir(temp_pipeline_dir)
2851
- except Exception:
2852
- pass
2853
-
2854
- print("\n" + "#" * 80)
2855
- print("### OPTIMIZED PIPELINE EXECUTION COMPLETE ###")
2856
- print("#" * 80)
2857
- return final_result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2858
 
2859
 
2860
 
 
2761
 
2762
  # def run_document_pipeline(input_pdf_path: str, layoutlmv3_model_path: str) -> Optional[
2763
  # List[Dict[str, Any]]]:
2764
+ DEFAULT_LAYOUTLMV3_MODEL_PATH = "./models/layoutlmv3_model"
2765
+ WEIGHTS_PATH = "./weights/yolo_weights.pt"
2766
 
2767
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
2768
 
2769
+ torch.set_grad_enabled(False)
 
 
2770
 
2771
+ # ===============================
2772
+ # GLOBAL CACHED MODELS
2773
+ # ===============================
2774
+ _layoutlm_model = None
2775
+ _layoutlm_processor = None
2776
+ _yolo_model = None
2777
 
 
 
 
 
 
2778
 
2779
+ def load_models(layoutlm_path):
2780
+ """
2781
+ Load models ONCE (Singleton pattern)
2782
+ """
2783
+ global _layoutlm_model, _layoutlm_processor, _yolo_model
2784
+
2785
+ if _layoutlm_model is None:
2786
+ print("πŸ”Ή Loading LayoutLMv3...")
2787
+ _layoutlm_processor = AutoProcessor.from_pretrained(layoutlm_path)
2788
+
2789
+ _layoutlm_model = (
2790
+ LayoutLMv3ForTokenClassification
2791
+ .from_pretrained(layoutlm_path)
2792
+ .to(DEVICE)
2793
+ .eval()
 
 
 
 
 
2794
  )
 
 
 
2795
 
2796
+ if DEVICE.type == "cuda":
2797
+ _layoutlm_model = _layoutlm_model.half()
2798
+ _layoutlm_model = torch.compile(_layoutlm_model)
2799
 
2800
+ if _yolo_model is None:
2801
+ print("πŸ”Ή Loading YOLO...")
2802
+ _yolo_model = YOLO(WEIGHTS_PATH)
2803
+ _yolo_model.model.eval()
2804
 
2805
+ return _layoutlm_model, _layoutlm_processor, _yolo_model
2806
 
2807
 
2808
+ # ===============================
2809
+ # PDF UTILITIES
2810
+ # ===============================
2811
+ def load_pdf_images(pdf_path):
2812
+ doc = fitz.open(pdf_path)
2813
+ images = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2814
 
2815
+ for page in doc:
2816
+ pix = page.get_pixmap(dpi=200)
2817
+ img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
2818
+ images.append(img)
2819
 
2820
+ return images
 
 
 
 
2821
 
2822
+
2823
+ # ===============================
2824
+ # MAIN PIPELINE
2825
+ # ===============================
2826
+ def run_document_pipeline(pdf_path, layoutlm_path):
2827
+ model, processor, yolo = load_models(layoutlm_path)
2828
+
2829
+ images = load_pdf_images(pdf_path)
2830
+
2831
+ results = []
2832
+
2833
+ for page_idx, image in enumerate(images):
2834
+ # -------------------------------
2835
+ # YOLO DETECTION
2836
+ # -------------------------------
2837
+ image_resized = image.resize((1024, 1024))
2838
+
2839
+ yolo_result = yolo.predict(
2840
+ image_resized,
2841
+ verbose=False,
2842
+ conf=0.25,
2843
+ )[0]
2844
+
2845
+ boxes = []
2846
+ words = []
2847
+
2848
+ for box in yolo_result.boxes.xyxy.cpu().numpy():
2849
+ x1, y1, x2, y2 = box
2850
+ boxes.append([
2851
+ int(x1), int(y1),
2852
+ int(x2), int(y2)
2853
+ ])
2854
+ words.append("text")
2855
+
2856
+ if not boxes:
2857
+ continue
2858
+
2859
+ # Normalize boxes for LayoutLM
2860
+ w, h = image_resized.size
2861
+ norm_boxes = [
2862
+ [
2863
+ int(1000 * b[0] / w),
2864
+ int(1000 * b[1] / h),
2865
+ int(1000 * b[2] / w),
2866
+ int(1000 * b[3] / h),
2867
+ ]
2868
+ for b in boxes
2869
+ ]
2870
+
2871
+ # -------------------------------
2872
+ # LAYOUTLM INFERENCE
2873
+ # -------------------------------
2874
+ encoding = processor(
2875
+ image_resized,
2876
+ words,
2877
+ boxes=norm_boxes,
2878
+ return_tensors="pt",
2879
+ truncation=True,
2880
+ padding=True,
2881
+ )
2882
+
2883
+ encoding = {k: v.to(DEVICE) for k, v in encoding.items()}
2884
+
2885
+ with torch.no_grad():
2886
+ outputs = model(**encoding)
2887
+
2888
+ predictions = outputs.logits.argmax(-1).cpu().tolist()
2889
+
2890
+ results.append({
2891
+ "page": page_idx + 1,
2892
+ "num_boxes": len(boxes),
2893
+ "predictions": predictions,
2894
+ })
2895
+
2896
+ if DEVICE.type == "cuda":
2897
+ torch.cuda.empty_cache()
2898
+
2899
+ return {
2900
+ "pdf": Path(pdf_path).name,
2901
+ "pages_processed": len(results),
2902
+ "results": results,
2903
+ }
2904
 
2905
 
2906