mathstutor / app /tools /advanced_ocr.py
ghadgemadhuri92's picture
agent tested with the prompt: Calculate 15 * 12 then add 50.
565a379
import logging
import torch
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
import io
logger = logging.getLogger(__name__)
class AdvancedOCR:
"""
Advanced OCR using Hugging Face Transformers (TrOCR).
Specialized for handwritten text and mathematical expressions.
"""
def __init__(self, model_name: str = "microsoft/trocr-base-handwritten"):
"""
Initialize the TrOCR model and processor.
"""
self.model_name = model_name
self.processor = None
self.model = None
# Lazy loading to avoid heavy startup time if not needed immediately
self._loaded = False
def load_model(self):
"""
Load the model into memory.
"""
if self._loaded:
return
try:
logger.info(f"Loading TrOCR model: {self.model_name}...")
self.processor = TrOCRProcessor.from_pretrained(self.model_name)
self.model = VisionEncoderDecoderModel.from_pretrained(self.model_name)
# Move to GPU if available
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
self._loaded = True
logger.info("TrOCR model loaded successfully.")
except Exception as e:
logger.error(f"Failed to load TrOCR model: {e}")
self._loaded = False
raise e
def extract_handwriting(self, image: Image.Image) -> str:
"""
Extract handwritten text from a PIL Image.
"""
if not self._loaded:
self.load_model()
try:
# Prepare image
if image.mode != "RGB":
image = image.convert("RGB")
pixel_values = self.processor(images=image, return_tensors="pt").pixel_values
pixel_values = pixel_values.to(self.device)
# Generate text
generated_ids = self.model.generate(pixel_values)
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_text
except Exception as e:
logger.error(f"TrOCR extraction failed: {e}")
return ""
def process_image_bytes(self, image_bytes: bytes) -> str:
"""
Process raw image bytes.
"""
try:
image = Image.open(io.BytesIO(image_bytes))
return self.extract_handwriting(image)
except Exception as e:
logger.error(f"Error processing image bytes: {e}")
return ""