mnemo-ocr-core / src /red_element_detector.py
MABobrov's picture
Deploy updated core backend pipeline
7fb79e4
"""Red element detector for mnemonic schema images.
Detects red circles (UID badges) and red rectangles (value cells),
reads UID numbers from circles, and associates them with nearby cells.
All other elements are treated as static shapes.
"""
from __future__ import annotations
import logging
from typing import Any
import cv2
import numpy as np
logger = logging.getLogger(__name__)
_MIN_CIRCLE_SIZE = 14
_MAX_CIRCLE_SIZE = 50
_MIN_CIRCULARITY = 0.55
_CIRCLE_TYPICAL_DIM = 32
def _detect_red_mask(image_bgr: np.ndarray) -> np.ndarray:
"""Create a binary mask of red-colored pixels.
Uses light morphological closing (2x2, 1 iteration) to avoid
merging adjacent table cells while still closing tiny gaps.
"""
hsv = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2HSV)
red_mask1 = cv2.inRange(hsv, np.array([0, 100, 80]), np.array([12, 255, 255]))
red_mask2 = cv2.inRange(hsv, np.array([168, 100, 80]), np.array([180, 255, 255]))
mask = cv2.bitwise_or(red_mask1, red_mask2)
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2))
return cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=1)
def _find_red_circles(
red_mask: np.ndarray,
) -> list[tuple[int, int, int, int]]:
"""Find red circle contours and split merged clusters."""
contours, _ = cv2.findContours(red_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
circles: list[tuple[int, int, int, int]] = []
for cnt in contours:
x, y, w, h = cv2.boundingRect(cnt)
area = cv2.contourArea(cnt)
if area < 50:
continue
perimeter = cv2.arcLength(cnt, True)
circularity = 4 * 3.14159 * area / (perimeter * perimeter) if perimeter > 0 else 0
aspect = w / max(h, 1)
if (circularity > 0.5
and 0.5 < aspect < 1.8
and 12 < w < _MAX_CIRCLE_SIZE
and 12 < h < _MAX_CIRCLE_SIZE):
circles.append((x, y, w, h))
elif (h > 12 and h < _MAX_CIRCLE_SIZE
and w > 45 and aspect > 1.3):
n_circles = max(2, round(w / 30))
step = w / n_circles
for i in range(n_circles):
cx = int(x + i * step)
cw = int(step)
circles.append((cx, y, cw, h))
return circles
def _ocr_circle_uid(
image_bgr: np.ndarray,
box: tuple[int, int, int, int],
) -> tuple[str, float]:
"""Read the UID number (black digits) inside a red circle."""
from src.ocr_utils_demo import get_title_reader
x, y, w, h = box
pad = max(2, int(w * 0.12))
crop = image_bgr[max(0, y - pad):y + h + pad, max(0, x - pad):x + w + pad]
if crop.size == 0:
return "", 0.0
gray = cv2.cvtColor(crop, cv2.COLOR_BGR2GRAY)
_, thresh = cv2.threshold(gray, 90, 255, cv2.THRESH_BINARY_INV)
enlarged = cv2.resize(thresh, None, fx=4, fy=4, interpolation=cv2.INTER_CUBIC)
inv = 255 - enlarged
reader = get_title_reader()
attempts = [
{"detail": 1, "paragraph": False, "allowlist": "0123456789"},
{"detail": 1, "paragraph": False},
]
for opts in attempts:
try:
results = reader.readtext(inv, **opts)
except TypeError:
safe = {k: v for k, v in opts.items() if k != "allowlist"}
results = reader.readtext(inv, **safe)
except Exception:
continue
for item in results or []:
if not isinstance(item, (list, tuple)) or len(item) < 3:
continue
txt = str(item[1] or "").strip()
digits = "".join(c for c in txt if c.isdigit())
if digits and len(digits) <= 4:
conf = float(item[2]) if item[2] else 0.0
return digits.lstrip("0") or digits, conf
return "", 0.0
def _find_red_rects(
red_mask: np.ndarray,
circles: list[tuple[int, int, int, int]],
) -> list[tuple[int, int, int, int]]:
"""Find red rectangles (value cells) that are not circles.
Red rectangles are cells with red borders. They have thin red outlines
(low contour area relative to bbox) and rectangular aspect ratios.
"""
contours, _ = cv2.findContours(red_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
rects: list[tuple[int, int, int, int]] = []
for cnt in contours:
x, y, w, h = cv2.boundingRect(cnt)
area = cv2.contourArea(cnt)
if w < 20 or h < 10:
continue
if w > 400 or h > 200:
continue
perimeter = cv2.arcLength(cnt, True)
circularity = 4 * 3.14159 * area / (perimeter * perimeter) if perimeter > 0 else 0
# Skip circles
if circularity > 0.55 and w < 50 and h < 50:
continue
aspect = w / max(h, 1)
if aspect < 0.5 or aspect > 12:
continue
# Skip if overlaps with a circle
is_circle = any(
abs(x - cx) < 6 and abs(y - cy) < 6 and abs(w - cw) < 8 and abs(h - ch) < 8
for cx, cy, cw, ch in circles
)
if is_circle:
continue
rects.append((x, y, w, h))
return rects
def _ocr_red_digits_in_rect(
image_bgr: np.ndarray,
red_mask: np.ndarray,
box: tuple[int, int, int, int],
) -> tuple[str, float]:
"""Read RED UID digits inside a red rectangle.
UID digits are small red numbers in the top-right corner of the cell.
We crop the top-right quadrant, erode the border, and OCR the remaining
red pixels (which are the UID digits).
"""
from src.ocr_utils_demo import get_title_reader
x, y, w, h = box
img_h, img_w = image_bgr.shape[:2]
# Search area: top-right corner of the rect + small area above/right
search_x0 = max(0, x + int(w * 0.4))
search_y0 = max(0, y - max(4, int(h * 0.3)))
search_x1 = min(img_w, x + w + max(6, int(w * 0.2)))
search_y1 = min(img_h, y + int(h * 0.5))
if search_x1 <= search_x0 or search_y1 <= search_y0:
return "", 0.0
crop_red = red_mask[search_y0:search_y1, search_x0:search_x1].copy()
if crop_red.size == 0 or not np.count_nonzero(crop_red):
return "", 0.0
# Erode to remove the thick border lines, leaving only small digit strokes
erode_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2))
eroded = cv2.erode(crop_red, erode_kernel, iterations=1)
if not np.count_nonzero(eroded):
eroded = crop_red
enlarged = cv2.resize(eroded, None, fx=6, fy=6, interpolation=cv2.INTER_NEAREST)
enlarged = cv2.dilate(enlarged, np.ones((3, 3), dtype=np.uint8), iterations=1)
on_white = np.full_like(enlarged, 255)
on_white[enlarged > 0] = 0
reader = get_title_reader()
attempts = [
{"detail": 1, "paragraph": False, "allowlist": "0123456789"},
{"detail": 1, "paragraph": False},
]
for opts in attempts:
try:
results = reader.readtext(on_white, **opts)
except TypeError:
safe = {k: v for k, v in opts.items() if k != "allowlist"}
results = reader.readtext(on_white, **safe)
except Exception:
continue
for item in results or []:
if not isinstance(item, (list, tuple)) or len(item) < 3:
continue
txt = str(item[1] or "").strip()
digits = "".join(c for c in txt if c.isdigit())
if digits and len(digits) <= 4:
conf = float(item[2]) if item[2] else 0.0
return digits.lstrip("0") or digits, conf
return "", 0.0
def _bounds_to_points(x: int, y: int, w: int, h: int) -> list[dict[str, float]]:
return [
{"x": float(x), "y": float(y)},
{"x": float(x + w), "y": float(y)},
{"x": float(x + w), "y": float(y + h)},
{"x": float(x), "y": float(y + h)},
]
def _center(box: tuple[int, int, int, int]) -> tuple[float, float]:
return box[0] + box[2] / 2.0, box[1] + box[3] / 2.0
def _distance(a: tuple[float, float], b: tuple[float, float]) -> float:
return ((a[0] - b[0]) ** 2 + (a[1] - b[1]) ** 2) ** 0.5
def detect_red_elements(
image_bgr: np.ndarray,
*,
vision_api_url: str = "",
vision_api_key: str = "",
vision_model: str = "",
) -> dict[str, Any]:
"""Detect red circles (UID badges).
If vision_api_url is set, uses Vision LLM for batch OCR of circles
(much more accurate). Otherwise falls back to EasyOCR.
Returns:
Dict with circles, circles_with_uid, red_mask.
"""
red_mask = _detect_red_mask(image_bgr)
circle_boxes = _find_red_circles(red_mask)
# Try Vision LLM batch OCR first
vision_uids: list[str] | None = None
if vision_api_url and circle_boxes:
try:
from src.vision_circle_ocr import batch_ocr_circles
vision_uids = batch_ocr_circles(
image_bgr, circle_boxes,
api_url=vision_api_url,
api_key=vision_api_key,
model=vision_model,
)
except Exception as exc:
logger.warning("Vision batch OCR failed, falling back to EasyOCR: %s", exc)
circles: list[dict[str, Any]] = []
for i, box in enumerate(circle_boxes):
if vision_uids is not None:
uid = vision_uids[i] if i < len(vision_uids) else ""
conf = 0.9 if uid else 0.0
else:
uid, conf = _ocr_circle_uid(image_bgr, box)
circles.append({
"box": box,
"uid": uid,
"confidence": conf,
"center": _center(box),
"type": "circle",
})
circles_with_uid = [c for c in circles if c["uid"]]
logger.info(
"Red circles: %d total, %d with UID (vision=%s)",
len(circles), len(circles_with_uid),
"yes" if vision_uids is not None else "no",
)
return {
"circles": circles,
"circles_with_uid": circles_with_uid,
"red_mask": red_mask,
}
def build_overlays_from_red_elements(
circles_with_uid: list[dict[str, Any]],
rects_with_uid: list[dict[str, Any]],
*,
uid_lookup: dict[str, dict[str, Any]] | None = None,
) -> list[dict[str, Any]]:
"""Convert detected red elements (circles + rects) with UIDs to overlays.
- Circles: UID badge pointing to a nearby parameter widget.
- Rects: Value cell with red UID digits, the rect IS the widget area.
Args:
circles_with_uid: Circle dicts from detect_red_elements.
rects_with_uid: Rect dicts from detect_red_elements.
uid_lookup: Excel lookup: uid_str -> row dict.
"""
lookup = uid_lookup or {}
overlays: list[dict[str, Any]] = []
seen_uids: set[str] = set()
all_elements = sorted(
[*circles_with_uid, *rects_with_uid],
key=lambda c: float(c.get("confidence", 0)),
reverse=True,
)
for elem in all_elements:
uid = str(elem["uid"]).strip()
if not uid or uid in seen_uids:
continue
seen_uids.add(uid)
bx, by, bw, bh = elem["box"]
elem_type = elem.get("type", "circle")
row = lookup.get(uid, {})
widget = str(row.get("widget") or "").strip()
om_path = str(row.get("om_path") or "").strip()
parameter = str(row.get("parameter") or "").strip()
equipment = str(row.get("equipment") or "").strip()
overlays.append({
"id": f"uid-{uid}",
"kind": "cell",
"category": "value",
"label": parameter[:40] if parameter else f"#{uid}",
"confidence": max(0.5, min(0.99, float(elem.get("confidence", 0.8)))),
"bounds": {"x": bx, "y": by, "width": bw, "height": bh},
"points": _bounds_to_points(bx, by, bw, bh),
"meta": {
"source": f"red_{elem_type}_detector",
"elementType": elem_type,
"ocrConfidence": round(float(elem.get("confidence", 0)), 3),
"equipment": equipment,
"parameter": parameter,
},
"bindingUid": uid,
"note": parameter[:60] if parameter else "",
"contextPath": om_path,
"widgetNameOverride": widget,
"staticShapeNameOverride": "",
})
overlays.sort(key=lambda o: (
int(o.get("bindingUid") or "0")
if (o.get("bindingUid") or "").isdigit()
else 0
))
return overlays