| from fastapi import FastAPI, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel, Field |
| from typing import List, Dict, Any, Optional |
| from PIL import Image, ImageDraw |
| import io |
| import base64 |
| import torch |
| from transformers import AutoModel, AutoProcessor |
| import numpy as np |
| import logging |
| import time |
| import gc |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| |
| app = FastAPI(title="DeepSeek OCR API", version="1.0.0") |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| model = None |
| processor = None |
|
|
| class BoxRegion(BaseModel): |
| id: int |
| x1: float = Field(..., ge=0, le=1) |
| y1: float = Field(..., ge=0, le=1) |
| x2: float = Field(..., ge=0, le=1) |
| y2: float = Field(..., ge=0, le=1) |
|
|
| class OCRRequest(BaseModel): |
| image: str = Field(..., description="Base64 encoded image") |
| boxes: List[BoxRegion] = Field(..., description="List of bounding boxes to process") |
| include_full_image: bool = Field(False, description="Whether to process the full image as well") |
|
|
| class BoxResult(BaseModel): |
| id: int |
| text: str |
| x1: float |
| y1: float |
| x2: float |
| y2: float |
|
|
| class OCRResponse(BaseModel): |
| results: List[BoxResult] |
| full_image_text: Optional[str] = None |
| processing_time: float |
|
|
| @app.on_event("startup") |
| async def load_model(): |
| """تحميل النموذج عند بدء التشغيل""" |
| global model, processor |
| try: |
| logger.info("Loading DeepSeek OCR model...") |
| |
| |
| model = AutoModel.from_pretrained( |
| "deepseek-ai/DeepSeek-OCR-2", |
| trust_remote_code=True, |
| torch_dtype=torch.float32, |
| device_map="cpu", |
| low_cpu_mem_usage=True |
| ) |
| model.eval() |
| |
| |
| try: |
| processor = AutoProcessor.from_pretrained( |
| "deepseek-ai/DeepSeek-OCR-2", |
| trust_remote_code=True |
| ) |
| except: |
| processor = None |
| logger.warning("Processor not available, using model directly") |
| |
| logger.info("Model loaded successfully!") |
| except Exception as e: |
| logger.error(f"Error loading model: {str(e)}") |
| raise |
|
|
| def decode_base64_image(base64_string: str) -> Image.Image: |
| """فك تشفير الصورة من base64""" |
| try: |
| if "base64," in base64_string: |
| base64_string = base64_string.split("base64,")[1] |
| |
| image_data = base64.b64decode(base64_string) |
| image = Image.open(io.BytesIO(image_data)) |
| return image.convert("RGB") |
| except Exception as e: |
| raise HTTPException(status_code=400, detail=f"Invalid image data: {str(e)}") |
|
|
| def crop_and_ocr(image: Image.Image, box: BoxRegion) -> str: |
| """قص المنطقة المحددة وإجراء OCR عليها""" |
| try: |
| |
| img_width, img_height = image.size |
| |
| left = int(box.x1 * img_width) |
| top = int(box.y1 * img_height) |
| right = int(box.x2 * img_width) |
| bottom = int(box.y2 * img_height) |
| |
| |
| left = max(0, min(left, img_width)) |
| top = max(0, min(top, img_height)) |
| right = max(left + 1, min(right, img_width)) |
| bottom = max(top + 1, min(bottom, img_height)) |
| |
| |
| cropped = image.crop((left, top, right, bottom)) |
| |
| |
| with torch.no_grad(): |
| if processor is not None: |
| |
| inputs = processor(images=cropped, return_tensors="pt") |
| result = model.generate(**inputs) |
| text = processor.decode(result[0], skip_special_tokens=True) |
| else: |
| |
| result = model(cropped) |
| text = result.strip() if result else "" |
| |
| return text if text else "" |
| except Exception as e: |
| logger.error(f"Error processing box {box.id}: {str(e)}") |
| return "" |
|
|
| def cleanup_memory(): |
| """تنظيف الذاكرة""" |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| @app.get("/") |
| async def root(): |
| return { |
| "message": "DeepSeek OCR API", |
| "status": "active", |
| "model": "deepseek-ai/DeepSeek-OCR-2", |
| "model_loaded": model is not None |
| } |
|
|
| @app.get("/health") |
| async def health_check(): |
| return { |
| "status": "healthy", |
| "model_loaded": model is not None |
| } |
|
|
| @app.post("/ocr", response_model=OCRResponse) |
| async def process_ocr(request: OCRRequest): |
| """معالجة OCR للمناطق المحددة في الصورة""" |
| start_time = time.time() |
| |
| if model is None: |
| raise HTTPException(status_code=503, detail="Model not loaded yet") |
| |
| try: |
| |
| image = decode_base64_image(request.image) |
| |
| results = [] |
| |
| |
| for box in request.boxes: |
| text = crop_and_ocr(image, box) |
| results.append(BoxResult( |
| id=box.id, |
| text=text, |
| x1=box.x1, |
| y1=box.y1, |
| x2=box.x2, |
| y2=box.y2 |
| )) |
| |
| |
| full_image_text = None |
| if request.include_full_image: |
| with torch.no_grad(): |
| if processor is not None: |
| inputs = processor(images=image, return_tensors="pt") |
| result = model.generate(**inputs) |
| full_image_text = processor.decode(result[0], skip_special_tokens=True) |
| else: |
| full_image_text = model(image).strip() |
| |
| |
| processing_time = time.time() - start_time |
| |
| |
| cleanup_memory() |
| |
| return OCRResponse( |
| results=results, |
| full_image_text=full_image_text, |
| processing_time=round(processing_time, 2) |
| ) |
| |
| except Exception as e: |
| cleanup_memory() |
| logger.error(f"Processing error: {str(e)}") |
| raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}") |
|
|
| @app.post("/ocr/single") |
| async def process_single_box(request: dict): |
| """معالجة مربع واحد فقط""" |
| if model is None: |
| raise HTTPException(status_code=503, detail="Model not loaded yet") |
| |
| try: |
| image = decode_base64_image(request["image"]) |
| box = BoxRegion(**request["box"]) |
| |
| text = crop_and_ocr(image, box) |
| |
| cleanup_memory() |
| |
| return { |
| "id": box.id, |
| "text": text, |
| "x1": box.x1, |
| "y1": box.y1, |
| "x2": box.x2, |
| "y2": box.y2 |
| } |
| except Exception as e: |
| cleanup_memory() |
| raise HTTPException(status_code=500, detail=f"Error: {str(e)}") |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run(app, host="0.0.0.0", port=7860) |