Spaces:
Running
Running
Update working_yolo_pipeline.py
Browse files- working_yolo_pipeline.py +107 -20
working_yolo_pipeline.py
CHANGED
|
@@ -178,37 +178,77 @@ from sklearn.metrics.pairwise import cosine_similarity
|
|
| 178 |
|
| 179 |
|
| 180 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
import logging
|
| 182 |
-
from transformers import TrOCRProcessor
|
| 183 |
-
# NOTE:
|
| 184 |
-
#
|
| 185 |
-
|
| 186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
|
| 189 |
|
| 190 |
# ============================================================================
|
| 191 |
-
# --- TR-OCR/
|
| 192 |
# ============================================================================
|
| 193 |
-
# Set up logging to WARNING level to suppress excessive output from model libraries
|
| 194 |
logging.basicConfig(level=logging.WARNING)
|
| 195 |
|
| 196 |
processor = None
|
| 197 |
-
|
| 198 |
|
| 199 |
try:
|
| 200 |
MODEL_NAME = 'breezedeus/pix2text-mfr-1.5'
|
| 201 |
processor = TrOCRProcessor.from_pretrained(MODEL_NAME)
|
| 202 |
|
| 203 |
-
# Initialize the model
|
| 204 |
-
|
| 205 |
-
ort_model = ORTModelForVision2Seq.from_pretrained(MODEL_NAME, use_cache=False)
|
| 206 |
|
| 207 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
except Exception as e:
|
| 209 |
-
print(f"β Error initializing TrOCR/
|
| 210 |
processor = None
|
| 211 |
-
|
| 212 |
|
| 213 |
|
| 214 |
|
|
@@ -362,13 +402,62 @@ except Exception as e:
|
|
| 362 |
|
| 363 |
|
| 364 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 365 |
def get_latex_from_base64(base64_string: str) -> str:
|
| 366 |
"""
|
| 367 |
-
Decodes a Base64 image string and uses the pre-initialized TrOCR/
|
| 368 |
to recognize the formula. It cleans the output by removing spaces and
|
| 369 |
crucially, replacing double backslashes with single backslashes for correct LaTeX.
|
| 370 |
"""
|
| 371 |
-
|
|
|
|
| 372 |
return "[MODEL_ERROR: Model not initialized]"
|
| 373 |
|
| 374 |
try:
|
|
@@ -381,7 +470,8 @@ def get_latex_from_base64(base64_string: str) -> str:
|
|
| 381 |
pixel_values = processor(images=image, return_tensors="pt").pixel_values
|
| 382 |
|
| 383 |
# 3. Text Generation (OCR)
|
| 384 |
-
|
|
|
|
| 385 |
raw_generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
| 386 |
|
| 387 |
if not raw_generated_text:
|
|
@@ -395,14 +485,11 @@ def get_latex_from_base64(base64_string: str) -> str:
|
|
| 395 |
cleaned_latex = re.sub(r'\s+', '', latex_string)
|
| 396 |
|
| 397 |
# B. CRITICAL FIX: Replace double backslashes with single backslashes.
|
| 398 |
-
# This addresses the over-escaping issue.
|
| 399 |
final_output = cleaned_latex.replace('\\\\', '\\')
|
| 400 |
|
| 401 |
-
# Return the clean LaTeX string (e.g., $$a=\frac{F}{2m}$$)
|
| 402 |
return final_output
|
| 403 |
|
| 404 |
except Exception as e:
|
| 405 |
-
# Catch any unexpected errors
|
| 406 |
print(f" β TR-OCR Recognition failed: {e}")
|
| 407 |
return f"[TR_OCR_ERROR: Recognition failed: {e}]"
|
| 408 |
|
|
|
|
| 178 |
|
| 179 |
|
| 180 |
|
| 181 |
+
# import logging
|
| 182 |
+
# from transformers import TrOCRProcessor
|
| 183 |
+
# # NOTE: Using optimum.onnxruntime for faster inference, as suggested by your sample script.
|
| 184 |
+
# # If you run into issues, you may need to fall back to the standard
|
| 185 |
+
# # 'transformers.VisionEncoderDecoderModel' if ORTModelForVision2Seq is not found/working.
|
| 186 |
+
# from optimum.onnxruntime import ORTModelForVision2Seq
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
|
| 190 |
import logging
|
| 191 |
+
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
| 192 |
+
# NOTE: We are replacing the ORTModelForVision2Seq import due to the ModuleNotFoundError
|
| 193 |
+
# from optimum.onnxruntime import ORTModelForVision2Seq <-- REMOVE THIS
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
# # ============================================================================
|
| 197 |
+
# # --- TR-OCR/ORT MODEL INITIALIZATION ---
|
| 198 |
+
# # ============================================================================
|
| 199 |
+
# # Set up logging to WARNING level to suppress excessive output from model libraries
|
| 200 |
+
# logging.basicConfig(level=logging.WARNING)
|
| 201 |
+
|
| 202 |
+
# processor = None
|
| 203 |
+
# ort_model = None
|
| 204 |
+
|
| 205 |
+
# try:
|
| 206 |
+
# MODEL_NAME = 'breezedeus/pix2text-mfr-1.5'
|
| 207 |
+
# processor = TrOCRProcessor.from_pretrained(MODEL_NAME)
|
| 208 |
+
|
| 209 |
+
# # Initialize the model for ONNX Runtime
|
| 210 |
+
# # NOTE: Set use_cache=False to avoid caching warnings/issues if reloading
|
| 211 |
+
# ort_model = ORTModelForVision2Seq.from_pretrained(MODEL_NAME, use_cache=False)
|
| 212 |
+
|
| 213 |
+
# print("β
ORTModelForVision2Seq and TrOCRProcessor initialized successfully for equation conversion.")
|
| 214 |
+
# except Exception as e:
|
| 215 |
+
# print(f"β Error initializing TrOCR/ORT model. Equations will not be converted: {e}")
|
| 216 |
+
# processor = None
|
| 217 |
+
# ort_model = None
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
|
| 222 |
|
| 223 |
|
| 224 |
|
| 225 |
# ============================================================================
|
| 226 |
+
# --- TR-OCR/PYTORCH MODEL INITIALIZATION ---
|
| 227 |
# ============================================================================
|
|
|
|
| 228 |
logging.basicConfig(level=logging.WARNING)
|
| 229 |
|
| 230 |
processor = None
|
| 231 |
+
pt_model = None # Renaming the variable from 'ort_model' to 'pt_model' for clarity
|
| 232 |
|
| 233 |
try:
|
| 234 |
MODEL_NAME = 'breezedeus/pix2text-mfr-1.5'
|
| 235 |
processor = TrOCRProcessor.from_pretrained(MODEL_NAME)
|
| 236 |
|
| 237 |
+
# Initialize the standard PyTorch model instead of the ORT model
|
| 238 |
+
pt_model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME)
|
|
|
|
| 239 |
|
| 240 |
+
# CRITICAL: Since you want CPU-ONLY, explicitly ensure the model is on CPU
|
| 241 |
+
if torch.cuda.is_available():
|
| 242 |
+
# Although you requested CPU-only, check if CUDA is available
|
| 243 |
+
# and ensure you take the necessary steps to force CPU or use the correct runtime environment.
|
| 244 |
+
# For simplicity, if torch is installed for CPU, it will default to CPU.
|
| 245 |
+
pass
|
| 246 |
+
|
| 247 |
+
print("β
VisionEncoderDecoderModel (PyTorch) and TrOCRProcessor initialized successfully for equation conversion.")
|
| 248 |
except Exception as e:
|
| 249 |
+
print(f"β Error initializing TrOCR/PyTorch model. Equations will not be converted: {e}")
|
| 250 |
processor = None
|
| 251 |
+
pt_model = None
|
| 252 |
|
| 253 |
|
| 254 |
|
|
|
|
| 402 |
|
| 403 |
|
| 404 |
|
| 405 |
+
# def get_latex_from_base64(base64_string: str) -> str:
|
| 406 |
+
# """
|
| 407 |
+
# Decodes a Base64 image string and uses the pre-initialized TrOCR/ORT model
|
| 408 |
+
# to recognize the formula. It cleans the output by removing spaces and
|
| 409 |
+
# crucially, replacing double backslashes with single backslashes for correct LaTeX.
|
| 410 |
+
# """
|
| 411 |
+
# if ort_model is None or processor is None:
|
| 412 |
+
# return "[MODEL_ERROR: Model not initialized]"
|
| 413 |
+
|
| 414 |
+
# try:
|
| 415 |
+
# # 1. Decode Base64 to Image
|
| 416 |
+
# image_data = base64.b64decode(base64_string)
|
| 417 |
+
# # We must ensure the image is RGB format for the model input
|
| 418 |
+
# image = Image.open(io.BytesIO(image_data)).convert('RGB')
|
| 419 |
+
|
| 420 |
+
# # 2. Preprocess the image
|
| 421 |
+
# pixel_values = processor(images=image, return_tensors="pt").pixel_values
|
| 422 |
+
|
| 423 |
+
# # 3. Text Generation (OCR)
|
| 424 |
+
# generated_ids = ort_model.generate(pixel_values)
|
| 425 |
+
# raw_generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
| 426 |
+
|
| 427 |
+
# if not raw_generated_text:
|
| 428 |
+
# return "[OCR_WARNING: No formula found]"
|
| 429 |
+
|
| 430 |
+
# latex_string = raw_generated_text[0]
|
| 431 |
+
|
| 432 |
+
# # --- 4. Post-processing and Cleanup ---
|
| 433 |
+
|
| 434 |
+
# # A. Remove all spaces/line breaks
|
| 435 |
+
# cleaned_latex = re.sub(r'\s+', '', latex_string)
|
| 436 |
+
|
| 437 |
+
# # B. CRITICAL FIX: Replace double backslashes with single backslashes.
|
| 438 |
+
# # This addresses the over-escaping issue.
|
| 439 |
+
# final_output = cleaned_latex.replace('\\\\', '\\')
|
| 440 |
+
|
| 441 |
+
# # Return the clean LaTeX string (e.g., $$a=\frac{F}{2m}$$)
|
| 442 |
+
# return final_output
|
| 443 |
+
|
| 444 |
+
# except Exception as e:
|
| 445 |
+
# # Catch any unexpected errors
|
| 446 |
+
# print(f" β TR-OCR Recognition failed: {e}")
|
| 447 |
+
# return f"[TR_OCR_ERROR: Recognition failed: {e}]"
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
|
| 453 |
def get_latex_from_base64(base64_string: str) -> str:
|
| 454 |
"""
|
| 455 |
+
Decodes a Base64 image string and uses the pre-initialized TrOCR/PyTorch model
|
| 456 |
to recognize the formula. It cleans the output by removing spaces and
|
| 457 |
crucially, replacing double backslashes with single backslashes for correct LaTeX.
|
| 458 |
"""
|
| 459 |
+
# Check the new model variable
|
| 460 |
+
if pt_model is None or processor is None:
|
| 461 |
return "[MODEL_ERROR: Model not initialized]"
|
| 462 |
|
| 463 |
try:
|
|
|
|
| 470 |
pixel_values = processor(images=image, return_tensors="pt").pixel_values
|
| 471 |
|
| 472 |
# 3. Text Generation (OCR)
|
| 473 |
+
# Use the PyTorch model's generate method
|
| 474 |
+
generated_ids = pt_model.generate(pixel_values)
|
| 475 |
raw_generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
| 476 |
|
| 477 |
if not raw_generated_text:
|
|
|
|
| 485 |
cleaned_latex = re.sub(r'\s+', '', latex_string)
|
| 486 |
|
| 487 |
# B. CRITICAL FIX: Replace double backslashes with single backslashes.
|
|
|
|
| 488 |
final_output = cleaned_latex.replace('\\\\', '\\')
|
| 489 |
|
|
|
|
| 490 |
return final_output
|
| 491 |
|
| 492 |
except Exception as e:
|
|
|
|
| 493 |
print(f" β TR-OCR Recognition failed: {e}")
|
| 494 |
return f"[TR_OCR_ERROR: Recognition failed: {e}]"
|
| 495 |
|