Spaces:
Paused
Paused
| """Vision LLM OCR for red circle UID reading. | |
| Instead of using EasyOCR for each circle individually, | |
| batches multiple circle crops into a single vision API call | |
| for faster and more accurate UID reading. | |
| """ | |
| from __future__ import annotations | |
| import base64 | |
| import json | |
| import logging | |
| import re | |
| import urllib.error | |
| import urllib.request | |
| from typing import Any | |
| import cv2 | |
| import numpy as np | |
| logger = logging.getLogger(__name__) | |
| _BATCH_SIZE = 30 # circles per API call | |
| def _crop_circle(image_bgr: np.ndarray, box: tuple[int, int, int, int], pad: int = 5) -> np.ndarray: | |
| x, y, w, h = box | |
| y0 = max(0, y - pad) | |
| x0 = max(0, x - pad) | |
| y1 = min(image_bgr.shape[0], y + h + pad) | |
| x1 = min(image_bgr.shape[1], x + w + pad) | |
| return image_bgr[y0:y1, x0:x1] | |
| def _build_grid_image(crops: list[np.ndarray], cols: int = 6) -> tuple[np.ndarray, list[tuple[int, int]]]: | |
| """Arrange circle crops into a grid image with labels for batch OCR.""" | |
| if not crops: | |
| return np.zeros((1, 1, 3), dtype=np.uint8), [] | |
| cell_w = max(c.shape[1] for c in crops) + 4 | |
| cell_h = max(c.shape[0] for c in crops) + 20 # extra space for index label | |
| rows = (len(crops) + cols - 1) // cols | |
| grid = np.full((rows * cell_h, cols * cell_w, 3), 255, dtype=np.uint8) | |
| positions: list[tuple[int, int]] = [] | |
| for i, crop in enumerate(crops): | |
| r, c = divmod(i, cols) | |
| gx = c * cell_w + 2 | |
| gy = r * cell_h + 16 # leave space for index | |
| ch, cw = crop.shape[:2] | |
| grid[gy:gy + ch, gx:gx + cw] = crop | |
| # Draw index number above | |
| cv2.putText(grid, f"#{i + 1}", (gx, gy - 4), cv2.FONT_HERSHEY_SIMPLEX, 0.35, (0, 0, 0), 1) | |
| positions.append((gx, gy)) | |
| return grid, positions | |
| def _encode_image(image: np.ndarray) -> str: | |
| success, buf = cv2.imencode(".png", image) | |
| if not success: | |
| raise ValueError("Failed to encode image") | |
| return base64.b64encode(buf.tobytes()).decode("ascii") | |
| def batch_ocr_circles( | |
| image_bgr: np.ndarray, | |
| circle_boxes: list[tuple[int, int, int, int]], | |
| *, | |
| api_url: str, | |
| api_key: str = "", | |
| model: str = "", | |
| ) -> list[str]: | |
| """Read UID numbers from red circles using vision LLM. | |
| Batches circles into grid images and sends to vision API. | |
| Returns list of UID strings (same length as circle_boxes, | |
| empty string if not readable). | |
| """ | |
| if not circle_boxes or not api_url: | |
| return [""] * len(circle_boxes) | |
| results: list[str] = [""] * len(circle_boxes) | |
| for batch_start in range(0, len(circle_boxes), _BATCH_SIZE): | |
| batch_end = min(batch_start + _BATCH_SIZE, len(circle_boxes)) | |
| batch_boxes = circle_boxes[batch_start:batch_end] | |
| batch_crops = [_crop_circle(image_bgr, box) for box in batch_boxes] | |
| grid, _ = _build_grid_image(batch_crops) | |
| b64 = _encode_image(grid) | |
| n = len(batch_crops) | |
| prompt = ( | |
| f"На изображении {n} красных кружков с цифрами внутри, пронумерованных #1..#{n}.\n" | |
| f"Для каждого кружка прочитай чёрное число внутри.\n" | |
| f"Верни ТОЛЬКО JSON массив из {n} строк, например: [\"42\", \"7\", \"103\", ...]\n" | |
| f"Если число не читается — пустая строка \"\".\n" | |
| f"Верни ровно {n} элементов." | |
| ) | |
| payload = { | |
| "model": model, | |
| "messages": [{ | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": prompt}, | |
| {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{b64}"}}, | |
| ], | |
| }], | |
| "max_tokens": 2048, | |
| "temperature": 0, | |
| } | |
| headers = {"Content-Type": "application/json"} | |
| if api_key: | |
| headers["Authorization"] = f"Bearer {api_key}" | |
| endpoint = api_url.rstrip("/") + "/chat/completions" | |
| body = json.dumps(payload).encode("utf-8") | |
| req = urllib.request.Request(endpoint, data=body, headers=headers, method="POST") | |
| try: | |
| with urllib.request.urlopen(req, timeout=60) as resp: | |
| raw = json.loads(resp.read().decode("utf-8")) | |
| content = str(raw.get("choices", [{}])[0].get("message", {}).get("content", "")) | |
| # Parse JSON array from response | |
| arr_match = re.search(r"\[.*\]", content, re.DOTALL) | |
| if arr_match: | |
| uids = json.loads(arr_match.group()) | |
| for i, uid_val in enumerate(uids[:n]): | |
| idx = batch_start + i | |
| digits = "".join(c for c in str(uid_val) if c.isdigit()) | |
| if digits and len(digits) <= 4: | |
| results[idx] = digits.lstrip("0") or digits | |
| logger.info("Vision batch OCR: %d/%d circles read (batch %d-%d)", | |
| sum(1 for r in results[batch_start:batch_end] if r), n, | |
| batch_start + 1, batch_end) | |
| except Exception as exc: | |
| logger.warning("Vision batch OCR failed for batch %d-%d: %s", | |
| batch_start + 1, batch_end, exc) | |
| return results | |