import os # IMPORTANT: Set OpenMP/MKL threads BEFORE importing torch/numpy # This must be done first to avoid threading conflicts os.environ['OMP_NUM_THREADS'] = '1' os.environ['MKL_NUM_THREADS'] = '1' os.environ['OPENBLAS_NUM_THREADS'] = '1' os.environ['NUMEXPR_NUM_THREADS'] = '1' from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import Dict, Any, List from transformers import AutoModel, AutoTokenizer import torch from PIL import Image import io import base64 app = FastAPI() # Global variables for model model = None tokenizer = None @app.on_event("startup") async def load_model(): """Load the model on startup""" global model, tokenizer try: model_name = 'deepseek-ai/DeepSeek-OCR' print(f"Loading model: {model_name}") tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) # Check if GPU supports Flash Attention (Ampere or newer, compute capability >= 8.0) use_flash_attention = False if torch.cuda.is_available(): compute_capability = torch.cuda.get_device_capability() print(f"GPU Compute Capability: {compute_capability}") # Flash Attention requires Ampere (8.0) or newer if compute_capability[0] >= 8: use_flash_attention = True print("GPU supports Flash Attention 2.0") else: print(f"GPU does not support Flash Attention 2.0 (requires compute capability >= 8.0, got {compute_capability[0]}.{compute_capability[1]})") # Load model with appropriate attention implementation if use_flash_attention: try: print("Loading with flash_attention_2...") model = AutoModel.from_pretrained( model_name, attn_implementation='flash_attention_2', trust_remote_code=True, use_safetensors=True ) print("✓ Loaded with flash_attention_2") except Exception as e: print(f"Could not load with flash_attention_2: {e}") print("Loading with standard attention (slower but more compatible)...") model = AutoModel.from_pretrained( model_name, trust_remote_code=True, use_safetensors=True ) print("✓ Loaded with standard attention") else: print("Loading with standard attention (slower but more compatible)...") model = AutoModel.from_pretrained( model_name, trust_remote_code=True, use_safetensors=True ) print("✓ Loaded with standard attention") # Move to GPU if available if torch.cuda.is_available(): model = model.eval().cuda().to(torch.bfloat16) print(f"✓ Model loaded on GPU: {torch.cuda.get_device_name(0)}") else: model = model.eval() print("⚠ Model loaded on CPU (will be slow)") except Exception as e: print(f"Error loading model: {e}") raise class ImageRequest(BaseModel): image: str # Base64 encoded image prompt: str = "\n<|grounding|>Convert the document to markdown. " base_size: int = 1024 image_size: int = 640 crop_mode: bool = True test_compress: bool = True # Enable compression/optimization (recommended: True per official docs) layout_only: bool = False # If True, only detect layout without detailed content extraction @app.get("/") def home(): return { "message": "DeepSeek-OCR Image Extraction API", "status": "ready" if model is not None else "loading", "gpu_available": torch.cuda.is_available(), "device": str(torch.cuda.get_device_name(0)) if torch.cuda.is_available() else "CPU" } @app.get("/health") def health(): return { "status": "healthy" if model is not None else "loading", "model_loaded": model is not None, "gpu_available": torch.cuda.is_available() } @app.post("/extract") async def extract_image(request: ImageRequest): """Extract text and bounding boxes from an image using DeepSeek-OCR.""" global model, tokenizer if model is None or tokenizer is None: raise HTTPException(status_code=503, detail="Model is still loading. Please try again in a moment.") try: if not request.image: raise HTTPException(status_code=400, detail="No image provided") # Decode base64 image try: image_bytes = base64.b64decode(request.image) except Exception as e: raise HTTPException(status_code=400, detail=f"Invalid base64 image: {e}") # Save to temporary file (DeepSeek-OCR expects a file path) import tempfile with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file: tmp_file.write(image_bytes) tmp_file_path = tmp_file.name try: # Get image dimensions image = Image.open(io.BytesIO(image_bytes)) img_width, img_height = image.size print(f"Processing image: {img_width}x{img_height}") # Run inference with grounding to get bounding boxes output_path = tempfile.mkdtemp() # Use simpler prompt for layout-only mode prompt = request.prompt if request.layout_only: prompt = "\n' in raw_model_output and '<|det|>' in raw_model_output: print(f"✓ Using raw stdout output - contains grounding tags ({len(raw_model_output)} chars)") result_text = raw_model_output # 2. Check if model returned result directly with grounding tags elif result and isinstance(result, str) and '<|ref|>' in result and '<|det|>' in result: print(f"✓ Using direct model return - contains grounding tags ({len(result)} chars)") result_text = result # 3. Check if saved file has grounding tags elif result_text and '<|ref|>' in result_text and '<|det|>' in result_text: print(f"✓ Using saved file - contains grounding tags ({len(result_text)} chars)") # result_text already set, no change needed # 4. Fallback: Use any available content (without grounding tags) elif result_text and len(result_text.strip()) > 50: print(f"⚠ Using saved file WITHOUT grounding tags ({len(result_text)} chars) - bounding boxes won't be available") elif result and isinstance(result, str) and len(result.strip()) > 50: print(f"⚠ Using direct model return WITHOUT grounding tags ({len(result)} chars)") result_text = result elif raw_model_output and len(raw_model_output.strip()) > 50: print(f"⚠ Using raw stdout WITHOUT grounding tags ({len(raw_model_output)} chars)") result_text = raw_model_output else: print("❌ WARNING: No usable output found from any source") result_text = result_text or "" print(f"Result preview: {result_text if result_text else 'No results found'}") print(f"Result image with boxes: {'Found' if result_image_with_boxes else 'Not found'}") print(f"Image patches: {len(image_patches)} patches found") # Parse the result with base_size for proper coordinate scaling extractions = parse_deepseek_result(result_text, img_width, img_height, request.base_size) print(f"✓ Parsed {len(extractions)} extractions from result") if extractions: # Show summary by type types_summary = {} for ext in extractions: ext_type = ext.get('type', 'unknown') types_summary[ext_type] = types_summary.get(ext_type, 0) + 1 print(f" Extraction types: {types_summary}") else: print(" ⚠ WARNING: No extractions parsed - check if result has grounding tags") # If layout_only mode, simplify the extractions if request.layout_only: layout_extractions = simplify_extractions_for_layout(extractions) print(f"Layout-only mode: Simplified {len(extractions)} extractions") else: layout_extractions = None # Extract patches organized by type (table, text, image) # IMPORTANT: Use the annotated image (result_with_boxes) for cropping because # the coordinates are relative to the processed image, not the original image_for_cropping = image_bytes if result_image_with_boxes: # Decode the annotated image to use for cropping try: annotated_image_bytes = base64.b64decode(result_image_with_boxes) # Verify it's a valid image test_img = Image.open(io.BytesIO(annotated_image_bytes)) img_for_crop_width, img_for_crop_height = test_img.size print(f"Using annotated image for cropping: {img_for_crop_width}x{img_for_crop_height} (original: {img_width}x{img_height})") image_for_cropping = annotated_image_bytes # Re-parse coordinates for annotated image dimensions # and add 200px padding around each box to avoid cutoff extractions = parse_deepseek_result( result_text, img_for_crop_width, img_for_crop_height, request.base_size, scale_coords=True, # Scale from base_size to annotated image size padding=200 # Add 200px padding around each box ) print(f"✓ Re-parsed {len(extractions)} extractions with 200px padding for annotated image") except Exception as e: print(f"⚠ Could not use annotated image for cropping: {e}, falling back to original") patches_by_type = extract_patches_by_type(image_for_cropping, extractions) # Clean the raw result by removing the special tags to get plain text import re # Remove all special tags but keep the content clean_text = re.sub(r'<\|ref\|>.*?<\|/ref\|><\|det\|>\[\[[\d, ]+\]\]<\|/det\|>\n?', '', result_text) # Create a simplified list of bounding boxes for easy drawing bounding_boxes = [ { "type": ext["type"], "x1": ext["bbox"]["x1"], "y1": ext["bbox"]["y1"], "x2": ext["bbox"]["x2"], "y2": ext["bbox"]["y2"] } for ext in extractions ] print(f"Extracted patches - tables: {len(patches_by_type['table'])}, text: {len(patches_by_type['text'])}, images: {len(patches_by_type['image'])}, other: {len(patches_by_type['other'])}") response_data = { "document_type": "image", "image_dimensions": { "width": img_width, "height": img_height }, "layout_only_mode": request.layout_only, "bounding_boxes": bounding_boxes, # Simplified list for drawing "num_extractions": len(extractions), # Counts "num_tables": len(patches_by_type["table"]), "num_texts": len(patches_by_type["text"]), "num_images_extracted": len(patches_by_type["image"]) } # Add layout-only or full extractions based on mode if request.layout_only: response_data["layout_elements"] = layout_extractions # Add a structured summary for easy parsing response_data["layout_summary"] = { "total_elements": len(layout_extractions), "elements_by_type": { "tables": [elem for elem in layout_extractions if elem["type"] == "table"], "text_blocks": [elem for elem in layout_extractions if elem["type"] == "text"], "images": [elem for elem in layout_extractions if elem["type"] == "image"], "other": [elem for elem in layout_extractions if elem["type"] not in ["table", "text", "image"]] }, "counts": { "tables": len([e for e in layout_extractions if e["type"] == "table"]), "text_blocks": len([e for e in layout_extractions if e["type"] == "text"]), "images": len([e for e in layout_extractions if e["type"] == "image"]), "other": len([e for e in layout_extractions if e["type"] not in ["table", "text", "image"]]) } } # Still include patches but without full content in extractions response_data["table_patches"] = patches_by_type["table"] response_data["text_patches"] = patches_by_type["text"] response_data["image_patches_extracted"] = patches_by_type["image"] response_data["other_patches"] = patches_by_type["other"] else: response_data["raw_result"] = result_text # Full raw output with tags response_data["raw_text"] = clean_text.strip() # Clean text without tags response_data["extractions"] = extractions # Full extractions with text and bboxes # Patches organized by type response_data["table_patches"] = patches_by_type["table"] response_data["text_patches"] = patches_by_type["text"] response_data["image_patches_extracted"] = patches_by_type["image"] response_data["other_patches"] = patches_by_type["other"] # Add result image with bounding boxes if available if result_image_with_boxes: response_data["result_image_with_boxes"] = result_image_with_boxes # Add model's processed image patches if available if image_patches: response_data["model_image_patches"] = image_patches response_data["num_model_patches"] = len(image_patches) return response_data finally: # Clean up temporary file if os.path.exists(tmp_file_path): os.unlink(tmp_file_path) except Exception as e: import traceback error_details = traceback.format_exc() print(f"Error in extract_image: {error_details}") raise HTTPException(status_code=500, detail=str(e)) def read_saved_results(output_path: str, image_file: str) -> tuple: """ Read the saved OCR results from the output directory. DeepSeek-OCR saves results as .mmd (markdown) files when save_results=True. Returns: (text_content, result_image_base64, image_patches_base64_list) """ import glob print(f"Looking for results in: {output_path}") # Check for all files in the directory all_files = glob.glob(os.path.join(output_path, "*")) print(f"All files in output_path: {all_files}") text_content = "" result_image_base64 = None image_patches = [] # Look for .mmd (markdown) files first mmd_files = glob.glob(os.path.join(output_path, "*.mmd")) print(f"Found {len(mmd_files)} .mmd files: {mmd_files}") if mmd_files: # Read the markdown file (usually result.mmd) with open(mmd_files[0], 'r', encoding='utf-8') as f: text_content = f.read() print(f"Successfully read {len(text_content)} characters from {mmd_files[0]}") else: # Try .txt files as fallback txt_files = glob.glob(os.path.join(output_path, "*.txt")) if txt_files: with open(txt_files[0], 'r', encoding='utf-8') as f: text_content = f.read() else: print("Warning: No .mmd or .txt files found in output directory") # Read the result image with bounding boxes result_with_boxes = os.path.join(output_path, "result_with_boxes.jpg") if os.path.exists(result_with_boxes): try: with open(result_with_boxes, 'rb') as f: result_image_base64 = base64.b64encode(f.read()).decode('utf-8') print(f"Successfully read result_with_boxes.jpg") except Exception as e: print(f"Error reading result_with_boxes.jpg: {e}") # Read all image patches from the images directory images_dir = os.path.join(output_path, "images") if os.path.exists(images_dir) and os.path.isdir(images_dir): try: # Get all image files in the directory image_files = sorted(glob.glob(os.path.join(images_dir, "*.[jp][pn]g")) + glob.glob(os.path.join(images_dir, "*.jpeg"))) print(f"Found {len(image_files)} image patches in images/ directory") for img_file in image_files: try: with open(img_file, 'rb') as f: img_base64 = base64.b64encode(f.read()).decode('utf-8') image_patches.append({ "filename": os.path.basename(img_file), "data": img_base64 }) except Exception as e: print(f"Error reading {img_file}: {e}") except Exception as e: print(f"Error reading images directory: {e}") return text_content, result_image_base64, image_patches def extract_patches_by_type(image_bytes: bytes, extractions: List[Dict]) -> Dict[str, List[Dict]]: """ Extract image patches for each extraction based on bounding boxes. Returns patches organized by type (table, text, image). """ from PIL import Image import io patches_by_type = { "table": [], "text": [], "image": [], "other": [] } if not extractions: print("⚠ extract_patches_by_type: No extractions provided, returning empty patches") return patches_by_type print(f"→ Extracting {len(extractions)} patches from image...") try: # Open the original image image = Image.open(io.BytesIO(image_bytes)) print(f" Image size: {image.size}") for idx, extraction in enumerate(extractions): bbox = extraction["bbox"] ext_type = extraction["type"] # Skip if invalid bbox if bbox["width"] <= 0 or bbox["height"] <= 0: print(f"Skipping patch {idx} ({ext_type}): invalid bbox with width={bbox['width']}, height={bbox['height']}") continue try: # Crop the image using the bounding box cropped = image.crop((bbox["x1"], bbox["y1"], bbox["x2"], bbox["y2"])) # Convert to base64 buffer = io.BytesIO() cropped.save(buffer, format='PNG') patch_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8') # Add to appropriate category patch_data = { "index": idx, "type": ext_type, "bbox": bbox, "text_preview": extraction["text"][:100] if len(extraction["text"]) > 100 else extraction["text"], "data": patch_base64 } # Categorize by type if ext_type in patches_by_type: patches_by_type[ext_type].append(patch_data) else: patches_by_type["other"].append(patch_data) # Success log for first few patches if idx < 3: print(f" ✓ Extracted patch {idx}: {ext_type} at ({bbox['x1']},{bbox['y1']})-({bbox['x2']},{bbox['y2']})") except Exception as e: print(f" ✗ Error cropping patch {idx} ({ext_type}): {e}") continue # Summary total_patches = sum(len(patches) for patches in patches_by_type.values()) print(f"✓ Successfully extracted {total_patches} patches total") except Exception as e: print(f"✗ Error extracting patches: {e}") return patches_by_type def simplify_extractions_for_layout(extractions: List[Dict]) -> List[Dict]: """ Simplify extractions for layout-only mode. Returns consistently structured layout elements with normalized bounding boxes. Always returns: type, bbox (with x1, y1, x2, y2, width, height), content_preview, dimensions. """ simplified = [] for idx, ext in enumerate(extractions): bbox = ext["bbox"] # Ensure bbox always has all required fields normalized_bbox = { "x1": bbox.get("x1", 0), "y1": bbox.get("y1", 0), "x2": bbox.get("x2", 0), "y2": bbox.get("y2", 0), "width": bbox.get("width", bbox.get("x2", 0) - bbox.get("x1", 0)), "height": bbox.get("height", bbox.get("y2", 0) - bbox.get("y1", 0)) } # For tables, just indicate it's a table without the complex HTML if ext["type"] == "table": preview = f"Table" description = f"Table element ({normalized_bbox['width']}×{normalized_bbox['height']}px)" # For text, keep first 50 chars elif ext["type"] == "text": text = ext["text"][:50].strip() preview = f"{text}..." if len(ext["text"]) > 50 else text description = f"Text block ({normalized_bbox['width']}×{normalized_bbox['height']}px)" # For images, just indicate it's an image elif ext["type"] == "image": preview = f"Image" description = f"Image element ({normalized_bbox['width']}×{normalized_bbox['height']}px)" else: preview = ext["text"][:50] if ext["text"] else ext["type"] description = f"{ext['type'].capitalize()} element ({normalized_bbox['width']}×{normalized_bbox['height']}px)" simplified.append({ "id": idx, "type": ext["type"], "bbox": normalized_bbox, "position": { "top_left": {"x": normalized_bbox["x1"], "y": normalized_bbox["y1"]}, "bottom_right": {"x": normalized_bbox["x2"], "y": normalized_bbox["y2"]}, "center": { "x": (normalized_bbox["x1"] + normalized_bbox["x2"]) // 2, "y": (normalized_bbox["y1"] + normalized_bbox["y2"]) // 2 } }, "dimensions": { "width": normalized_bbox["width"], "height": normalized_bbox["height"], "area": normalized_bbox["width"] * normalized_bbox["height"] }, "content_preview": preview, "description": description }) return simplified def parse_deepseek_result(result: Any, img_width: int, img_height: int, base_size: int = 1024, scale_coords: bool = True, padding: int = 0) -> List[Dict]: """ Parse the DeepSeek-OCR result to extract text and bounding boxes. DeepSeek-OCR format: <|ref|>TYPE<|/ref|><|det|>[[x1, y1, x2, y2]]<|/det|> CONTENT The bounding boxes are in DeepSeek's coordinate space (based on base_size), so we need to scale them to the actual image dimensions. Args: result: The model output text img_width: Target image width img_height: Target image height base_size: Model's coordinate space size (usually 1024) scale_coords: Whether to scale coordinates (False if already in target space) padding: Pixels to add around each bounding box (while keeping in bounds) """ import re extractions = [] if not isinstance(result, str): return extractions # Calculate the scale factors from the model's coordinate space to actual image # DeepSeek-OCR appears to use a square coordinate space (base_size x base_size) # regardless of the actual image aspect ratio # So coordinates are always in the range [0, base_size] for both x and y if scale_coords: scale_x = img_width / base_size scale_y = img_height / base_size print(f"Image dimensions: {img_width}x{img_height}, base_size: {base_size}") print(f"Coordinate space: {base_size}x{base_size}, scale_x: {scale_x:.2f}, scale_y: {scale_y:.2f}") else: scale_x = 1.0 scale_y = 1.0 print(f"Using coordinates as-is (no scaling) for image: {img_width}x{img_height}") # Pattern to match: <|ref|>TYPE<|/ref|><|det|>[[x1, y1, x2, y2]]<|/det|> pattern = r'<\|ref\|>(.*?)<\|/ref\|><\|det\|>\[\[([\d, ]+)\]\]<\|/det\|>' # Find all matches with their positions matches = list(re.finditer(pattern, result)) for i, match in enumerate(matches): ref_type = match.group(1) # text, table, image, etc. bbox_str = match.group(2) # "x1, y1, x2, y2" # Parse bounding box coordinates try: coords = [int(x.strip()) for x in bbox_str.split(',')] if len(coords) == 4: x1, y1, x2, y2 = coords # Scale coordinates to actual image dimensions using separate scale factors x1_scaled = int(x1 * scale_x) y1_scaled = int(y1 * scale_y) x2_scaled = int(x2 * scale_x) y2_scaled = int(y2 * scale_y) # Add padding around bounding box (before bounds checking) if padding > 0: original_x1, original_y1, original_x2, original_y2 = x1_scaled, y1_scaled, x2_scaled, y2_scaled x1_scaled -= padding y1_scaled -= padding x2_scaled += padding y2_scaled += padding # Log first box padding for debugging if i == 0: print(f" Padding applied: {padding}px around boxes (e.g., box 0: {original_x1},{original_y1},{original_x2},{original_y2} -> {x1_scaled},{y1_scaled},{x2_scaled},{y2_scaled})") # Ensure coordinates are within image bounds x1_scaled = max(0, min(x1_scaled, img_width)) y1_scaled = max(0, min(y1_scaled, img_height)) x2_scaled = max(0, min(x2_scaled, img_width)) y2_scaled = max(0, min(y2_scaled, img_height)) bbox = { "x1": x1_scaled, "y1": y1_scaled, "x2": x2_scaled, "y2": y2_scaled, "width": x2_scaled - x1_scaled, "height": y2_scaled - y1_scaled, "original_coords": {"x1": x1, "y1": y1, "x2": x2, "y2": y2} # Keep original for debugging } else: bbox = {"x1": 0, "y1": 0, "x2": 0, "y2": 0, "width": 0, "height": 0} except Exception as e: print(f"Error parsing bounding box: {e} for bounding box: {bbox_str} for type {ref_type}") bbox = {"x1": 0, "y1": 0, "x2": 0, "y2": 0, "width": 0, "height": 0} # Extract content after this tag until the next tag (or end of string) content_start = match.end() if i + 1 < len(matches): content_end = matches[i + 1].start() else: content_end = len(result) content = result[content_start:content_end].strip() # Skip empty content or just whitespace/newlines if content and content not in ['\n', '\n\n', '**']: extractions.append({ "type": ref_type, "text": content, "bbox": bbox }) return extractions @app.post("/extract_simple") async def extract_simple(request: ImageRequest): """ Simplified endpoint that returns the raw DeepSeek-OCR output for inspection and format understanding. """ global model, tokenizer if model is None or tokenizer is None: raise HTTPException(status_code=503, detail="Model is still loading") try: if not request.image: raise HTTPException(status_code=400, detail="No image provided") image_bytes = base64.b64decode(request.image) import tempfile with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file: tmp_file.write(image_bytes) tmp_file_path = tmp_file.name try: output_path = tempfile.mkdtemp() # Capture stdout to get the raw model output with grounding tags import sys from io import StringIO old_stdout = sys.stdout sys.stdout = captured_output = StringIO() try: result = model.infer( tokenizer, prompt=request.prompt, image_file=tmp_file_path, output_path=output_path, base_size=request.base_size, image_size=request.image_size, crop_mode=request.crop_mode, save_results=True, test_compress=request.test_compress ) finally: sys.stdout = old_stdout raw_model_output = captured_output.getvalue() # Get image dimensions image = Image.open(io.BytesIO(image_bytes)) img_width, img_height = image.size # Read the result from saved files result_text, result_image_with_boxes, image_patches = read_saved_results(output_path, tmp_file_path) # Use raw model output if it contains grounding tags if '<|ref|>' in raw_model_output and '<|det|>' in raw_model_output: result_text = raw_model_output # Parse extractions and get patches by type with base_size for proper coordinate scaling # Use annotated image for cropping if available image_for_cropping = image_bytes if result_image_with_boxes: try: annotated_image_bytes = base64.b64decode(result_image_with_boxes) test_img = Image.open(io.BytesIO(annotated_image_bytes)) img_for_crop_width, img_for_crop_height = test_img.size print(f"Using annotated image for cropping: {img_for_crop_width}x{img_for_crop_height}") image_for_cropping = annotated_image_bytes # Re-parse with annotated image dimensions and 200px padding extractions = parse_deepseek_result( result_text, img_for_crop_width, img_for_crop_height, request.base_size, scale_coords=True, padding=200 ) except Exception as e: print(f"Could not use annotated image: {e}") extractions = parse_deepseek_result(result_text, img_width, img_height, request.base_size, padding=200) else: extractions = parse_deepseek_result(result_text, img_width, img_height, request.base_size, padding=200) patches_by_type = extract_patches_by_type(image_for_cropping, extractions) response = { "result_type": str(type(result)), "result": result_text[:5000] if result_text else "No results found", "full_result": result_text, "output_path": output_path, "num_extractions": len(extractions), "num_tables": len(patches_by_type["table"]), "num_texts": len(patches_by_type["text"]), "num_images": len(patches_by_type["image"]) } # Add images if available if result_image_with_boxes: response["result_image_with_boxes"] = result_image_with_boxes if image_patches: response["model_image_patches"] = image_patches response["num_model_patches"] = len(image_patches) # Add patches by type response["table_patches"] = patches_by_type["table"] response["text_patches"] = patches_by_type["text"] response["image_patches_extracted"] = patches_by_type["image"] response["other_patches"] = patches_by_type["other"] return response finally: if os.path.exists(tmp_file_path): os.unlink(tmp_file_path) except Exception as e: import traceback error_details = traceback.format_exc() print(f"Error: {error_details}") raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)