Spaces:
Sleeping
Sleeping
| import os | |
| # IMPORTANT: Set OpenMP/MKL threads BEFORE importing torch/numpy | |
| # This must be done first to avoid threading conflicts | |
| os.environ['OMP_NUM_THREADS'] = '1' | |
| os.environ['MKL_NUM_THREADS'] = '1' | |
| os.environ['OPENBLAS_NUM_THREADS'] = '1' | |
| os.environ['NUMEXPR_NUM_THREADS'] = '1' | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from typing import Dict, Any, List | |
| from transformers import AutoModel, AutoTokenizer | |
| import torch | |
| from PIL import Image | |
| import io | |
| import base64 | |
| app = FastAPI() | |
| # Global variables for model | |
| model = None | |
| tokenizer = None | |
| async def load_model(): | |
| """Load the model on startup""" | |
| global model, tokenizer | |
| try: | |
| model_name = 'deepseek-ai/DeepSeek-OCR' | |
| print(f"Loading model: {model_name}") | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
| # Check if GPU supports Flash Attention (Ampere or newer, compute capability >= 8.0) | |
| use_flash_attention = False | |
| if torch.cuda.is_available(): | |
| compute_capability = torch.cuda.get_device_capability() | |
| print(f"GPU Compute Capability: {compute_capability}") | |
| # Flash Attention requires Ampere (8.0) or newer | |
| if compute_capability[0] >= 8: | |
| use_flash_attention = True | |
| print("GPU supports Flash Attention 2.0") | |
| else: | |
| print(f"GPU does not support Flash Attention 2.0 (requires compute capability >= 8.0, got {compute_capability[0]}.{compute_capability[1]})") | |
| # Load model with appropriate attention implementation | |
| if use_flash_attention: | |
| try: | |
| print("Loading with flash_attention_2...") | |
| model = AutoModel.from_pretrained( | |
| model_name, | |
| attn_implementation='flash_attention_2', | |
| trust_remote_code=True, | |
| use_safetensors=True | |
| ) | |
| print("✓ Loaded with flash_attention_2") | |
| except Exception as e: | |
| print(f"Could not load with flash_attention_2: {e}") | |
| print("Loading with standard attention (slower but more compatible)...") | |
| model = AutoModel.from_pretrained( | |
| model_name, | |
| trust_remote_code=True, | |
| use_safetensors=True | |
| ) | |
| print("✓ Loaded with standard attention") | |
| else: | |
| print("Loading with standard attention (slower but more compatible)...") | |
| model = AutoModel.from_pretrained( | |
| model_name, | |
| trust_remote_code=True, | |
| use_safetensors=True | |
| ) | |
| print("✓ Loaded with standard attention") | |
| # Move to GPU if available | |
| if torch.cuda.is_available(): | |
| model = model.eval().cuda().to(torch.bfloat16) | |
| print(f"✓ Model loaded on GPU: {torch.cuda.get_device_name(0)}") | |
| else: | |
| model = model.eval() | |
| print("⚠ Model loaded on CPU (will be slow)") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| raise | |
| class ImageRequest(BaseModel): | |
| image: str # Base64 encoded image | |
| prompt: str = "<image>\n<|grounding|>Convert the document to markdown. " | |
| base_size: int = 1024 | |
| image_size: int = 640 | |
| crop_mode: bool = True | |
| test_compress: bool = True # Enable compression/optimization (recommended: True per official docs) | |
| layout_only: bool = False # If True, only detect layout without detailed content extraction | |
| def home(): | |
| return { | |
| "message": "DeepSeek-OCR Image Extraction API", | |
| "status": "ready" if model is not None else "loading", | |
| "gpu_available": torch.cuda.is_available(), | |
| "device": str(torch.cuda.get_device_name(0)) if torch.cuda.is_available() else "CPU" | |
| } | |
| def health(): | |
| return { | |
| "status": "healthy" if model is not None else "loading", | |
| "model_loaded": model is not None, | |
| "gpu_available": torch.cuda.is_available() | |
| } | |
| async def extract_image(request: ImageRequest): | |
| """Extract text and bounding boxes from an image using DeepSeek-OCR.""" | |
| global model, tokenizer | |
| if model is None or tokenizer is None: | |
| raise HTTPException(status_code=503, detail="Model is still loading. Please try again in a moment.") | |
| try: | |
| if not request.image: | |
| raise HTTPException(status_code=400, detail="No image provided") | |
| # Decode base64 image | |
| try: | |
| image_bytes = base64.b64decode(request.image) | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Invalid base64 image: {e}") | |
| # Save to temporary file (DeepSeek-OCR expects a file path) | |
| import tempfile | |
| with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file: | |
| tmp_file.write(image_bytes) | |
| tmp_file_path = tmp_file.name | |
| try: | |
| # Get image dimensions | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| img_width, img_height = image.size | |
| print(f"Processing image: {img_width}x{img_height}") | |
| # Run inference with grounding to get bounding boxes | |
| output_path = tempfile.mkdtemp() | |
| # Use simpler prompt for layout-only mode | |
| prompt = request.prompt | |
| if request.layout_only: | |
| prompt = "<image>\n<Identify all objects, table, diagrams, and text and output them in bounding boxes.o " | |
| print("Using layout-only mode with structured bounding boxes") | |
| # Capture stdout to get the raw model output with grounding tags | |
| import sys | |
| from io import StringIO | |
| # Redirect stdout to capture the model's output | |
| old_stdout = sys.stdout | |
| sys.stdout = captured_output = StringIO() | |
| try: | |
| # Call model.infer with parameters matching official documentation | |
| result = model.infer( | |
| tokenizer, | |
| prompt=prompt, | |
| image_file=tmp_file_path, | |
| output_path=output_path, | |
| base_size=request.base_size, | |
| image_size=request.image_size, | |
| crop_mode=request.crop_mode, | |
| save_results=True, | |
| test_compress=request.test_compress | |
| ) | |
| finally: | |
| # Restore stdout | |
| sys.stdout = old_stdout | |
| # Get the captured output (contains the raw grounding tags) | |
| raw_model_output = captured_output.getvalue() | |
| print(f"Extraction complete. Result type: {type(result)}") | |
| print(f"Result value: {result if result and len(str(result)) < 200 else (str(result)[:200] + '...' if result else 'None')}") | |
| print(f"Output path: {output_path}") | |
| print(f"Captured {len(raw_model_output)} characters from model output") | |
| # Debug: print first 500 chars of raw output | |
| if raw_model_output: | |
| print(f"Raw output preview: {raw_model_output[:500]}") | |
| # Read the result from saved files | |
| # DeepSeek-OCR saves results to output_path when save_results=True | |
| result_text, result_image_with_boxes, image_patches = read_saved_results(output_path, tmp_file_path) | |
| # Try multiple sources for the result text | |
| # PRIORITY: Always prefer content with grounding tags over content without | |
| # 1. Check if raw stdout has grounding tags (HIGHEST PRIORITY) | |
| if '<|ref|>' in raw_model_output and '<|det|>' in raw_model_output: | |
| print(f"✓ Using raw stdout output - contains grounding tags ({len(raw_model_output)} chars)") | |
| result_text = raw_model_output | |
| # 2. Check if model returned result directly with grounding tags | |
| elif result and isinstance(result, str) and '<|ref|>' in result and '<|det|>' in result: | |
| print(f"✓ Using direct model return - contains grounding tags ({len(result)} chars)") | |
| result_text = result | |
| # 3. Check if saved file has grounding tags | |
| elif result_text and '<|ref|>' in result_text and '<|det|>' in result_text: | |
| print(f"✓ Using saved file - contains grounding tags ({len(result_text)} chars)") | |
| # result_text already set, no change needed | |
| # 4. Fallback: Use any available content (without grounding tags) | |
| elif result_text and len(result_text.strip()) > 50: | |
| print(f"⚠ Using saved file WITHOUT grounding tags ({len(result_text)} chars) - bounding boxes won't be available") | |
| elif result and isinstance(result, str) and len(result.strip()) > 50: | |
| print(f"⚠ Using direct model return WITHOUT grounding tags ({len(result)} chars)") | |
| result_text = result | |
| elif raw_model_output and len(raw_model_output.strip()) > 50: | |
| print(f"⚠ Using raw stdout WITHOUT grounding tags ({len(raw_model_output)} chars)") | |
| result_text = raw_model_output | |
| else: | |
| print("❌ WARNING: No usable output found from any source") | |
| result_text = result_text or "" | |
| print(f"Result preview: {result_text if result_text else 'No results found'}") | |
| print(f"Result image with boxes: {'Found' if result_image_with_boxes else 'Not found'}") | |
| print(f"Image patches: {len(image_patches)} patches found") | |
| # Parse the result with base_size for proper coordinate scaling | |
| extractions = parse_deepseek_result(result_text, img_width, img_height, request.base_size) | |
| print(f"✓ Parsed {len(extractions)} extractions from result") | |
| if extractions: | |
| # Show summary by type | |
| types_summary = {} | |
| for ext in extractions: | |
| ext_type = ext.get('type', 'unknown') | |
| types_summary[ext_type] = types_summary.get(ext_type, 0) + 1 | |
| print(f" Extraction types: {types_summary}") | |
| else: | |
| print(" ⚠ WARNING: No extractions parsed - check if result has grounding tags") | |
| # If layout_only mode, simplify the extractions | |
| if request.layout_only: | |
| layout_extractions = simplify_extractions_for_layout(extractions) | |
| print(f"Layout-only mode: Simplified {len(extractions)} extractions") | |
| else: | |
| layout_extractions = None | |
| # Extract patches organized by type (table, text, image) | |
| # IMPORTANT: Use the annotated image (result_with_boxes) for cropping because | |
| # the coordinates are relative to the processed image, not the original | |
| image_for_cropping = image_bytes | |
| if result_image_with_boxes: | |
| # Decode the annotated image to use for cropping | |
| try: | |
| annotated_image_bytes = base64.b64decode(result_image_with_boxes) | |
| # Verify it's a valid image | |
| test_img = Image.open(io.BytesIO(annotated_image_bytes)) | |
| img_for_crop_width, img_for_crop_height = test_img.size | |
| print(f"Using annotated image for cropping: {img_for_crop_width}x{img_for_crop_height} (original: {img_width}x{img_height})") | |
| image_for_cropping = annotated_image_bytes | |
| # Re-parse coordinates for annotated image dimensions | |
| # and add 200px padding around each box to avoid cutoff | |
| extractions = parse_deepseek_result( | |
| result_text, | |
| img_for_crop_width, | |
| img_for_crop_height, | |
| request.base_size, | |
| scale_coords=True, # Scale from base_size to annotated image size | |
| padding=200 # Add 200px padding around each box | |
| ) | |
| print(f"✓ Re-parsed {len(extractions)} extractions with 200px padding for annotated image") | |
| except Exception as e: | |
| print(f"⚠ Could not use annotated image for cropping: {e}, falling back to original") | |
| patches_by_type = extract_patches_by_type(image_for_cropping, extractions) | |
| # Clean the raw result by removing the special tags to get plain text | |
| import re | |
| # Remove all special tags but keep the content | |
| clean_text = re.sub(r'<\|ref\|>.*?<\|/ref\|><\|det\|>\[\[[\d, ]+\]\]<\|/det\|>\n?', '', result_text) | |
| # Create a simplified list of bounding boxes for easy drawing | |
| bounding_boxes = [ | |
| { | |
| "type": ext["type"], | |
| "x1": ext["bbox"]["x1"], | |
| "y1": ext["bbox"]["y1"], | |
| "x2": ext["bbox"]["x2"], | |
| "y2": ext["bbox"]["y2"] | |
| } | |
| for ext in extractions | |
| ] | |
| print(f"Extracted patches - tables: {len(patches_by_type['table'])}, text: {len(patches_by_type['text'])}, images: {len(patches_by_type['image'])}, other: {len(patches_by_type['other'])}") | |
| response_data = { | |
| "document_type": "image", | |
| "image_dimensions": { | |
| "width": img_width, | |
| "height": img_height | |
| }, | |
| "layout_only_mode": request.layout_only, | |
| "bounding_boxes": bounding_boxes, # Simplified list for drawing | |
| "num_extractions": len(extractions), | |
| # Counts | |
| "num_tables": len(patches_by_type["table"]), | |
| "num_texts": len(patches_by_type["text"]), | |
| "num_images_extracted": len(patches_by_type["image"]) | |
| } | |
| # Add layout-only or full extractions based on mode | |
| if request.layout_only: | |
| response_data["layout_elements"] = layout_extractions | |
| # Add a structured summary for easy parsing | |
| response_data["layout_summary"] = { | |
| "total_elements": len(layout_extractions), | |
| "elements_by_type": { | |
| "tables": [elem for elem in layout_extractions if elem["type"] == "table"], | |
| "text_blocks": [elem for elem in layout_extractions if elem["type"] == "text"], | |
| "images": [elem for elem in layout_extractions if elem["type"] == "image"], | |
| "other": [elem for elem in layout_extractions if elem["type"] not in ["table", "text", "image"]] | |
| }, | |
| "counts": { | |
| "tables": len([e for e in layout_extractions if e["type"] == "table"]), | |
| "text_blocks": len([e for e in layout_extractions if e["type"] == "text"]), | |
| "images": len([e for e in layout_extractions if e["type"] == "image"]), | |
| "other": len([e for e in layout_extractions if e["type"] not in ["table", "text", "image"]]) | |
| } | |
| } | |
| # Still include patches but without full content in extractions | |
| response_data["table_patches"] = patches_by_type["table"] | |
| response_data["text_patches"] = patches_by_type["text"] | |
| response_data["image_patches_extracted"] = patches_by_type["image"] | |
| response_data["other_patches"] = patches_by_type["other"] | |
| else: | |
| response_data["raw_result"] = result_text # Full raw output with tags | |
| response_data["raw_text"] = clean_text.strip() # Clean text without tags | |
| response_data["extractions"] = extractions # Full extractions with text and bboxes | |
| # Patches organized by type | |
| response_data["table_patches"] = patches_by_type["table"] | |
| response_data["text_patches"] = patches_by_type["text"] | |
| response_data["image_patches_extracted"] = patches_by_type["image"] | |
| response_data["other_patches"] = patches_by_type["other"] | |
| # Add result image with bounding boxes if available | |
| if result_image_with_boxes: | |
| response_data["result_image_with_boxes"] = result_image_with_boxes | |
| # Add model's processed image patches if available | |
| if image_patches: | |
| response_data["model_image_patches"] = image_patches | |
| response_data["num_model_patches"] = len(image_patches) | |
| return response_data | |
| finally: | |
| # Clean up temporary file | |
| if os.path.exists(tmp_file_path): | |
| os.unlink(tmp_file_path) | |
| except Exception as e: | |
| import traceback | |
| error_details = traceback.format_exc() | |
| print(f"Error in extract_image: {error_details}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def read_saved_results(output_path: str, image_file: str) -> tuple: | |
| """ | |
| Read the saved OCR results from the output directory. | |
| DeepSeek-OCR saves results as .mmd (markdown) files when save_results=True. | |
| Returns: (text_content, result_image_base64, image_patches_base64_list) | |
| """ | |
| import glob | |
| print(f"Looking for results in: {output_path}") | |
| # Check for all files in the directory | |
| all_files = glob.glob(os.path.join(output_path, "*")) | |
| print(f"All files in output_path: {all_files}") | |
| text_content = "" | |
| result_image_base64 = None | |
| image_patches = [] | |
| # Look for .mmd (markdown) files first | |
| mmd_files = glob.glob(os.path.join(output_path, "*.mmd")) | |
| print(f"Found {len(mmd_files)} .mmd files: {mmd_files}") | |
| if mmd_files: | |
| # Read the markdown file (usually result.mmd) | |
| with open(mmd_files[0], 'r', encoding='utf-8') as f: | |
| text_content = f.read() | |
| print(f"Successfully read {len(text_content)} characters from {mmd_files[0]}") | |
| else: | |
| # Try .txt files as fallback | |
| txt_files = glob.glob(os.path.join(output_path, "*.txt")) | |
| if txt_files: | |
| with open(txt_files[0], 'r', encoding='utf-8') as f: | |
| text_content = f.read() | |
| else: | |
| print("Warning: No .mmd or .txt files found in output directory") | |
| # Read the result image with bounding boxes | |
| result_with_boxes = os.path.join(output_path, "result_with_boxes.jpg") | |
| if os.path.exists(result_with_boxes): | |
| try: | |
| with open(result_with_boxes, 'rb') as f: | |
| result_image_base64 = base64.b64encode(f.read()).decode('utf-8') | |
| print(f"Successfully read result_with_boxes.jpg") | |
| except Exception as e: | |
| print(f"Error reading result_with_boxes.jpg: {e}") | |
| # Read all image patches from the images directory | |
| images_dir = os.path.join(output_path, "images") | |
| if os.path.exists(images_dir) and os.path.isdir(images_dir): | |
| try: | |
| # Get all image files in the directory | |
| image_files = sorted(glob.glob(os.path.join(images_dir, "*.[jp][pn]g")) + | |
| glob.glob(os.path.join(images_dir, "*.jpeg"))) | |
| print(f"Found {len(image_files)} image patches in images/ directory") | |
| for img_file in image_files: | |
| try: | |
| with open(img_file, 'rb') as f: | |
| img_base64 = base64.b64encode(f.read()).decode('utf-8') | |
| image_patches.append({ | |
| "filename": os.path.basename(img_file), | |
| "data": img_base64 | |
| }) | |
| except Exception as e: | |
| print(f"Error reading {img_file}: {e}") | |
| except Exception as e: | |
| print(f"Error reading images directory: {e}") | |
| return text_content, result_image_base64, image_patches | |
| def extract_patches_by_type(image_bytes: bytes, extractions: List[Dict]) -> Dict[str, List[Dict]]: | |
| """ | |
| Extract image patches for each extraction based on bounding boxes. | |
| Returns patches organized by type (table, text, image). | |
| """ | |
| from PIL import Image | |
| import io | |
| patches_by_type = { | |
| "table": [], | |
| "text": [], | |
| "image": [], | |
| "other": [] | |
| } | |
| if not extractions: | |
| print("⚠ extract_patches_by_type: No extractions provided, returning empty patches") | |
| return patches_by_type | |
| print(f"→ Extracting {len(extractions)} patches from image...") | |
| try: | |
| # Open the original image | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| print(f" Image size: {image.size}") | |
| for idx, extraction in enumerate(extractions): | |
| bbox = extraction["bbox"] | |
| ext_type = extraction["type"] | |
| # Skip if invalid bbox | |
| if bbox["width"] <= 0 or bbox["height"] <= 0: | |
| print(f"Skipping patch {idx} ({ext_type}): invalid bbox with width={bbox['width']}, height={bbox['height']}") | |
| continue | |
| try: | |
| # Crop the image using the bounding box | |
| cropped = image.crop((bbox["x1"], bbox["y1"], bbox["x2"], bbox["y2"])) | |
| # Convert to base64 | |
| buffer = io.BytesIO() | |
| cropped.save(buffer, format='PNG') | |
| patch_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8') | |
| # Add to appropriate category | |
| patch_data = { | |
| "index": idx, | |
| "type": ext_type, | |
| "bbox": bbox, | |
| "text_preview": extraction["text"][:100] if len(extraction["text"]) > 100 else extraction["text"], | |
| "data": patch_base64 | |
| } | |
| # Categorize by type | |
| if ext_type in patches_by_type: | |
| patches_by_type[ext_type].append(patch_data) | |
| else: | |
| patches_by_type["other"].append(patch_data) | |
| # Success log for first few patches | |
| if idx < 3: | |
| print(f" ✓ Extracted patch {idx}: {ext_type} at ({bbox['x1']},{bbox['y1']})-({bbox['x2']},{bbox['y2']})") | |
| except Exception as e: | |
| print(f" ✗ Error cropping patch {idx} ({ext_type}): {e}") | |
| continue | |
| # Summary | |
| total_patches = sum(len(patches) for patches in patches_by_type.values()) | |
| print(f"✓ Successfully extracted {total_patches} patches total") | |
| except Exception as e: | |
| print(f"✗ Error extracting patches: {e}") | |
| return patches_by_type | |
| def simplify_extractions_for_layout(extractions: List[Dict]) -> List[Dict]: | |
| """ | |
| Simplify extractions for layout-only mode. | |
| Returns consistently structured layout elements with normalized bounding boxes. | |
| Always returns: type, bbox (with x1, y1, x2, y2, width, height), content_preview, dimensions. | |
| """ | |
| simplified = [] | |
| for idx, ext in enumerate(extractions): | |
| bbox = ext["bbox"] | |
| # Ensure bbox always has all required fields | |
| normalized_bbox = { | |
| "x1": bbox.get("x1", 0), | |
| "y1": bbox.get("y1", 0), | |
| "x2": bbox.get("x2", 0), | |
| "y2": bbox.get("y2", 0), | |
| "width": bbox.get("width", bbox.get("x2", 0) - bbox.get("x1", 0)), | |
| "height": bbox.get("height", bbox.get("y2", 0) - bbox.get("y1", 0)) | |
| } | |
| # For tables, just indicate it's a table without the complex HTML | |
| if ext["type"] == "table": | |
| preview = f"Table" | |
| description = f"Table element ({normalized_bbox['width']}×{normalized_bbox['height']}px)" | |
| # For text, keep first 50 chars | |
| elif ext["type"] == "text": | |
| text = ext["text"][:50].strip() | |
| preview = f"{text}..." if len(ext["text"]) > 50 else text | |
| description = f"Text block ({normalized_bbox['width']}×{normalized_bbox['height']}px)" | |
| # For images, just indicate it's an image | |
| elif ext["type"] == "image": | |
| preview = f"Image" | |
| description = f"Image element ({normalized_bbox['width']}×{normalized_bbox['height']}px)" | |
| else: | |
| preview = ext["text"][:50] if ext["text"] else ext["type"] | |
| description = f"{ext['type'].capitalize()} element ({normalized_bbox['width']}×{normalized_bbox['height']}px)" | |
| simplified.append({ | |
| "id": idx, | |
| "type": ext["type"], | |
| "bbox": normalized_bbox, | |
| "position": { | |
| "top_left": {"x": normalized_bbox["x1"], "y": normalized_bbox["y1"]}, | |
| "bottom_right": {"x": normalized_bbox["x2"], "y": normalized_bbox["y2"]}, | |
| "center": { | |
| "x": (normalized_bbox["x1"] + normalized_bbox["x2"]) // 2, | |
| "y": (normalized_bbox["y1"] + normalized_bbox["y2"]) // 2 | |
| } | |
| }, | |
| "dimensions": { | |
| "width": normalized_bbox["width"], | |
| "height": normalized_bbox["height"], | |
| "area": normalized_bbox["width"] * normalized_bbox["height"] | |
| }, | |
| "content_preview": preview, | |
| "description": description | |
| }) | |
| return simplified | |
| def parse_deepseek_result(result: Any, img_width: int, img_height: int, base_size: int = 1024, scale_coords: bool = True, padding: int = 0) -> List[Dict]: | |
| """ | |
| Parse the DeepSeek-OCR result to extract text and bounding boxes. | |
| DeepSeek-OCR format: | |
| <|ref|>TYPE<|/ref|><|det|>[[x1, y1, x2, y2]]<|/det|> | |
| CONTENT | |
| The bounding boxes are in DeepSeek's coordinate space (based on base_size), | |
| so we need to scale them to the actual image dimensions. | |
| Args: | |
| result: The model output text | |
| img_width: Target image width | |
| img_height: Target image height | |
| base_size: Model's coordinate space size (usually 1024) | |
| scale_coords: Whether to scale coordinates (False if already in target space) | |
| padding: Pixels to add around each bounding box (while keeping in bounds) | |
| """ | |
| import re | |
| extractions = [] | |
| if not isinstance(result, str): | |
| return extractions | |
| # Calculate the scale factors from the model's coordinate space to actual image | |
| # DeepSeek-OCR appears to use a square coordinate space (base_size x base_size) | |
| # regardless of the actual image aspect ratio | |
| # So coordinates are always in the range [0, base_size] for both x and y | |
| if scale_coords: | |
| scale_x = img_width / base_size | |
| scale_y = img_height / base_size | |
| print(f"Image dimensions: {img_width}x{img_height}, base_size: {base_size}") | |
| print(f"Coordinate space: {base_size}x{base_size}, scale_x: {scale_x:.2f}, scale_y: {scale_y:.2f}") | |
| else: | |
| scale_x = 1.0 | |
| scale_y = 1.0 | |
| print(f"Using coordinates as-is (no scaling) for image: {img_width}x{img_height}") | |
| # Pattern to match: <|ref|>TYPE<|/ref|><|det|>[[x1, y1, x2, y2]]<|/det|> | |
| pattern = r'<\|ref\|>(.*?)<\|/ref\|><\|det\|>\[\[([\d, ]+)\]\]<\|/det\|>' | |
| # Find all matches with their positions | |
| matches = list(re.finditer(pattern, result)) | |
| for i, match in enumerate(matches): | |
| ref_type = match.group(1) # text, table, image, etc. | |
| bbox_str = match.group(2) # "x1, y1, x2, y2" | |
| # Parse bounding box coordinates | |
| try: | |
| coords = [int(x.strip()) for x in bbox_str.split(',')] | |
| if len(coords) == 4: | |
| x1, y1, x2, y2 = coords | |
| # Scale coordinates to actual image dimensions using separate scale factors | |
| x1_scaled = int(x1 * scale_x) | |
| y1_scaled = int(y1 * scale_y) | |
| x2_scaled = int(x2 * scale_x) | |
| y2_scaled = int(y2 * scale_y) | |
| # Add padding around bounding box (before bounds checking) | |
| if padding > 0: | |
| original_x1, original_y1, original_x2, original_y2 = x1_scaled, y1_scaled, x2_scaled, y2_scaled | |
| x1_scaled -= padding | |
| y1_scaled -= padding | |
| x2_scaled += padding | |
| y2_scaled += padding | |
| # Log first box padding for debugging | |
| if i == 0: | |
| print(f" Padding applied: {padding}px around boxes (e.g., box 0: {original_x1},{original_y1},{original_x2},{original_y2} -> {x1_scaled},{y1_scaled},{x2_scaled},{y2_scaled})") | |
| # Ensure coordinates are within image bounds | |
| x1_scaled = max(0, min(x1_scaled, img_width)) | |
| y1_scaled = max(0, min(y1_scaled, img_height)) | |
| x2_scaled = max(0, min(x2_scaled, img_width)) | |
| y2_scaled = max(0, min(y2_scaled, img_height)) | |
| bbox = { | |
| "x1": x1_scaled, | |
| "y1": y1_scaled, | |
| "x2": x2_scaled, | |
| "y2": y2_scaled, | |
| "width": x2_scaled - x1_scaled, | |
| "height": y2_scaled - y1_scaled, | |
| "original_coords": {"x1": x1, "y1": y1, "x2": x2, "y2": y2} # Keep original for debugging | |
| } | |
| else: | |
| bbox = {"x1": 0, "y1": 0, "x2": 0, "y2": 0, "width": 0, "height": 0} | |
| except Exception as e: | |
| print(f"Error parsing bounding box: {e} for bounding box: {bbox_str} for type {ref_type}") | |
| bbox = {"x1": 0, "y1": 0, "x2": 0, "y2": 0, "width": 0, "height": 0} | |
| # Extract content after this tag until the next tag (or end of string) | |
| content_start = match.end() | |
| if i + 1 < len(matches): | |
| content_end = matches[i + 1].start() | |
| else: | |
| content_end = len(result) | |
| content = result[content_start:content_end].strip() | |
| # Skip empty content or just whitespace/newlines | |
| if content and content not in ['\n', '\n\n', '**']: | |
| extractions.append({ | |
| "type": ref_type, | |
| "text": content, | |
| "bbox": bbox | |
| }) | |
| return extractions | |
| async def extract_simple(request: ImageRequest): | |
| """ | |
| Simplified endpoint that returns the raw DeepSeek-OCR output | |
| for inspection and format understanding. | |
| """ | |
| global model, tokenizer | |
| if model is None or tokenizer is None: | |
| raise HTTPException(status_code=503, detail="Model is still loading") | |
| try: | |
| if not request.image: | |
| raise HTTPException(status_code=400, detail="No image provided") | |
| image_bytes = base64.b64decode(request.image) | |
| import tempfile | |
| with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file: | |
| tmp_file.write(image_bytes) | |
| tmp_file_path = tmp_file.name | |
| try: | |
| output_path = tempfile.mkdtemp() | |
| # Capture stdout to get the raw model output with grounding tags | |
| import sys | |
| from io import StringIO | |
| old_stdout = sys.stdout | |
| sys.stdout = captured_output = StringIO() | |
| try: | |
| result = model.infer( | |
| tokenizer, | |
| prompt=request.prompt, | |
| image_file=tmp_file_path, | |
| output_path=output_path, | |
| base_size=request.base_size, | |
| image_size=request.image_size, | |
| crop_mode=request.crop_mode, | |
| save_results=True, | |
| test_compress=request.test_compress | |
| ) | |
| finally: | |
| sys.stdout = old_stdout | |
| raw_model_output = captured_output.getvalue() | |
| # Get image dimensions | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| img_width, img_height = image.size | |
| # Read the result from saved files | |
| result_text, result_image_with_boxes, image_patches = read_saved_results(output_path, tmp_file_path) | |
| # Use raw model output if it contains grounding tags | |
| if '<|ref|>' in raw_model_output and '<|det|>' in raw_model_output: | |
| result_text = raw_model_output | |
| # Parse extractions and get patches by type with base_size for proper coordinate scaling | |
| # Use annotated image for cropping if available | |
| image_for_cropping = image_bytes | |
| if result_image_with_boxes: | |
| try: | |
| annotated_image_bytes = base64.b64decode(result_image_with_boxes) | |
| test_img = Image.open(io.BytesIO(annotated_image_bytes)) | |
| img_for_crop_width, img_for_crop_height = test_img.size | |
| print(f"Using annotated image for cropping: {img_for_crop_width}x{img_for_crop_height}") | |
| image_for_cropping = annotated_image_bytes | |
| # Re-parse with annotated image dimensions and 200px padding | |
| extractions = parse_deepseek_result( | |
| result_text, | |
| img_for_crop_width, | |
| img_for_crop_height, | |
| request.base_size, | |
| scale_coords=True, | |
| padding=200 | |
| ) | |
| except Exception as e: | |
| print(f"Could not use annotated image: {e}") | |
| extractions = parse_deepseek_result(result_text, img_width, img_height, request.base_size, padding=200) | |
| else: | |
| extractions = parse_deepseek_result(result_text, img_width, img_height, request.base_size, padding=200) | |
| patches_by_type = extract_patches_by_type(image_for_cropping, extractions) | |
| response = { | |
| "result_type": str(type(result)), | |
| "result": result_text[:5000] if result_text else "No results found", | |
| "full_result": result_text, | |
| "output_path": output_path, | |
| "num_extractions": len(extractions), | |
| "num_tables": len(patches_by_type["table"]), | |
| "num_texts": len(patches_by_type["text"]), | |
| "num_images": len(patches_by_type["image"]) | |
| } | |
| # Add images if available | |
| if result_image_with_boxes: | |
| response["result_image_with_boxes"] = result_image_with_boxes | |
| if image_patches: | |
| response["model_image_patches"] = image_patches | |
| response["num_model_patches"] = len(image_patches) | |
| # Add patches by type | |
| response["table_patches"] = patches_by_type["table"] | |
| response["text_patches"] = patches_by_type["text"] | |
| response["image_patches_extracted"] = patches_by_type["image"] | |
| response["other_patches"] = patches_by_type["other"] | |
| return response | |
| finally: | |
| if os.path.exists(tmp_file_path): | |
| os.unlink(tmp_file_path) | |
| except Exception as e: | |
| import traceback | |
| error_details = traceback.format_exc() | |
| print(f"Error: {error_details}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |