heerjtdev commited on
Commit
6a25d35
Β·
verified Β·
1 Parent(s): b652b08

Update working_yolo_pipeline.py

Browse files
Files changed (1) hide show
  1. working_yolo_pipeline.py +173 -171
working_yolo_pipeline.py CHANGED
@@ -178,77 +178,77 @@ from sklearn.metrics.pairwise import cosine_similarity
178
 
179
 
180
 
181
- # import logging
182
- # from transformers import TrOCRProcessor
183
- # # NOTE: Using optimum.onnxruntime for faster inference, as suggested by your sample script.
184
- # # If you run into issues, you may need to fall back to the standard
185
- # # 'transformers.VisionEncoderDecoderModel' if ORTModelForVision2Seq is not found/working.
186
- # from optimum.onnxruntime import ORTModelForVision2Seq
187
-
188
-
189
-
190
  import logging
191
- from transformers import TrOCRProcessor, VisionEncoderDecoderModel
192
- # NOTE: We are replacing the ORTModelForVision2Seq import due to the ModuleNotFoundError
193
- # from optimum.onnxruntime import ORTModelForVision2Seq <-- REMOVE THIS
194
-
195
-
196
- # # ============================================================================
197
- # # --- TR-OCR/ORT MODEL INITIALIZATION ---
198
- # # ============================================================================
199
- # # Set up logging to WARNING level to suppress excessive output from model libraries
200
- # logging.basicConfig(level=logging.WARNING)
201
-
202
- # processor = None
203
- # ort_model = None
204
-
205
- # try:
206
- # MODEL_NAME = 'breezedeus/pix2text-mfr-1.5'
207
- # processor = TrOCRProcessor.from_pretrained(MODEL_NAME)
208
-
209
- # # Initialize the model for ONNX Runtime
210
- # # NOTE: Set use_cache=False to avoid caching warnings/issues if reloading
211
- # ort_model = ORTModelForVision2Seq.from_pretrained(MODEL_NAME, use_cache=False)
212
-
213
- # print("βœ… ORTModelForVision2Seq and TrOCRProcessor initialized successfully for equation conversion.")
214
- # except Exception as e:
215
- # print(f"❌ Error initializing TrOCR/ORT model. Equations will not be converted: {e}")
216
- # processor = None
217
- # ort_model = None
218
-
219
-
220
 
221
 
222
 
 
 
 
 
223
 
224
 
225
  # ============================================================================
226
- # --- TR-OCR/PYTORCH MODEL INITIALIZATION ---
227
  # ============================================================================
 
228
  logging.basicConfig(level=logging.WARNING)
229
 
230
  processor = None
231
- pt_model = None # Renaming the variable from 'ort_model' to 'pt_model' for clarity
232
 
233
  try:
234
  MODEL_NAME = 'breezedeus/pix2text-mfr-1.5'
235
  processor = TrOCRProcessor.from_pretrained(MODEL_NAME)
236
 
237
- # Initialize the standard PyTorch model instead of the ORT model
238
- pt_model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME)
 
239
 
240
- # CRITICAL: Since you want CPU-ONLY, explicitly ensure the model is on CPU
241
- if torch.cuda.is_available():
242
- # Although you requested CPU-only, check if CUDA is available
243
- # and ensure you take the necessary steps to force CPU or use the correct runtime environment.
244
- # For simplicity, if torch is installed for CPU, it will default to CPU.
245
- pass
246
-
247
- print("βœ… VisionEncoderDecoderModel (PyTorch) and TrOCRProcessor initialized successfully for equation conversion.")
248
  except Exception as e:
249
- print(f"❌ Error initializing TrOCR/PyTorch model. Equations will not be converted: {e}")
250
  processor = None
251
- pt_model = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
 
253
 
254
 
@@ -402,62 +402,13 @@ except Exception as e:
402
 
403
 
404
 
405
- # def get_latex_from_base64(base64_string: str) -> str:
406
- # """
407
- # Decodes a Base64 image string and uses the pre-initialized TrOCR/ORT model
408
- # to recognize the formula. It cleans the output by removing spaces and
409
- # crucially, replacing double backslashes with single backslashes for correct LaTeX.
410
- # """
411
- # if ort_model is None or processor is None:
412
- # return "[MODEL_ERROR: Model not initialized]"
413
-
414
- # try:
415
- # # 1. Decode Base64 to Image
416
- # image_data = base64.b64decode(base64_string)
417
- # # We must ensure the image is RGB format for the model input
418
- # image = Image.open(io.BytesIO(image_data)).convert('RGB')
419
-
420
- # # 2. Preprocess the image
421
- # pixel_values = processor(images=image, return_tensors="pt").pixel_values
422
-
423
- # # 3. Text Generation (OCR)
424
- # generated_ids = ort_model.generate(pixel_values)
425
- # raw_generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
426
-
427
- # if not raw_generated_text:
428
- # return "[OCR_WARNING: No formula found]"
429
-
430
- # latex_string = raw_generated_text[0]
431
-
432
- # # --- 4. Post-processing and Cleanup ---
433
-
434
- # # A. Remove all spaces/line breaks
435
- # cleaned_latex = re.sub(r'\s+', '', latex_string)
436
-
437
- # # B. CRITICAL FIX: Replace double backslashes with single backslashes.
438
- # # This addresses the over-escaping issue.
439
- # final_output = cleaned_latex.replace('\\\\', '\\')
440
-
441
- # # Return the clean LaTeX string (e.g., $$a=\frac{F}{2m}$$)
442
- # return final_output
443
-
444
- # except Exception as e:
445
- # # Catch any unexpected errors
446
- # print(f" ❌ TR-OCR Recognition failed: {e}")
447
- # return f"[TR_OCR_ERROR: Recognition failed: {e}]"
448
-
449
-
450
-
451
-
452
-
453
  def get_latex_from_base64(base64_string: str) -> str:
454
  """
455
- Decodes a Base64 image string and uses the pre-initialized TrOCR/PyTorch model
456
- to recognize the formula. It cleans the output by removing spaces and
457
  crucially, replacing double backslashes with single backslashes for correct LaTeX.
458
  """
459
- # Check the new model variable
460
- if pt_model is None or processor is None:
461
  return "[MODEL_ERROR: Model not initialized]"
462
 
463
  try:
@@ -470,8 +421,7 @@ def get_latex_from_base64(base64_string: str) -> str:
470
  pixel_values = processor(images=image, return_tensors="pt").pixel_values
471
 
472
  # 3. Text Generation (OCR)
473
- # Use the PyTorch model's generate method
474
- generated_ids = pt_model.generate(pixel_values)
475
  raw_generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
476
 
477
  if not raw_generated_text:
@@ -485,17 +435,69 @@ def get_latex_from_base64(base64_string: str) -> str:
485
  cleaned_latex = re.sub(r'\s+', '', latex_string)
486
 
487
  # B. CRITICAL FIX: Replace double backslashes with single backslashes.
488
- final_output = cleaned_latex.replace('\\\\', '\\')
 
 
 
 
 
489
 
490
- return final_output
491
 
492
  except Exception as e:
 
493
  print(f" ❌ TR-OCR Recognition failed: {e}")
494
  return f"[TR_OCR_ERROR: Recognition failed: {e}]"
495
 
496
 
497
 
498
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
499
 
500
 
501
 
@@ -2351,55 +2353,7 @@ def run_document_pipeline(input_pdf_path: str, layoutlmv3_model_path: str, label
2351
 
2352
 
2353
 
2354
- # if __name__ == "__main__":
2355
- # parser = argparse.ArgumentParser(description="Complete Pipeline")
2356
- # parser.add_argument("--input_pdf", type=str, required=True, help="Input PDF")
2357
- # parser.add_argument("--layoutlmv3_model_path", type=str, default=DEFAULT_LAYOUTLMV3_MODEL_PATH, help="Model Path")
2358
- # parser.add_argument("--ls_output_path", type=str, default=None, help="Label Studio Output Path")
2359
- # # --- ADDED ARGUMENT FOR DEBUGGING ---
2360
- # parser.add_argument("--raw_preds_path", type=str, default='BIO_debug.json',
2361
- # help="Debug path for raw BIO tag predictions (JSON).")
2362
- # # ------------------------------------
2363
- # args = parser.parse_args()
2364
-
2365
- # pdf_name = os.path.splitext(os.path.basename(args.input_pdf))[0]
2366
- # final_output_path = os.path.abspath(f"{pdf_name}_final_output_embedded.json")
2367
- # ls_output_path = os.path.abspath(
2368
- # args.ls_output_path if args.ls_output_path else f"{pdf_name}_label_studio_tasks.json")
2369
- # # --- CALCULATE RAW PREDICTIONS OUTPUT PATH ---
2370
- # # raw_predictions_output_path = os.path.abspath(
2371
- # # args.raw_preds_path if args.raw_preds_path else f"{pdf_name}_raw_predictions_debug.json")
2372
- # # ---------------------------------------------
2373
-
2374
- # # --- UPDATED FUNCTION CALL ---
2375
- # final_json_data = run_document_pipeline(
2376
- # args.input_pdf,
2377
- # args.layoutlmv3_model_path,
2378
- # ls_output_path,
2379
- # # raw_predictions_output_path # Pass the new argument
2380
- # )
2381
- # # -----------------------------
2382
-
2383
- # if final_json_data:
2384
- # with open(final_output_path, 'w', encoding='utf-8') as f:
2385
- # json.dump(final_json_data, f, indent=2, ensure_ascii=False)
2386
- # print(f"\nβœ… Final Data Saved: {final_output_path}")
2387
- # else:
2388
- # print("\n❌ Pipeline Failed.")
2389
- # sys.exit(1)
2390
-
2391
-
2392
-
2393
-
2394
-
2395
-
2396
  if __name__ == "__main__":
2397
- # Ensure 'json', 'argparse', 'os', and 'sys' are imported at the top of your script
2398
- # import json
2399
- # import argparse
2400
- # import os
2401
- # import sys
2402
-
2403
  parser = argparse.ArgumentParser(description="Complete Pipeline")
2404
  parser.add_argument("--input_pdf", type=str, required=True, help="Input PDF")
2405
  parser.add_argument("--layoutlmv3_model_path", type=str, default=DEFAULT_LAYOUTLMV3_MODEL_PATH, help="Model Path")
@@ -2421,30 +2375,78 @@ if __name__ == "__main__":
2421
 
2422
  # --- UPDATED FUNCTION CALL ---
2423
  final_json_data = run_document_pipeline(
2424
- args.input_pdf,
2425
- args.layoutlmv3_model_path,
2426
- ls_output_path,
2427
  # raw_predictions_output_path # Pass the new argument
2428
  )
2429
  # -----------------------------
2430
 
2431
- # πŸ›‘ CRITICAL FIX: CUSTOM JSON SAVING TO REMOVE DOUBLE BACKSLASHES πŸ›‘
2432
  if final_json_data:
2433
- # 1. Dump the Python object to a standard JSON string.
2434
- # This uses json.dumps which correctly escapes single backslashes ('\') to ('\\').
2435
- json_str = json.dumps(final_json_data, indent=2, ensure_ascii=False)
2436
-
2437
- # 2. **UNDO ESCAPING:** Replace every instance of the JSON-escaped backslash ('\\')
2438
- # with a single literal backslash ('\'). This forces the file content to be correct for LaTeX.
2439
- final_output_content = json_str.replace('\\\\', '\\')
2440
-
2441
- # 3. Write the corrected string content to the file.
2442
  with open(final_output_path, 'w', encoding='utf-8') as f:
2443
- f.write(final_output_content)
2444
-
2445
  print(f"\nβœ… Final Data Saved: {final_output_path}")
2446
  else:
2447
  print("\n❌ Pipeline Failed.")
2448
  sys.exit(1)
2449
 
2450
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
 
180
 
 
 
 
 
 
 
 
 
 
181
  import logging
182
+ from transformers import TrOCRProcessor
183
+ # NOTE: Using optimum.onnxruntime for faster inference, as suggested by your sample script.
184
+ # If you run into issues, you may need to fall back to the standard
185
+ # 'transformers.VisionEncoderDecoderModel' if ORTModelForVision2Seq is not found/working.
186
+ from optimum.onnxruntime import ORTModelForVision2Seq
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
 
189
 
190
+ # import logging
191
+ # from transformers import TrOCRProcessor, VisionEncoderDecoderModel
192
+ # # NOTE: We are replacing the ORTModelForVision2Seq import due to the ModuleNotFoundError
193
+ # # from optimum.onnxruntime import ORTModelForVision2Seq <-- REMOVE THIS
194
 
195
 
196
  # ============================================================================
197
+ # --- TR-OCR/ORT MODEL INITIALIZATION ---
198
  # ============================================================================
199
+ # Set up logging to WARNING level to suppress excessive output from model libraries
200
  logging.basicConfig(level=logging.WARNING)
201
 
202
  processor = None
203
+ ort_model = None
204
 
205
  try:
206
  MODEL_NAME = 'breezedeus/pix2text-mfr-1.5'
207
  processor = TrOCRProcessor.from_pretrained(MODEL_NAME)
208
 
209
+ # Initialize the model for ONNX Runtime
210
+ # NOTE: Set use_cache=False to avoid caching warnings/issues if reloading
211
+ ort_model = ORTModelForVision2Seq.from_pretrained(MODEL_NAME, use_cache=False)
212
 
213
+ print("βœ… ORTModelForVision2Seq and TrOCRProcessor initialized successfully for equation conversion.")
 
 
 
 
 
 
 
214
  except Exception as e:
215
+ print(f"❌ Error initializing TrOCR/ORT model. Equations will not be converted: {e}")
216
  processor = None
217
+ ort_model = None
218
+
219
+
220
+
221
+
222
+
223
+
224
+ #
225
+ # # ============================================================================
226
+ # # --- TR-OCR/PYTORCH MODEL INITIALIZATION ---
227
+ # # ============================================================================
228
+ # logging.basicConfig(level=logging.WARNING)
229
+ #
230
+ # processor = None
231
+ # pt_model = None # Renaming the variable from 'ort_model' to 'pt_model' for clarity
232
+ #
233
+ # try:
234
+ # MODEL_NAME = 'breezedeus/pix2text-mfr-1.5'
235
+ # processor = TrOCRProcessor.from_pretrained(MODEL_NAME)
236
+ #
237
+ # # Initialize the standard PyTorch model instead of the ORT model
238
+ # pt_model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME)
239
+ #
240
+ # # CRITICAL: Since you want CPU-ONLY, explicitly ensure the model is on CPU
241
+ # if torch.cuda.is_available():
242
+ # # Although you requested CPU-only, check if CUDA is available
243
+ # # and ensure you take the necessary steps to force CPU or use the correct runtime environment.
244
+ # # For simplicity, if torch is installed for CPU, it will default to CPU.
245
+ # pass
246
+ #
247
+ # print("βœ… VisionEncoderDecoderModel (PyTorch) and TrOCRProcessor initialized successfully for equation conversion.")
248
+ # except Exception as e:
249
+ # print(f"❌ Error initializing TrOCR/PyTorch model. Equations will not be converted: {e}")
250
+ # processor = None
251
+ # pt_model = None
252
 
253
 
254
 
 
402
 
403
 
404
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
  def get_latex_from_base64(base64_string: str) -> str:
406
  """
407
+ Decodes a Base64 image string and uses the pre-initialized TrOCR/ORT model
408
+ to recognize the formula. It cleans the output by removing spaces and
409
  crucially, replacing double backslashes with single backslashes for correct LaTeX.
410
  """
411
+ if ort_model is None or processor is None:
 
412
  return "[MODEL_ERROR: Model not initialized]"
413
 
414
  try:
 
421
  pixel_values = processor(images=image, return_tensors="pt").pixel_values
422
 
423
  # 3. Text Generation (OCR)
424
+ generated_ids = ort_model.generate(pixel_values)
 
425
  raw_generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
426
 
427
  if not raw_generated_text:
 
435
  cleaned_latex = re.sub(r'\s+', '', latex_string)
436
 
437
  # B. CRITICAL FIX: Replace double backslashes with single backslashes.
438
+ # This addresses the over-escaping issue.
439
+ # final_output = cleaned_latex.replace('\\\\', '\\')
440
+
441
+ # Return the clean LaTeX string (e.g., $$a=\frac{F}{2m}$$)
442
+ #return final_output
443
+ return cleaned_latex
444
 
 
445
 
446
  except Exception as e:
447
+ # Catch any unexpected errors
448
  print(f" ❌ TR-OCR Recognition failed: {e}")
449
  return f"[TR_OCR_ERROR: Recognition failed: {e}]"
450
 
451
 
452
 
453
 
454
+ #
455
+ # def get_latex_from_base64(base64_string: str) -> str:
456
+ # """
457
+ # Decodes a Base64 image string and uses the pre-initialized TrOCR/PyTorch model
458
+ # to recognize the formula. It cleans the output by removing spaces and
459
+ # crucially, replacing double backslashes with single backslashes for correct LaTeX.
460
+ # """
461
+ # # Check the new model variable
462
+ # if pt_model is None or processor is None:
463
+ # return "[MODEL_ERROR: Model not initialized]"
464
+ #
465
+ # try:
466
+ # # 1. Decode Base64 to Image
467
+ # image_data = base64.b64decode(base64_string)
468
+ # # We must ensure the image is RGB format for the model input
469
+ # image = Image.open(io.BytesIO(image_data)).convert('RGB')
470
+ #
471
+ # # 2. Preprocess the image
472
+ # pixel_values = processor(images=image, return_tensors="pt").pixel_values
473
+ #
474
+ # # 3. Text Generation (OCR)
475
+ # # Use the PyTorch model's generate method
476
+ # generated_ids = pt_model.generate(pixel_values)
477
+ # raw_generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
478
+ #
479
+ # if not raw_generated_text:
480
+ # return "[OCR_WARNING: No formula found]"
481
+ #
482
+ # latex_string = raw_generated_text[0]
483
+ #
484
+ # # --- 4. Post-processing and Cleanup ---
485
+ #
486
+ # # A. Remove all spaces/line breaks
487
+ # cleaned_latex = re.sub(r'\s+', '', latex_string)
488
+ #
489
+ # # B. CRITICAL FIX: Replace double backslashes with single backslashes.
490
+ # final_output = cleaned_latex.replace('\\\\', '\\')
491
+ #
492
+ # return final_output
493
+ #
494
+ # except Exception as e:
495
+ # print(f" ❌ TR-OCR Recognition failed: {e}")
496
+ # return f"[TR_OCR_ERROR: Recognition failed: {e}]"
497
+ #
498
+ #
499
+
500
+
501
 
502
 
503
 
 
2353
 
2354
 
2355
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2356
  if __name__ == "__main__":
 
 
 
 
 
 
2357
  parser = argparse.ArgumentParser(description="Complete Pipeline")
2358
  parser.add_argument("--input_pdf", type=str, required=True, help="Input PDF")
2359
  parser.add_argument("--layoutlmv3_model_path", type=str, default=DEFAULT_LAYOUTLMV3_MODEL_PATH, help="Model Path")
 
2375
 
2376
  # --- UPDATED FUNCTION CALL ---
2377
  final_json_data = run_document_pipeline(
2378
+ args.input_pdf,
2379
+ args.layoutlmv3_model_path,
2380
+ ls_output_path,
2381
  # raw_predictions_output_path # Pass the new argument
2382
  )
2383
  # -----------------------------
2384
 
 
2385
  if final_json_data:
 
 
 
 
 
 
 
 
 
2386
  with open(final_output_path, 'w', encoding='utf-8') as f:
2387
+ json.dump(final_json_data, f, indent=2, ensure_ascii=False)
 
2388
  print(f"\nβœ… Final Data Saved: {final_output_path}")
2389
  else:
2390
  print("\n❌ Pipeline Failed.")
2391
  sys.exit(1)
2392
 
2393
 
2394
+
2395
+
2396
+
2397
+
2398
+ # if __name__ == "__main__":
2399
+ # # Ensure 'json', 'argparse', 'os', and 'sys' are imported at the top of your script
2400
+ # # import json
2401
+ # # import argparse
2402
+ # # import os
2403
+ # # import sys
2404
+ #
2405
+ # parser = argparse.ArgumentParser(description="Complete Pipeline")
2406
+ # parser.add_argument("--input_pdf", type=str, required=True, help="Input PDF")
2407
+ # parser.add_argument("--layoutlmv3_model_path", type=str, default=DEFAULT_LAYOUTLMV3_MODEL_PATH, help="Model Path")
2408
+ # parser.add_argument("--ls_output_path", type=str, default=None, help="Label Studio Output Path")
2409
+ # # --- ADDED ARGUMENT FOR DEBUGGING ---
2410
+ # parser.add_argument("--raw_preds_path", type=str, default='BIO_debug.json',
2411
+ # help="Debug path for raw BIO tag predictions (JSON).")
2412
+ # # ------------------------------------
2413
+ # args = parser.parse_args()
2414
+ #
2415
+ # pdf_name = os.path.splitext(os.path.basename(args.input_pdf))[0]
2416
+ # final_output_path = os.path.abspath(f"{pdf_name}_final_output_embedded.json")
2417
+ # ls_output_path = os.path.abspath(
2418
+ # args.ls_output_path if args.ls_output_path else f"{pdf_name}_label_studio_tasks.json")
2419
+ # # --- CALCULATE RAW PREDICTIONS OUTPUT PATH ---
2420
+ # # raw_predictions_output_path = os.path.abspath(
2421
+ # # args.raw_preds_path if args.raw_preds_path else f"{pdf_name}_raw_predictions_debug.json")
2422
+ # # ---------------------------------------------
2423
+ #
2424
+ # # --- UPDATED FUNCTION CALL ---
2425
+ # final_json_data = run_document_pipeline(
2426
+ # args.input_pdf,
2427
+ # args.layoutlmv3_model_path,
2428
+ # ls_output_path,
2429
+ # # raw_predictions_output_path # Pass the new argument
2430
+ # )
2431
+ # # -----------------------------
2432
+ #
2433
+ # # πŸ›‘ CRITICAL FIX: CUSTOM JSON SAVING TO REMOVE DOUBLE BACKSLASHES πŸ›‘
2434
+ # if final_json_data:
2435
+ # # 1. Dump the Python object to a standard JSON string.
2436
+ # # This uses json.dumps which correctly escapes single backslashes ('\') to ('\\').
2437
+ # json_str = json.dumps(final_json_data, indent=2, ensure_ascii=False)
2438
+ #
2439
+ # # 2. **UNDO ESCAPING:** Replace every instance of the JSON-escaped backslash ('\\')
2440
+ # # with a single literal backslash ('\'). This forces the file content to be correct for LaTeX.
2441
+ # final_output_content = json_str.replace('\\\\', '\\')
2442
+ #
2443
+ # # 3. Write the corrected string content to the file.
2444
+ # with open(final_output_path, 'w', encoding='utf-8') as f:
2445
+ # f.write(final_output_content)
2446
+ #
2447
+ # print(f"\nβœ… Final Data Saved: {final_output_path}")
2448
+ # else:
2449
+ # print("\n❌ Pipeline Failed.")
2450
+ # sys.exit(1)
2451
+
2452
+