heerjtdev commited on
Commit
bb2061e
·
verified ·
1 Parent(s): b8d54ea

Update working_yolo_pipeline.py

Browse files
Files changed (1) hide show
  1. working_yolo_pipeline.py +199 -103
working_yolo_pipeline.py CHANGED
@@ -139,42 +139,76 @@ from sklearn.metrics.pairwise import cosine_similarity
139
 
140
  #=============================================================================
141
  #-----EXPERIMENT LATEX
142
- #=============================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
- # --- NEW IMPORTS ---
145
- from pix2text import Pix2Text
146
  import logging
147
- # -------------------
 
 
 
 
 
148
 
149
- # ============================================================================
150
- # --- CONFIGURATION AND CONSTANTS ---
151
- # ... (Your existing constants like WEIGHTS_PATH, OCR_JSON_OUTPUT_DIR, etc.)
152
- # ============================================================================
153
 
154
  # ============================================================================
155
- # --- PIX2TEXT INITIALIZATION AND HELPER ---
156
  # ============================================================================
157
  # Set up logging to WARNING level to suppress excessive output from model libraries
158
  logging.basicConfig(level=logging.WARNING)
159
- logging.getLogger('pix2text').setLevel(logging.WARNING)
160
 
161
- # Initialize Pix2Text model globally (expensive operation, do it once)
162
- p2t = None
 
163
  try:
164
- # Use 'yolox_tiny' for faster inference AND configure PyTorch backend
165
- p2t = Pix2Text(
166
- analyzer_config={'model_name': 'yolox_tiny'},
167
- # ⬇️ ADD THESE LINES TO USE PYTORCH INSTEAD OF ONNX ⬇️
168
- text_config={
169
- 'rec_model_backend': 'pytorch',
170
- 'det_model_backend': 'pytorch'
171
- }
172
- )
173
- print("✅ Pix2Text model initialized successfully with PyTorch backend for equation conversion.")
174
  except Exception as e:
175
- print(f"❌ Error initializing Pix2Text model. Equations will not be converted: {e}")
176
- p2t = None
177
-
178
 
179
 
180
 
@@ -273,66 +307,11 @@ except Exception as e:
273
 
274
 
275
 
276
- def get_latex_from_base64(base64_string: str) -> str:
277
- """
278
- Decodes a Base64 image string, uses Pix2Text to recognize the formula,
279
- and returns the LaTeX code, stripped of all whitespace, as requested,
280
- and corrects unintended double backslashes.
281
- """
282
- if p2t is None:
283
- return "[P2T_ERROR: Model not initialized]"
284
-
285
- try:
286
- # 1. Decode Base64 to Image
287
- image_data = base64.b64decode(base64_string)
288
- image = Image.open(io.BytesIO(image_data))
289
-
290
- # 2. Recognize text and formulas
291
- # Use keep_original_image=False to save memory
292
- result = p2t.recognize(image, save_formula_images=False, use_analyzer=True, keep_original_image=False)
293
-
294
- # 3. Parse the result for LaTeX
295
- extracted_latex_parts = []
296
- if isinstance(result, list):
297
- for item in result:
298
- # Use .text for structured output, item itself for string output
299
- text = item.text if hasattr(item, 'text') else str(item)
300
- extracted_latex_parts.append(text)
301
- elif isinstance(result, str):
302
- extracted_latex_parts = [result]
303
-
304
- # Join with a space first, then clean all whitespace
305
- extracted_latex = " ".join(extracted_latex_parts).strip()
306
-
307
- # *** CORE CHANGE 1: Remove all spaces/line breaks ***
308
- cleaned_latex = extracted_latex.replace('\\\\', '\\')
309
- final_latex = re.sub(r'\s+', '', cleaned_latex)
310
-
311
- if not cleaned_latex:
312
- return "[P2T_WARNING: No formula found]"
313
-
314
- # *** CORE CHANGE 2: Fix unintended double backslashes for LaTeX rendering ***
315
- # This replaces every sequence of two literal backslashes ('\\') with one literal backslash ('\'),
316
- # ensuring LaTeX commands like '\frac' are correctly formed.
317
-
318
-
319
- # Return the clean and corrected LaTeX string.
320
- return final_latex
321
-
322
- except Exception as e:
323
- # Catch any unexpected errors
324
- print(f" ❌ Pix2Text Recognition failed: {e}")
325
- return f"[P2T_ERROR: Recognition failed: {e}]"
326
-
327
-
328
-
329
-
330
-
331
  # def get_latex_from_base64(base64_string: str) -> str:
332
  # """
333
- # Decodes a Base64 image string, uses Pix2Text to recognize the formula,
334
- # returns the LaTeX code stripped of all whitespace, and collapses unintended
335
- # repeated backslashes into a single backslash.
336
  # """
337
  # if p2t is None:
338
  # return "[P2T_ERROR: Model not initialized]"
@@ -341,37 +320,41 @@ def get_latex_from_base64(base64_string: str) -> str:
341
  # # 1. Decode Base64 to Image
342
  # image_data = base64.b64decode(base64_string)
343
  # image = Image.open(io.BytesIO(image_data))
344
-
345
  # # 2. Recognize text and formulas
346
- # result = p2t.recognize(
347
- # image, save_formula_images=False, use_analyzer=True, keep_original_image=False
348
- # )
349
-
350
  # # 3. Parse the result for LaTeX
351
  # extracted_latex_parts = []
352
  # if isinstance(result, list):
353
  # for item in result:
 
354
  # text = item.text if hasattr(item, 'text') else str(item)
355
  # extracted_latex_parts.append(text)
356
  # elif isinstance(result, str):
357
- # extracted_latex_parts = [result]
358
-
359
- # # Join then strip
360
  # extracted_latex = " ".join(extracted_latex_parts).strip()
361
-
362
- # # Remove all whitespace/newlines/tabs as requested
363
- # cleaned_latex = re.sub(r'\s+', '', extracted_latex)
364
-
 
365
  # if not cleaned_latex:
366
- # return "[P2T_WARNING: No formula found]"
367
-
368
- # # COLLAPSE any run of 2 or more backslashes into a single backslash.
369
- # # This handles inputs like '\\\\sqrt' or '\\\\\\frac' robustly.
370
- # final_latex = re.sub(r'\\{2,}', r'\\', cleaned_latex)
371
-
 
 
372
  # return final_latex
373
 
374
  # except Exception as e:
 
375
  # print(f" ❌ Pix2Text Recognition failed: {e}")
376
  # return f"[P2T_ERROR: Recognition failed: {e}]"
377
 
@@ -379,6 +362,58 @@ def get_latex_from_base64(base64_string: str) -> str:
379
 
380
 
381
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
 
383
 
384
  # # Initialize the YOLO model
@@ -2229,7 +2264,55 @@ def run_document_pipeline(input_pdf_path: str, layoutlmv3_model_path: str, label
2229
 
2230
 
2231
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2232
  if __name__ == "__main__":
 
 
 
 
 
 
2233
  parser = argparse.ArgumentParser(description="Complete Pipeline")
2234
  parser.add_argument("--input_pdf", type=str, required=True, help="Input PDF")
2235
  parser.add_argument("--layoutlmv3_model_path", type=str, default=DEFAULT_LAYOUTLMV3_MODEL_PATH, help="Model Path")
@@ -2258,10 +2341,23 @@ if __name__ == "__main__":
2258
  )
2259
  # -----------------------------
2260
 
 
2261
  if final_json_data:
 
 
 
 
 
 
 
 
 
2262
  with open(final_output_path, 'w', encoding='utf-8') as f:
2263
- json.dump(final_json_data, f, indent=2, ensure_ascii=False)
 
2264
  print(f"\n✅ Final Data Saved: {final_output_path}")
2265
  else:
2266
  print("\n❌ Pipeline Failed.")
2267
- sys.exit(1)
 
 
 
139
 
140
  #=============================================================================
141
  #-----EXPERIMENT LATEX
142
+ # #=============================================================================
143
+
144
+ # # --- NEW IMPORTS ---
145
+ # from pix2text import Pix2Text
146
+ # import logging
147
+ # # -------------------
148
+
149
+ # # ============================================================================
150
+ # # --- CONFIGURATION AND CONSTANTS ---
151
+ # # ... (Your existing constants like WEIGHTS_PATH, OCR_JSON_OUTPUT_DIR, etc.)
152
+ # # ============================================================================
153
+
154
+ # # ============================================================================
155
+ # # --- PIX2TEXT INITIALIZATION AND HELPER ---
156
+ # # ============================================================================
157
+ # # Set up logging to WARNING level to suppress excessive output from model libraries
158
+ # logging.basicConfig(level=logging.WARNING)
159
+ # logging.getLogger('pix2text').setLevel(logging.WARNING)
160
+
161
+ # # Initialize Pix2Text model globally (expensive operation, do it once)
162
+ # p2t = None
163
+ # try:
164
+ # # Use 'yolox_tiny' for faster inference AND configure PyTorch backend
165
+ # p2t = Pix2Text(
166
+ # analyzer_config={'model_name': 'yolox_tiny'},
167
+ # # ⬇️ ADD THESE LINES TO USE PYTORCH INSTEAD OF ONNX ⬇️
168
+ # text_config={
169
+ # 'rec_model_backend': 'pytorch',
170
+ # 'det_model_backend': 'pytorch'
171
+ # }
172
+ # )
173
+ # print("✅ Pix2Text model initialized successfully with PyTorch backend for equation conversion.")
174
+ # except Exception as e:
175
+ # print(f"❌ Error initializing Pix2Text model. Equations will not be converted: {e}")
176
+ # p2t = None
177
+
178
+
179
+
180
 
 
 
181
  import logging
182
+ from transformers import TrOCRProcessor
183
+ # NOTE: Using optimum.onnxruntime for faster inference, as suggested by your sample script.
184
+ # If you run into issues, you may need to fall back to the standard
185
+ # 'transformers.VisionEncoderDecoderModel' if ORTModelForVision2Seq is not found/working.
186
+ from optimum.onnxruntime import ORTModelForVision2Seq
187
+
188
 
 
 
 
 
189
 
190
  # ============================================================================
191
+ # --- TR-OCR/ORT MODEL INITIALIZATION ---
192
  # ============================================================================
193
  # Set up logging to WARNING level to suppress excessive output from model libraries
194
  logging.basicConfig(level=logging.WARNING)
 
195
 
196
+ processor = None
197
+ ort_model = None
198
+
199
  try:
200
+ MODEL_NAME = 'breezedeus/pix2text-mfr-1.5'
201
+ processor = TrOCRProcessor.from_pretrained(MODEL_NAME)
202
+
203
+ # Initialize the model for ONNX Runtime
204
+ # NOTE: Set use_cache=False to avoid caching warnings/issues if reloading
205
+ ort_model = ORTModelForVision2Seq.from_pretrained(MODEL_NAME, use_cache=False)
206
+
207
+ print("✅ ORTModelForVision2Seq and TrOCRProcessor initialized successfully for equation conversion.")
 
 
208
  except Exception as e:
209
+ print(f"❌ Error initializing TrOCR/ORT model. Equations will not be converted: {e}")
210
+ processor = None
211
+ ort_model = None
212
 
213
 
214
 
 
307
 
308
 
309
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  # def get_latex_from_base64(base64_string: str) -> str:
311
  # """
312
+ # Decodes a Base64 image string, uses Pix2Text to recognize the formula,
313
+ # and returns the LaTeX code, stripped of all whitespace, as requested,
314
+ # and corrects unintended double backslashes.
315
  # """
316
  # if p2t is None:
317
  # return "[P2T_ERROR: Model not initialized]"
 
320
  # # 1. Decode Base64 to Image
321
  # image_data = base64.b64decode(base64_string)
322
  # image = Image.open(io.BytesIO(image_data))
323
+
324
  # # 2. Recognize text and formulas
325
+ # # Use keep_original_image=False to save memory
326
+ # result = p2t.recognize(image, save_formula_images=False, use_analyzer=True, keep_original_image=False)
327
+
 
328
  # # 3. Parse the result for LaTeX
329
  # extracted_latex_parts = []
330
  # if isinstance(result, list):
331
  # for item in result:
332
+ # # Use .text for structured output, item itself for string output
333
  # text = item.text if hasattr(item, 'text') else str(item)
334
  # extracted_latex_parts.append(text)
335
  # elif isinstance(result, str):
336
+ # extracted_latex_parts = [result]
337
+
338
+ # # Join with a space first, then clean all whitespace
339
  # extracted_latex = " ".join(extracted_latex_parts).strip()
340
+
341
+ # # *** CORE CHANGE 1: Remove all spaces/line breaks ***
342
+ # cleaned_latex = extracted_latex.replace('\\\\', '\\')
343
+ # final_latex = re.sub(r'\s+', '', cleaned_latex)
344
+
345
  # if not cleaned_latex:
346
+ # return "[P2T_WARNING: No formula found]"
347
+
348
+ # # *** CORE CHANGE 2: Fix unintended double backslashes for LaTeX rendering ***
349
+ # # This replaces every sequence of two literal backslashes ('\\') with one literal backslash ('\'),
350
+ # # ensuring LaTeX commands like '\frac' are correctly formed.
351
+
352
+
353
+ # # Return the clean and corrected LaTeX string.
354
  # return final_latex
355
 
356
  # except Exception as e:
357
+ # # Catch any unexpected errors
358
  # print(f" ❌ Pix2Text Recognition failed: {e}")
359
  # return f"[P2T_ERROR: Recognition failed: {e}]"
360
 
 
362
 
363
 
364
 
365
+ def get_latex_from_base64(base64_string: str) -> str:
366
+ """
367
+ Decodes a Base64 image string and uses the pre-initialized TrOCR/ORT model
368
+ to recognize the formula. It cleans the output by removing spaces and
369
+ crucially, replacing double backslashes with single backslashes for correct LaTeX.
370
+ """
371
+ if ort_model is None or processor is None:
372
+ return "[MODEL_ERROR: Model not initialized]"
373
+
374
+ try:
375
+ # 1. Decode Base64 to Image
376
+ image_data = base64.b64decode(base64_string)
377
+ # We must ensure the image is RGB format for the model input
378
+ image = Image.open(io.BytesIO(image_data)).convert('RGB')
379
+
380
+ # 2. Preprocess the image
381
+ pixel_values = processor(images=image, return_tensors="pt").pixel_values
382
+
383
+ # 3. Text Generation (OCR)
384
+ generated_ids = ort_model.generate(pixel_values)
385
+ raw_generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
386
+
387
+ if not raw_generated_text:
388
+ return "[OCR_WARNING: No formula found]"
389
+
390
+ latex_string = raw_generated_text[0]
391
+
392
+ # --- 4. Post-processing and Cleanup ---
393
+
394
+ # A. Remove all spaces/line breaks
395
+ cleaned_latex = re.sub(r'\s+', '', latex_string)
396
+
397
+ # B. CRITICAL FIX: Replace double backslashes with single backslashes.
398
+ # This addresses the over-escaping issue.
399
+ final_output = cleaned_latex.replace('\\\\', '\\')
400
+
401
+ # Return the clean LaTeX string (e.g., $$a=\frac{F}{2m}$$)
402
+ return final_output
403
+
404
+ except Exception as e:
405
+ # Catch any unexpected errors
406
+ print(f" ❌ TR-OCR Recognition failed: {e}")
407
+ return f"[TR_OCR_ERROR: Recognition failed: {e}]"
408
+
409
+
410
+
411
+
412
+
413
+
414
+
415
+
416
+
417
 
418
 
419
  # # Initialize the YOLO model
 
2264
 
2265
 
2266
 
2267
+ # if __name__ == "__main__":
2268
+ # parser = argparse.ArgumentParser(description="Complete Pipeline")
2269
+ # parser.add_argument("--input_pdf", type=str, required=True, help="Input PDF")
2270
+ # parser.add_argument("--layoutlmv3_model_path", type=str, default=DEFAULT_LAYOUTLMV3_MODEL_PATH, help="Model Path")
2271
+ # parser.add_argument("--ls_output_path", type=str, default=None, help="Label Studio Output Path")
2272
+ # # --- ADDED ARGUMENT FOR DEBUGGING ---
2273
+ # parser.add_argument("--raw_preds_path", type=str, default='BIO_debug.json',
2274
+ # help="Debug path for raw BIO tag predictions (JSON).")
2275
+ # # ------------------------------------
2276
+ # args = parser.parse_args()
2277
+
2278
+ # pdf_name = os.path.splitext(os.path.basename(args.input_pdf))[0]
2279
+ # final_output_path = os.path.abspath(f"{pdf_name}_final_output_embedded.json")
2280
+ # ls_output_path = os.path.abspath(
2281
+ # args.ls_output_path if args.ls_output_path else f"{pdf_name}_label_studio_tasks.json")
2282
+ # # --- CALCULATE RAW PREDICTIONS OUTPUT PATH ---
2283
+ # # raw_predictions_output_path = os.path.abspath(
2284
+ # # args.raw_preds_path if args.raw_preds_path else f"{pdf_name}_raw_predictions_debug.json")
2285
+ # # ---------------------------------------------
2286
+
2287
+ # # --- UPDATED FUNCTION CALL ---
2288
+ # final_json_data = run_document_pipeline(
2289
+ # args.input_pdf,
2290
+ # args.layoutlmv3_model_path,
2291
+ # ls_output_path,
2292
+ # # raw_predictions_output_path # Pass the new argument
2293
+ # )
2294
+ # # -----------------------------
2295
+
2296
+ # if final_json_data:
2297
+ # with open(final_output_path, 'w', encoding='utf-8') as f:
2298
+ # json.dump(final_json_data, f, indent=2, ensure_ascii=False)
2299
+ # print(f"\n✅ Final Data Saved: {final_output_path}")
2300
+ # else:
2301
+ # print("\n❌ Pipeline Failed.")
2302
+ # sys.exit(1)
2303
+
2304
+
2305
+
2306
+
2307
+
2308
+
2309
  if __name__ == "__main__":
2310
+ # Ensure 'json', 'argparse', 'os', and 'sys' are imported at the top of your script
2311
+ # import json
2312
+ # import argparse
2313
+ # import os
2314
+ # import sys
2315
+
2316
  parser = argparse.ArgumentParser(description="Complete Pipeline")
2317
  parser.add_argument("--input_pdf", type=str, required=True, help="Input PDF")
2318
  parser.add_argument("--layoutlmv3_model_path", type=str, default=DEFAULT_LAYOUTLMV3_MODEL_PATH, help="Model Path")
 
2341
  )
2342
  # -----------------------------
2343
 
2344
+ # 🛑 CRITICAL FIX: CUSTOM JSON SAVING TO REMOVE DOUBLE BACKSLASHES 🛑
2345
  if final_json_data:
2346
+ # 1. Dump the Python object to a standard JSON string.
2347
+ # This uses json.dumps which correctly escapes single backslashes ('\') to ('\\').
2348
+ json_str = json.dumps(final_json_data, indent=2, ensure_ascii=False)
2349
+
2350
+ # 2. **UNDO ESCAPING:** Replace every instance of the JSON-escaped backslash ('\\')
2351
+ # with a single literal backslash ('\'). This forces the file content to be correct for LaTeX.
2352
+ final_output_content = json_str.replace('\\\\', '\\')
2353
+
2354
+ # 3. Write the corrected string content to the file.
2355
  with open(final_output_path, 'w', encoding='utf-8') as f:
2356
+ f.write(final_output_content)
2357
+
2358
  print(f"\n✅ Final Data Saved: {final_output_path}")
2359
  else:
2360
  print("\n❌ Pipeline Failed.")
2361
+ sys.exit(1)
2362
+
2363
+