| |
| """ |
| One-script setup to serve fine-tuned DeepSeek-OCR with LoRA adapters. |
| Downloads base model + adapters, loads them, and starts a FastAPI server on port 8000. |
| |
| Usage: |
| pip install unsloth fastapi uvicorn python-multipart huggingface_hub peft pillow |
| python serve.py |
| |
| Environment variables: |
| HF_TOKEN - Hugging Face token (if adapter repo is private) |
| LORA_REPO - HF repo with LoRA adapters (default: shubhamingale/deepseek-ocr2-3b-lora) |
| PORT - Server port (default: 8000) |
| IMAGE_SIZE - Inference image size (default: 640) |
| BASE_SIZE - Inference base size (default: 1024) |
| """ |
|
|
| import os |
| import sys |
| import io |
| import gc |
| import time |
| import base64 |
| import tempfile |
| import subprocess |
|
|
| |
| LORA_REPO = os.environ.get("LORA_REPO", "shubhamingale/deepseek-ocr2-3b-lora") |
| BASE_MODEL_REPO = "unsloth/DeepSeek-OCR" |
| BASE_MODEL_DIR = "./deepseek_ocr" |
| PORT = int(os.environ.get("PORT", "8000")) |
| IMAGE_SIZE = int(os.environ.get("IMAGE_SIZE", "640")) |
| BASE_SIZE = int(os.environ.get("BASE_SIZE", "1024")) |
| HF_TOKEN = os.environ.get("HF_TOKEN", None) |
| |
|
|
|
|
| def log(msg): |
| """Print a timestamped log message.""" |
| elapsed = time.time() - START_TIME |
| mins, secs = divmod(int(elapsed), 60) |
| print(f" [{mins:02d}:{secs:02d}] {msg}", flush=True) |
|
|
|
|
| START_TIME = time.time() |
|
|
|
|
| def download_models(): |
| """Download base model and LoRA adapters from Hugging Face.""" |
| from huggingface_hub import snapshot_download |
|
|
| |
| if not os.path.exists(BASE_MODEL_DIR): |
| log(f"⬇️ Downloading base model: {BASE_MODEL_REPO} (~6.7GB)") |
| log(" This may take 3-10 minutes depending on connection...") |
| snapshot_download(BASE_MODEL_REPO, local_dir=BASE_MODEL_DIR, token=HF_TOKEN) |
| log("✅ Base model downloaded.") |
| else: |
| log(f"✅ Base model already exists at {BASE_MODEL_DIR}") |
|
|
| |
| log(f"⬇️ Downloading LoRA adapters: {LORA_REPO} (~296MB)") |
| lora_path = snapshot_download(LORA_REPO, token=HF_TOKEN) |
| log(f"✅ LoRA adapters downloaded to: {lora_path}") |
| return lora_path |
|
|
|
|
| def load_model(lora_path: str): |
| """Load base model + LoRA adapters and prepare for inference.""" |
| os.environ["UNSLOTH_WARN_UNINITIALIZED"] = "0" |
| os.environ["HF_HUB_OFFLINE"] = "1" |
|
|
| log("📦 Importing Unsloth and transformers...") |
| from unsloth import FastVisionModel |
| from transformers import AutoModel |
| from peft import PeftModel |
| log("✅ Imports done.") |
|
|
| log("🔄 Loading base model into GPU... (this takes 1-2 minutes)") |
| model, tokenizer = FastVisionModel.from_pretrained( |
| BASE_MODEL_DIR, |
| load_in_4bit=False, |
| auto_model=AutoModel, |
| trust_remote_code=True, |
| unsloth_force_compile=True, |
| use_gradient_checkpointing="unsloth", |
| ) |
| log("✅ Base model loaded.") |
|
|
| log("🔗 Applying LoRA adapters...") |
| model = PeftModel.from_pretrained(model, lora_path) |
| log("✅ LoRA adapters applied.") |
|
|
| log("⚡ Switching to inference mode...") |
| FastVisionModel.for_inference(model) |
| log("✅ Model ready for inference!") |
|
|
| return model, tokenizer |
|
|
|
|
| def create_app(model, tokenizer): |
| """Create FastAPI application with OCR endpoints.""" |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException |
| from fastapi.responses import JSONResponse |
| from PIL import Image |
|
|
| app = FastAPI( |
| title="DeepSeek-OCR Fine-tuned API", |
| description="OCR API powered by fine-tuned DeepSeek-OCR with LoRA adapters", |
| version="1.0.0", |
| ) |
|
|
| @app.get("/health") |
| async def health(): |
| return {"status": "ok", "model": BASE_MODEL_REPO, "lora": LORA_REPO} |
|
|
| @app.post("/ocr") |
| async def ocr( |
| file: UploadFile = File(...), |
| prompt: str = Form(default="<image>\nExtract all the information from given image"), |
| image_size: int = Form(default=IMAGE_SIZE), |
| base_size: int = Form(default=BASE_SIZE), |
| ): |
| """Run OCR on an uploaded image.""" |
| if not file.content_type or not file.content_type.startswith("image/"): |
| raise HTTPException(status_code=400, detail="File must be an image") |
|
|
| |
| with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp: |
| content = await file.read() |
| tmp.write(content) |
| tmp_path = tmp.name |
|
|
| try: |
| output_dir = tempfile.mkdtemp() |
|
|
| |
| old_stdout = sys.stdout |
| sys.stdout = captured = io.StringIO() |
|
|
| res = model.infer( |
| tokenizer, |
| prompt=prompt, |
| image_file=tmp_path, |
| output_path=output_dir, |
| image_size=image_size, |
| base_size=base_size, |
| crop_mode=True, |
| save_results=False, |
| test_compress=False, |
| ) |
|
|
| sys.stdout = old_stdout |
| printed_output = captured.getvalue() |
|
|
| |
| if res and isinstance(res, str) and len(res.strip()) > 0: |
| result_text = res.strip() |
| else: |
| lines = printed_output.strip().split("\n") |
| pred_lines = [ |
| l for l in lines |
| if not l.startswith("=") |
| and not l.startswith("BASE:") |
| and not l.startswith("PATCHES:") |
| and not l.startswith("The attention mask") |
| ] |
| result_text = "\n".join(pred_lines).strip() |
|
|
| return JSONResponse(content={"text": result_text, "status": "success"}) |
|
|
| except Exception as e: |
| sys.stdout = sys.__stdout__ |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| finally: |
| os.unlink(tmp_path) |
|
|
| @app.post("/ocr/base64") |
| async def ocr_base64( |
| image_base64: str = Form(...), |
| prompt: str = Form(default="<image>\nExtract all the information from given image"), |
| image_size: int = Form(default=IMAGE_SIZE), |
| base_size: int = Form(default=BASE_SIZE), |
| ): |
| """Run OCR on a base64-encoded image.""" |
| try: |
| image_data = base64.b64decode(image_base64) |
| except Exception: |
| raise HTTPException(status_code=400, detail="Invalid base64 image data") |
|
|
| with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp: |
| tmp.write(image_data) |
| tmp_path = tmp.name |
|
|
| try: |
| output_dir = tempfile.mkdtemp() |
|
|
| old_stdout = sys.stdout |
| sys.stdout = captured = io.StringIO() |
|
|
| res = model.infer( |
| tokenizer, |
| prompt=prompt, |
| image_file=tmp_path, |
| output_path=output_dir, |
| image_size=image_size, |
| base_size=base_size, |
| crop_mode=True, |
| save_results=False, |
| test_compress=False, |
| ) |
|
|
| sys.stdout = old_stdout |
| printed_output = captured.getvalue() |
|
|
| if res and isinstance(res, str) and len(res.strip()) > 0: |
| result_text = res.strip() |
| else: |
| lines = printed_output.strip().split("\n") |
| pred_lines = [ |
| l for l in lines |
| if not l.startswith("=") |
| and not l.startswith("BASE:") |
| and not l.startswith("PATCHES:") |
| and not l.startswith("The attention mask") |
| ] |
| result_text = "\n".join(pred_lines).strip() |
|
|
| return JSONResponse(content={"text": result_text, "status": "success"}) |
|
|
| except Exception as e: |
| sys.stdout = sys.__stdout__ |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| finally: |
| os.unlink(tmp_path) |
|
|
| return app |
|
|
|
|
| def main(): |
| import uvicorn |
|
|
| print("=" * 60) |
| print(" DeepSeek-OCR Fine-tuned Server") |
| print("=" * 60) |
| print(f" Base model : {BASE_MODEL_REPO}") |
| print(f" LoRA repo : {LORA_REPO}") |
| print(f" Port : {PORT}") |
| print("=" * 60) |
| print() |
|
|
| |
| log("STEP 1/3: Downloading models...") |
| lora_path = download_models() |
| print() |
|
|
| |
| log("STEP 2/3: Loading model...") |
| model, tokenizer = load_model(lora_path) |
| print() |
|
|
| |
| log("STEP 3/3: Starting API server...") |
| app = create_app(model, tokenizer) |
|
|
| total = time.time() - START_TIME |
| mins, secs = divmod(int(total), 60) |
|
|
| print() |
| print("=" * 60) |
| print(f" 🚀 Server ready in {mins}m {secs}s") |
| print(f" 🌐 http://0.0.0.0:{PORT}") |
| print(f" 📖 http://0.0.0.0:{PORT}/docs (Swagger UI)") |
| print() |
| print(f" Test it:") |
| print(f' curl -X POST http://localhost:{PORT}/ocr -F "file=@image.jpg"') |
| print("=" * 60) |
| print() |
|
|
| uvicorn.run(app, host="0.0.0.0", port=PORT) |
|
|
|
|
| if __name__ == "__main__": |
| main() |