Spaces:
Running
Running
| import logging | |
| import torch | |
| from transformers import TrOCRProcessor, VisionEncoderDecoderModel | |
| from PIL import Image | |
| import io | |
| logger = logging.getLogger(__name__) | |
| class AdvancedOCR: | |
| """ | |
| Advanced OCR using Hugging Face Transformers (TrOCR). | |
| Specialized for handwritten text and mathematical expressions. | |
| """ | |
| def __init__(self, model_name: str = "microsoft/trocr-base-handwritten"): | |
| """ | |
| Initialize the TrOCR model and processor. | |
| """ | |
| self.model_name = model_name | |
| self.processor = None | |
| self.model = None | |
| # Lazy loading to avoid heavy startup time if not needed immediately | |
| self._loaded = False | |
| def load_model(self): | |
| """ | |
| Load the model into memory. | |
| """ | |
| if self._loaded: | |
| return | |
| try: | |
| logger.info(f"Loading TrOCR model: {self.model_name}...") | |
| self.processor = TrOCRProcessor.from_pretrained(self.model_name) | |
| self.model = VisionEncoderDecoderModel.from_pretrained(self.model_name) | |
| # Move to GPU if available | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.model.to(self.device) | |
| self._loaded = True | |
| logger.info("TrOCR model loaded successfully.") | |
| except Exception as e: | |
| logger.error(f"Failed to load TrOCR model: {e}") | |
| self._loaded = False | |
| raise e | |
| def extract_handwriting(self, image: Image.Image) -> str: | |
| """ | |
| Extract handwritten text from a PIL Image. | |
| """ | |
| if not self._loaded: | |
| self.load_model() | |
| try: | |
| # Prepare image | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| pixel_values = self.processor(images=image, return_tensors="pt").pixel_values | |
| pixel_values = pixel_values.to(self.device) | |
| # Generate text | |
| generated_ids = self.model.generate(pixel_values) | |
| generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| return generated_text | |
| except Exception as e: | |
| logger.error(f"TrOCR extraction failed: {e}") | |
| return "" | |
| def process_image_bytes(self, image_bytes: bytes) -> str: | |
| """ | |
| Process raw image bytes. | |
| """ | |
| try: | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| return self.extract_handwriting(image) | |
| except Exception as e: | |
| logger.error(f"Error processing image bytes: {e}") | |
| return "" | |