mnemo-ocr-core / src /vision_circle_ocr.py
MABobrov's picture
Deploy updated core backend pipeline
7fb79e4
"""Vision LLM OCR for red circle UID reading.
Instead of using EasyOCR for each circle individually,
batches multiple circle crops into a single vision API call
for faster and more accurate UID reading.
"""
from __future__ import annotations
import base64
import json
import logging
import re
import urllib.error
import urllib.request
from typing import Any
import cv2
import numpy as np
logger = logging.getLogger(__name__)
_BATCH_SIZE = 30 # circles per API call
def _crop_circle(image_bgr: np.ndarray, box: tuple[int, int, int, int], pad: int = 5) -> np.ndarray:
x, y, w, h = box
y0 = max(0, y - pad)
x0 = max(0, x - pad)
y1 = min(image_bgr.shape[0], y + h + pad)
x1 = min(image_bgr.shape[1], x + w + pad)
return image_bgr[y0:y1, x0:x1]
def _build_grid_image(crops: list[np.ndarray], cols: int = 6) -> tuple[np.ndarray, list[tuple[int, int]]]:
"""Arrange circle crops into a grid image with labels for batch OCR."""
if not crops:
return np.zeros((1, 1, 3), dtype=np.uint8), []
cell_w = max(c.shape[1] for c in crops) + 4
cell_h = max(c.shape[0] for c in crops) + 20 # extra space for index label
rows = (len(crops) + cols - 1) // cols
grid = np.full((rows * cell_h, cols * cell_w, 3), 255, dtype=np.uint8)
positions: list[tuple[int, int]] = []
for i, crop in enumerate(crops):
r, c = divmod(i, cols)
gx = c * cell_w + 2
gy = r * cell_h + 16 # leave space for index
ch, cw = crop.shape[:2]
grid[gy:gy + ch, gx:gx + cw] = crop
# Draw index number above
cv2.putText(grid, f"#{i + 1}", (gx, gy - 4), cv2.FONT_HERSHEY_SIMPLEX, 0.35, (0, 0, 0), 1)
positions.append((gx, gy))
return grid, positions
def _encode_image(image: np.ndarray) -> str:
success, buf = cv2.imencode(".png", image)
if not success:
raise ValueError("Failed to encode image")
return base64.b64encode(buf.tobytes()).decode("ascii")
def batch_ocr_circles(
image_bgr: np.ndarray,
circle_boxes: list[tuple[int, int, int, int]],
*,
api_url: str,
api_key: str = "",
model: str = "",
) -> list[str]:
"""Read UID numbers from red circles using vision LLM.
Batches circles into grid images and sends to vision API.
Returns list of UID strings (same length as circle_boxes,
empty string if not readable).
"""
if not circle_boxes or not api_url:
return [""] * len(circle_boxes)
results: list[str] = [""] * len(circle_boxes)
for batch_start in range(0, len(circle_boxes), _BATCH_SIZE):
batch_end = min(batch_start + _BATCH_SIZE, len(circle_boxes))
batch_boxes = circle_boxes[batch_start:batch_end]
batch_crops = [_crop_circle(image_bgr, box) for box in batch_boxes]
grid, _ = _build_grid_image(batch_crops)
b64 = _encode_image(grid)
n = len(batch_crops)
prompt = (
f"На изображении {n} красных кружков с цифрами внутри, пронумерованных #1..#{n}.\n"
f"Для каждого кружка прочитай чёрное число внутри.\n"
f"Верни ТОЛЬКО JSON массив из {n} строк, например: [\"42\", \"7\", \"103\", ...]\n"
f"Если число не читается — пустая строка \"\".\n"
f"Верни ровно {n} элементов."
)
payload = {
"model": model,
"messages": [{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{b64}"}},
],
}],
"max_tokens": 2048,
"temperature": 0,
}
headers = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
endpoint = api_url.rstrip("/") + "/chat/completions"
body = json.dumps(payload).encode("utf-8")
req = urllib.request.Request(endpoint, data=body, headers=headers, method="POST")
try:
with urllib.request.urlopen(req, timeout=60) as resp:
raw = json.loads(resp.read().decode("utf-8"))
content = str(raw.get("choices", [{}])[0].get("message", {}).get("content", ""))
# Parse JSON array from response
arr_match = re.search(r"\[.*\]", content, re.DOTALL)
if arr_match:
uids = json.loads(arr_match.group())
for i, uid_val in enumerate(uids[:n]):
idx = batch_start + i
digits = "".join(c for c in str(uid_val) if c.isdigit())
if digits and len(digits) <= 4:
results[idx] = digits.lstrip("0") or digits
logger.info("Vision batch OCR: %d/%d circles read (batch %d-%d)",
sum(1 for r in results[batch_start:batch_end] if r), n,
batch_start + 1, batch_end)
except Exception as exc:
logger.warning("Vision batch OCR failed for batch %d-%d: %s",
batch_start + 1, batch_end, exc)
return results