from transformers import AutoModel, AutoTokenizer from typing import Dict, List, Any import torch import base64 from io import BytesIO from PIL import Image import os class EndpointHandler: def __init__(self, model_dir = 'deepseek-ai/DeepSeek-OCR'): model_path = model_dir self.tokenizer = AutoTokenizer.from_pretrained( model_path, trust_remote_code=True, local_files_only=bool(model_dir) # Only use local files if model_dir is provided ) # Check if CUDA is available self.device = 'cuda' if torch.cuda.is_available() else 'cpu' # Load model with appropriate settings model_kwargs = { 'trust_remote_code': True, 'torch_dtype': torch.bfloat16 if self.device == 'cuda' else torch.float32 } # Add flash attention if available and on CUDA if self.device == 'cuda': try: model_kwargs['_attn_implementation'] = 'flash_attention_2' except: pass # Fall back to default if flash attention not available self.model = AutoModel.from_pretrained(model_path, **model_kwargs) self.model = self.model.eval() # Move to appropriate device if self.device == 'cuda': self.model = self.model.cuda() def __call__(self, data: Dict[str, Any]) -> str: try: inputs = data.get("inputs") base64_string = inputs["base64"] # Remove data URL prefix if present if ',' in base64_string: base64_string = base64_string.split(',')[1] # Decode base64 to image image_data = base64.b64decode(base64_string) image = Image.open(BytesIO(image_data)) # Convert to RGB if necessary (handles PNG, JPEG, etc.) if image.mode != 'RGB': image = image.convert('RGB') # Define the prompt for Markdown conversion prompt = "\n<|grounding|>Convert the document to markdown." # Run OCR inference result = self.model.infer( self.tokenizer, prompt=prompt, image_file=image, # Pass PIL Image directly output_path=output_path, base_size=1024, image_size=640, crop_mode=True, save_results=output_path is not None ) return result except Exception as e: print(f"Error processing image: {e}") return None