shubhamingale's picture
update
af96db0 verified
#!/usr/bin/env python3
"""
One-script setup to serve fine-tuned DeepSeek-OCR with LoRA adapters.
Downloads base model + adapters, loads them, and starts a FastAPI server on port 8000.
Usage:
pip install unsloth fastapi uvicorn python-multipart huggingface_hub peft pillow
python serve.py
Environment variables:
HF_TOKEN - Hugging Face token (if adapter repo is private)
LORA_REPO - HF repo with LoRA adapters (default: shubhamingale/deepseek-ocr2-3b-lora)
PORT - Server port (default: 8000)
IMAGE_SIZE - Inference image size (default: 640)
BASE_SIZE - Inference base size (default: 1024)
"""
import os
import sys
import io
import gc
import time
import base64
import tempfile
import subprocess
# ── Configuration ──────────────────────────────────────────────────
LORA_REPO = os.environ.get("LORA_REPO", "shubhamingale/deepseek-ocr2-3b-lora")
BASE_MODEL_REPO = "unsloth/DeepSeek-OCR"
BASE_MODEL_DIR = "./deepseek_ocr"
PORT = int(os.environ.get("PORT", "8000"))
IMAGE_SIZE = int(os.environ.get("IMAGE_SIZE", "640"))
BASE_SIZE = int(os.environ.get("BASE_SIZE", "1024"))
HF_TOKEN = os.environ.get("HF_TOKEN", None)
# ──────────────────────────────────────────────────────────────────
def log(msg):
"""Print a timestamped log message."""
elapsed = time.time() - START_TIME
mins, secs = divmod(int(elapsed), 60)
print(f" [{mins:02d}:{secs:02d}] {msg}", flush=True)
START_TIME = time.time()
def download_models():
"""Download base model and LoRA adapters from Hugging Face."""
from huggingface_hub import snapshot_download
# Download base model
if not os.path.exists(BASE_MODEL_DIR):
log(f"⬇️ Downloading base model: {BASE_MODEL_REPO} (~6.7GB)")
log(" This may take 3-10 minutes depending on connection...")
snapshot_download(BASE_MODEL_REPO, local_dir=BASE_MODEL_DIR, token=HF_TOKEN)
log("✅ Base model downloaded.")
else:
log(f"✅ Base model already exists at {BASE_MODEL_DIR}")
# Download LoRA adapters
log(f"⬇️ Downloading LoRA adapters: {LORA_REPO} (~296MB)")
lora_path = snapshot_download(LORA_REPO, token=HF_TOKEN)
log(f"✅ LoRA adapters downloaded to: {lora_path}")
return lora_path
def load_model(lora_path: str):
"""Load base model + LoRA adapters and prepare for inference."""
os.environ["UNSLOTH_WARN_UNINITIALIZED"] = "0"
os.environ["HF_HUB_OFFLINE"] = "1"
log("📦 Importing Unsloth and transformers...")
from unsloth import FastVisionModel
from transformers import AutoModel
from peft import PeftModel
log("✅ Imports done.")
log("🔄 Loading base model into GPU... (this takes 1-2 minutes)")
model, tokenizer = FastVisionModel.from_pretrained(
BASE_MODEL_DIR,
load_in_4bit=False,
auto_model=AutoModel,
trust_remote_code=True,
unsloth_force_compile=True,
use_gradient_checkpointing="unsloth",
)
log("✅ Base model loaded.")
log("🔗 Applying LoRA adapters...")
model = PeftModel.from_pretrained(model, lora_path)
log("✅ LoRA adapters applied.")
log("⚡ Switching to inference mode...")
FastVisionModel.for_inference(model)
log("✅ Model ready for inference!")
return model, tokenizer
def create_app(model, tokenizer):
"""Create FastAPI application with OCR endpoints."""
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.responses import JSONResponse
from PIL import Image
app = FastAPI(
title="DeepSeek-OCR Fine-tuned API",
description="OCR API powered by fine-tuned DeepSeek-OCR with LoRA adapters",
version="1.0.0",
)
@app.get("/health")
async def health():
return {"status": "ok", "model": BASE_MODEL_REPO, "lora": LORA_REPO}
@app.post("/ocr")
async def ocr(
file: UploadFile = File(...),
prompt: str = Form(default="<image>\nExtract all the information from given image"),
image_size: int = Form(default=IMAGE_SIZE),
base_size: int = Form(default=BASE_SIZE),
):
"""Run OCR on an uploaded image."""
if not file.content_type or not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="File must be an image")
# Save uploaded image to temp file
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
content = await file.read()
tmp.write(content)
tmp_path = tmp.name
try:
output_dir = tempfile.mkdtemp()
# Capture printed output (model.infer() prints to stdout)
old_stdout = sys.stdout
sys.stdout = captured = io.StringIO()
res = model.infer(
tokenizer,
prompt=prompt,
image_file=tmp_path,
output_path=output_dir,
image_size=image_size,
base_size=base_size,
crop_mode=True,
save_results=False,
test_compress=False,
)
sys.stdout = old_stdout
printed_output = captured.getvalue()
# Use return value if available, otherwise parse captured output
if res and isinstance(res, str) and len(res.strip()) > 0:
result_text = res.strip()
else:
lines = printed_output.strip().split("\n")
pred_lines = [
l for l in lines
if not l.startswith("=")
and not l.startswith("BASE:")
and not l.startswith("PATCHES:")
and not l.startswith("The attention mask")
]
result_text = "\n".join(pred_lines).strip()
return JSONResponse(content={"text": result_text, "status": "success"})
except Exception as e:
sys.stdout = sys.__stdout__
raise HTTPException(status_code=500, detail=str(e))
finally:
os.unlink(tmp_path)
@app.post("/ocr/base64")
async def ocr_base64(
image_base64: str = Form(...),
prompt: str = Form(default="<image>\nExtract all the information from given image"),
image_size: int = Form(default=IMAGE_SIZE),
base_size: int = Form(default=BASE_SIZE),
):
"""Run OCR on a base64-encoded image."""
try:
image_data = base64.b64decode(image_base64)
except Exception:
raise HTTPException(status_code=400, detail="Invalid base64 image data")
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
tmp.write(image_data)
tmp_path = tmp.name
try:
output_dir = tempfile.mkdtemp()
old_stdout = sys.stdout
sys.stdout = captured = io.StringIO()
res = model.infer(
tokenizer,
prompt=prompt,
image_file=tmp_path,
output_path=output_dir,
image_size=image_size,
base_size=base_size,
crop_mode=True,
save_results=False,
test_compress=False,
)
sys.stdout = old_stdout
printed_output = captured.getvalue()
if res and isinstance(res, str) and len(res.strip()) > 0:
result_text = res.strip()
else:
lines = printed_output.strip().split("\n")
pred_lines = [
l for l in lines
if not l.startswith("=")
and not l.startswith("BASE:")
and not l.startswith("PATCHES:")
and not l.startswith("The attention mask")
]
result_text = "\n".join(pred_lines).strip()
return JSONResponse(content={"text": result_text, "status": "success"})
except Exception as e:
sys.stdout = sys.__stdout__
raise HTTPException(status_code=500, detail=str(e))
finally:
os.unlink(tmp_path)
return app
def main():
import uvicorn
print("=" * 60)
print(" DeepSeek-OCR Fine-tuned Server")
print("=" * 60)
print(f" Base model : {BASE_MODEL_REPO}")
print(f" LoRA repo : {LORA_REPO}")
print(f" Port : {PORT}")
print("=" * 60)
print()
# Step 1: Download models
log("STEP 1/3: Downloading models...")
lora_path = download_models()
print()
# Step 2: Load model with LoRA
log("STEP 2/3: Loading model...")
model, tokenizer = load_model(lora_path)
print()
# Step 3: Create and run API
log("STEP 3/3: Starting API server...")
app = create_app(model, tokenizer)
total = time.time() - START_TIME
mins, secs = divmod(int(total), 60)
print()
print("=" * 60)
print(f" 🚀 Server ready in {mins}m {secs}s")
print(f" 🌐 http://0.0.0.0:{PORT}")
print(f" 📖 http://0.0.0.0:{PORT}/docs (Swagger UI)")
print()
print(f" Test it:")
print(f' curl -X POST http://localhost:{PORT}/ocr -F "file=@image.jpg"')
print("=" * 60)
print()
uvicorn.run(app, host="0.0.0.0", port=PORT)
if __name__ == "__main__":
main()