RapidOCR / app.py
KroZenDev's picture
Update app.py
7eee37a verified
Raw
History Blame Contribute Delete
9.84 kB
import os
import io
import time
import asyncio
import numpy as np
from PIL import Image
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import HTMLResponse
from starlette.concurrency import run_in_threadpool
from huggingface_hub import hf_hub_download
# Ограничиваем треды ONNX Runtime (на HF Spaces CPU Basic всего 2 vCPU)
os.environ["OMP_NUM_THREADS"] = "2"
os.environ["MKL_NUM_THREADS"] = "2"
app = FastAPI()
MODEL_DIR = os.path.join(os.path.dirname(__file__), "models")
# Ограничиваем реальную конкурентность OCR-вызовов под число vCPU.
# На 2 vCPU параллельный запуск 2+ OCR одновременно делит ядра между ними
# и может быть МЕДЛЕННЕЕ, чем строгая очередь — поэтому 1, не 2.
ocr_semaphore = asyncio.Semaphore(1)
def download_file(repo_id, filename, local_dir):
"""Скачивает файл с HF Hub, если его нет локально."""
local_path = os.path.join(local_dir, filename)
if os.path.exists(local_path):
return local_path
try:
print(f"Downloading {repo_id}/{filename}...")
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_dir)
print(f" Saved to {local_path}")
return local_path
except Exception as e:
print(f" Failed: {e}")
return None
def ensure_models():
"""
Скачивает det (mobile, лёгкий ~4.8 МБ вместо server-варианта ~88 МБ),
rec и dict для русского+английского (eslav покрывает кириллицу).
"""
os.makedirs(MODEL_DIR, exist_ok=True)
# Детекция language-agnostic — просто находит текстовые блоки,
# поэтому mobile-вариант той же PP-OCRv5 серии работает для любого языка.
det = download_file("ilaylow/PP_OCRv5_mobile_onnx", "ppocrv5_det.onnx", MODEL_DIR)
# Рекогнишн оставляем eslav — обучена именно под кириллицу + латиницу
rec = download_file("monkt/paddleocr-onnx", "languages/eslav/rec.onnx", MODEL_DIR)
dict_path = download_file("monkt/paddleocr-onnx", "languages/eslav/dict.txt", MODEL_DIR)
return det, rec, dict_path
det_path, rec_path, dict_path = ensure_models()
from rapidocr_onnxruntime import RapidOCR
# === Попытка 1: использовать kwargs (новая версия 1.3.x) ===
try:
ocr = RapidOCR(
det_model_path=det_path,
rec_model_path=rec_path,
rec_keys_path=dict_path,
cls=False,
)
print("✓ RapidOCR initialized with kwargs")
# Monkey-patch параметров детекции для ускорения
if hasattr(ocr, "text_detector") and hasattr(ocr.text_detector, "preprocess_op"):
for op in ocr.text_detector.preprocess_op:
if op.__class__.__name__ == "DetResizeForTest":
op.limit_side_len = 480
op.limit_type = "max"
print(" Patched: limit_side_len=480, limit_type=max")
if hasattr(ocr, "text_detector") and hasattr(ocr.text_detector, "postprocess_op"):
ocr.text_detector.postprocess_op.thresh = 0.5
ocr.text_detector.postprocess_op.use_dilation = False
ocr.text_detector.postprocess_op.score_mode = "fast"
# Меньше кандидатов для NMS — на скрине редко >10-15 текстовых блоков
if hasattr(ocr.text_detector.postprocess_op, "max_candidates"):
ocr.text_detector.postprocess_op.max_candidates = 100
print(" Patched: thresh=0.5, use_dilation=False, score_mode=fast, max_candidates=100")
if hasattr(ocr, "text_recognizer"):
ocr.text_recognizer.rec_batch_num = 1
print(" Patched: rec_batch_num=1")
except TypeError:
# === Попытка 2: старая версия, нужен полный config.yaml ===
print("kwargs not supported, using full config.yaml")
cls_path = None
for repo, path in [
("RapidAI/RapidOCR", "onnx/PP-OCRv4/cls/ch_ppocr_mobile_v2.0_cls_infer.onnx"),
("RapidAI/RapidOCR", "onnx/PP-OCRv3/cls/ch_ppocr_mobile_v2.0_cls_infer.onnx"),
("RapidAI/RapidOCR", "resources/models/ch_ppocr_mobile_v2.0_cls_infer.onnx"),
]:
cls_path = download_file(repo, path, MODEL_DIR)
if cls_path:
break
if not cls_path:
print("Downloading standard models via RapidOCR()...")
temp_ocr = RapidOCR()
import rapidocr_onnxruntime
pkg_dir = os.path.dirname(rapidocr_onnxruntime.__file__)
for root, dirs, files in os.walk(pkg_dir):
for f in files:
if "cls" in f and f.endswith(".onnx"):
cls_path = os.path.join(root, f)
break
if cls_path:
break
if not cls_path:
cache_dir = os.path.expanduser("~/.cache/rapidocr_onnxruntime")
if os.path.exists(cache_dir):
for root, dirs, files in os.walk(cache_dir):
for f in files:
if "cls" in f and f.endswith(".onnx"):
cls_path = os.path.join(root, f)
break
if cls_path:
break
config_path = os.path.join(os.path.dirname(__file__), "config.yaml")
with open(config_path, "w") as f:
f.write(f"""
Global:
text_score: 0.5
use_angle_cls: false
print_verbose: false
min_height: 30
width_height_ratio: 8
Det:
module_name: ch_ppocr_v3_det
class_name: TextDetector
model_path: {det_path}
use_cuda: false
pre_process:
DetResizeForTest:
limit_side_len: 480
limit_type: max
NormalizeImage:
std: [0.229, 0.224, 0.225]
mean: [0.485, 0.456, 0.406]
scale: 1./255.
order: hwc
ToCHWImage:
KeepKeys:
keep_keys: ['image', 'shape']
post_process:
thresh: 0.5
box_thresh: 0.5
max_candidates: 100
unclip_ratio: 1.6
use_dilation: false
score_mode: fast
Cls:
module_name: ch_ppocr_v2_cls
class_name: TextClassifier
model_path: {cls_path or det_path}
cls_img_shape: [3, 48, 192]
cls_batch_num: 1
cls_thresh: 0.9
label_list: [0, 180]
Rec:
module_name: ch_ppocr_v2_rec
class_name: TextRecognizer
model_path: {rec_path}
rec_img_shape: [3, 48, 320]
rec_batch_num: 1
keys_path: {dict_path}
""")
ocr = RapidOCR(config_path=config_path)
print("✓ RapidOCR initialized with config.yaml")
# Warm-up: разогреваем ONNX Runtime
print("Warming up OCR...")
warmup_img = np.zeros((480, 480, 3), dtype=np.uint8)
try:
_ = ocr(warmup_img)
except Exception:
pass
print("Ready!")
HTML_FORM = """<!DOCTYPE html>
<html>
<head>
<title>RapidOCR — Russian + English</title>
<meta charset="utf-8">
<style>
body { font-family: sans-serif; max-width: 600px; margin: 50px auto; padding: 20px; background: #1a1a2e; color: #eee; }
h1 { color: #e94560; }
input[type=file] { margin: 20px 0; display: block; }
button { padding: 12px 24px; cursor: pointer; background: #e94560; color: white; border: none; border-radius: 6px; font-size: 16px; }
button:hover { background: #c73e54; }
#result { margin-top: 20px; padding: 15px; background: #16213e; border-radius: 8px; white-space: pre-wrap; font-family: monospace; min-height: 60px; }
</style>
</head>
<body>
<h1>📝 RapidOCR — Russian + English</h1>
<p>Upload an image with text (screenshots work great)</p>
<form id="uploadForm">
<input type="file" id="file" name="file" accept="image/*" required>
<button type="submit">Recognize</button>
</form>
<div id="result">Result will appear here...</div>
<script>
document.getElementById('uploadForm').onsubmit = async (e) => {
e.preventDefault();
const resultDiv = document.getElementById('result');
resultDiv.textContent = 'Processing...';
const formData = new FormData();
formData.append('file', document.getElementById('file').files[0]);
try {
const res = await fetch('/predict', {method: 'POST', body: formData});
const data = await res.json();
resultDiv.textContent = JSON.stringify(data, null, 2);
} catch (err) {
resultDiv.textContent = 'Error: ' + err.message;
}
};
</script>
</body>
</html>"""
@app.get("/", response_class=HTMLResponse)
def read_root():
return HTML_FORM
def _run_ocr_sync(img_array):
"""Синхронный вызов OCR — выполняется в threadpool, не блокирует event loop."""
t0 = time.time()
results, elapse_list = ocr(img_array)
total = time.time() - t0
print(f"[TIMING] total={total:.2f}s breakdown(det,cls,rec)={elapse_list}")
return results
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
try:
contents = await file.read()
image = Image.open(io.BytesIO(contents))
if image.mode != "RGB":
image = image.convert("RGB")
if image.width > 1600 or image.height > 1600:
image.thumbnail((1600, 1600))
img_array = np.array(image)
async with ocr_semaphore:
results = await run_in_threadpool(_run_ocr_sync, img_array)
if not results:
return {"text": ""}
text = " ".join([item[1] for item in results]).strip()
return {"text": text}
except Exception as e:
return {"error": str(e)}