from transformers import AutoModel, AutoTokenizer from typing import Dict, List, Any import torch import base64 from io import BytesIO from PIL import Image import os import tempfile class EndpointHandler: def __init__(self, model_dir = 'deepseek-ai/DeepSeek-OCR'): model_path = model_dir self.tokenizer = AutoTokenizer.from_pretrained( model_path, trust_remote_code=True, local_files_only=bool(model_dir) ) # Check if CUDA is available self.device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f"Using device: {self.device}") # Load model in float32 to avoid dtype conflicts model_kwargs = { 'trust_remote_code': True, 'torch_dtype': torch.float32 # Use float32 instead of float16 } # Explicitly disable flash attention model_kwargs['_attn_implementation'] = 'eager' self.model = AutoModel.from_pretrained(model_path, **model_kwargs) self.model = self.model.eval() # Move to appropriate device if self.device == 'cuda': self.model = self.model.cuda() def __call__(self, data: Dict[str, Any]) -> str: try: base64_string = None if "inputs" in data and isinstance(data["inputs"], str): base64_string = data["inputs"] # Case 2: Base64 in nested inputs dictionary elif "inputs" in data and isinstance(data["inputs"], dict): base64_string = data["inputs"].get("base64") # Case 3: Direct base64 at root level elif "base64" in data: base64_string = data["base64"] # Case 4: Try raw data as base64 elif isinstance(data, str): base64_string = data if not base64_string: return {"error": "No base64 string found in input data. Available keys: " + str(data.keys())} print("Found base64 string, length:", len(base64_string)) # Remove data URL prefix if present if ',' in base64_string: base64_string = base64_string.split(',')[1] # Decode base64 to image image_data = base64.b64decode(base64_string) # Define the prompt for Markdown conversion prompt = "\n<|grounding|>Convert this document to markdown format using # headers, **bold** for important information, and Markdown table syntax (using | and -) instead of HTML." with tempfile.TemporaryDirectory() as temp_dir: image_path = os.path.join(temp_dir, "input_image.png") with open(image_path, "wb") as f: f.write(image_data) print(f"Image saved to: {image_path}") # Verify the image can be opened try: test_image = Image.open(image_path) if test_image.mode != 'RGB': test_image = test_image.convert('RGB') test_image.save(image_path) # Save converted version print(f"Image verified: {test_image.size}, mode: {test_image.mode}") except Exception as img_error: return {"error": f"Invalid image: {str(img_error)}"} output_dir = os.path.join(temp_dir, "deepseek_out") os.makedirs(output_dir, exist_ok=True) # Run OCR inference result = self.model.infer( self.tokenizer, prompt=prompt, image_file=image_path, # Pass the PIL Image object directly output_path=output_dir, base_size=1024, image_size=640, crop_mode=True, save_results=True, #eval_mode=True ) for fname in os.listdir(output_dir): print("File:\n", fname) if fname.endswith(".md") or fname.endswith(".mmd"): md_path = os.path.join(output_dir, fname) with open(md_path, 'r', encoding='utf-8') as f: markdown = f.read() print("Markdown output:\n", markdown) return markdown #print(str(result)) #return result except Exception as e: print(f"Error processing image: {e}") return str(e)