| """ |
| Image preprocessing for image-input SPARK on real-world uploads. |
| |
| Real user uploads (paper-figure crops, software screenshots, photos of |
| lab monitors) live in a much wider image distribution than the rendered |
| training PNGs. This module produces a cleaned grayscale 224x224 PIL image |
| that looks closer to the training distribution before it enters the |
| image-mode CNN. |
| |
| Three stages, all PIL-in / PIL-out: |
| |
| 1. crop_to_plot_region -- OCR-based detection of the inner plot |
| bounding box; crops out browser chrome, |
| paper captions, side panels. |
| 2. remove_gridlines_and_background |
| -- adaptive threshold + morphological line |
| detection to suppress thin gridlines and |
| normalize the background to white. |
| 3. prepare_for_image_mode |
| -- orchestrator: crop -> clean -> resize. |
| |
| All heavy CV deps (`cv2`, `easyocr`) are imported lazily so the module |
| loads cleanly in environments that lack them; in that case the relevant |
| function is a no-op and `meta['was_*']` reports False. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from typing import Dict, Optional, Tuple |
|
|
| import numpy as np |
| from PIL import Image, ImageOps |
|
|
|
|
| |
| |
| |
|
|
| def _detect_label_positions(image_array: np.ndarray): |
| """Run OCR and return raw (cx, cy, val) tuples for every numeric label. |
| |
| Mirrors the OCR pass in `digitizer.auto_detect_axis_bounds` but exposes |
| the per-label pixel positions, which we need to locate the inner plot |
| bounding box (right of y-labels, above x-labels). |
| |
| Returns ([], None) if easyocr is unavailable or finds <4 numeric labels. |
| """ |
| try: |
| import easyocr |
| except ImportError: |
| return [], None |
| import re |
|
|
| if image_array.ndim == 3 and image_array.shape[2] == 4: |
| image_array = image_array[:, :, :3] |
|
|
| H, W = image_array.shape[:2] |
| reader = easyocr.Reader(["en"], gpu=False, verbose=False) |
| try: |
| results = reader.readtext(image_array, detail=1) |
| except Exception: |
| return [], None |
|
|
| _NUM_RE = re.compile(r"^[−\-–~]?\d+\.?\d*(?:[eE][+\-]?\d+)?$") |
| detections = [] |
| for bbox, text, conf in results: |
| cleaned = (text.strip().replace(" ", "") |
| .replace("−", "-").replace("–", "-").replace("~", "-")) |
| if not _NUM_RE.match(cleaned): |
| continue |
| try: |
| float(cleaned) |
| except ValueError: |
| continue |
| if conf < 0.2: |
| continue |
| cx = float(np.mean([p[0] for p in bbox])) |
| cy = float(np.mean([p[1] for p in bbox])) |
| detections.append((cx, cy, float(cleaned.replace("-", "-")))) |
|
|
| if len(detections) < 4: |
| return [], None |
| return detections, (H, W) |
|
|
|
|
| def _plot_bbox_from_detections(detections, hw, margin_frac: float = 0.02): |
| """Compute inner-plot bounding box (left, top, right, bottom) in pixels |
| from raw OCR label detections. |
| |
| Heuristic: |
| - y-axis labels live in the left third of the image |
| -> plot_left = max cx among y-labels + margin |
| - x-axis labels live in the bottom third of the image |
| -> plot_bottom = min cy among x-labels - margin |
| - plot_right roughly = max cx among x-labels + margin (fallback to W) |
| - plot_top roughly = min cy among y-labels - margin (fallback to 0) |
| |
| Returns (left, top, right, bottom) ints, or None if heuristic fails. |
| """ |
| H, W = hw |
| margin = int(margin_frac * max(H, W)) |
|
|
| y_label_cxs = [cx for cx, cy, _ in detections if cx < W * 0.30] |
| y_label_cys = [cy for cx, cy, _ in detections if cx < W * 0.30] |
| x_label_cxs = [cx for cx, cy, _ in detections if cy > H * 0.65] |
| x_label_cys = [cy for cx, cy, _ in detections if cy > H * 0.65] |
|
|
| if not y_label_cxs or not x_label_cys: |
| return None |
|
|
| plot_left = int(max(y_label_cxs) + margin) |
| plot_bottom = int(min(x_label_cys) - margin) |
| plot_right = int(max(x_label_cxs) + margin) if x_label_cxs else W |
| plot_top = int(min(y_label_cys) - margin) if y_label_cys else 0 |
|
|
| plot_left = max(0, min(plot_left, W - 1)) |
| plot_right = max(plot_left + 1, min(plot_right, W)) |
| plot_top = max(0, min(plot_top, H - 1)) |
| plot_bottom = max(plot_top + 1, min(plot_bottom, H)) |
|
|
| if plot_right - plot_left < 32 or plot_bottom - plot_top < 32: |
| return None |
| return (plot_left, plot_top, plot_right, plot_bottom) |
|
|
|
|
| def crop_to_plot_region(pil_image: Image.Image, |
| margin_frac: float = 0.02, |
| ) -> Tuple[Image.Image, Optional[Tuple[int, int, int, int]]]: |
| """Detect the inner plot bbox via OCR and crop to it. |
| |
| Args: |
| pil_image: input PIL image (any mode). |
| margin_frac: small padding around the detected plot region as a |
| fraction of max(H, W). |
| |
| Returns: |
| (cropped_pil, bbox) where bbox is (left, top, right, bottom) ints |
| or None if OCR-based detection failed (in which case |
| cropped_pil == pil_image). |
| """ |
| arr = np.asarray(pil_image.convert("RGB")) |
| dets, hw = _detect_label_positions(arr) |
| if not dets or hw is None: |
| return pil_image, None |
| bbox = _plot_bbox_from_detections(dets, hw, margin_frac=margin_frac) |
| if bbox is None: |
| return pil_image, None |
| cropped = pil_image.crop(bbox) |
| return cropped, bbox |
|
|
|
|
| |
| |
| |
|
|
| def _ensure_grayscale(pil_image: Image.Image) -> np.ndarray: |
| """Return uint8 grayscale numpy array from any PIL image.""" |
| if pil_image.mode != "L": |
| pil_image = pil_image.convert("L") |
| return np.asarray(pil_image, dtype=np.uint8) |
|
|
|
|
| def remove_gridlines_and_background( |
| pil_image: Image.Image, |
| background_stretch: bool = True, |
| remove_gridlines: bool = True, |
| grid_min_length_frac: float = 0.30, |
| soft_threshold: int = 245, |
| ) -> Tuple[Image.Image, Dict[str, object]]: |
| """Normalize background to white and (optionally) remove thin gridlines. |
| |
| Pipeline: |
| 1. Convert to grayscale. |
| 2. (background_stretch) Linearly stretch the gray histogram so the |
| brightest pixel is 255 (cancels colored / off-white backgrounds). |
| 3. (remove_gridlines) Adaptive-threshold to a binary mask of dark |
| pixels (curve + axes + text + gridlines), then morphological |
| opening with very long horizontal `(1, K)` and vertical `(K, 1)` |
| kernels finds long thin lines; we inpaint those regions on the |
| grayscale image. The main curve survives because morphological |
| opening with a 1xK kernel only keeps strictly straight horizontal |
| runs of >=K dark pixels; a curving line breaks the connectivity. |
| 4. (soft_threshold) Push pixels >= `soft_threshold` to pure 255 to |
| snap any residual near-white background to clean white. |
| |
| Falls back to a pure-PIL background stretch if cv2 is unavailable. |
| |
| Returns: |
| (cleaned_pil, meta) where meta has keys was_stretched, |
| was_cleaned, n_horiz_gridlines, n_vert_gridlines. |
| """ |
| meta: Dict[str, object] = { |
| "was_stretched": False, |
| "was_cleaned": False, |
| "n_horiz_gridlines": 0, |
| "n_vert_gridlines": 0, |
| } |
| arr = _ensure_grayscale(pil_image) |
|
|
| if background_stretch: |
| if arr.max() > 0: |
| scale = 255.0 / float(arr.max()) |
| arr = np.clip(arr.astype(np.float32) * scale, 0, 255).astype(np.uint8) |
| meta["was_stretched"] = True |
|
|
| try: |
| import cv2 |
| except ImportError: |
| if soft_threshold > 0: |
| arr = np.where(arr >= soft_threshold, 255, arr).astype(np.uint8) |
| return Image.fromarray(arr, mode="L"), meta |
|
|
| if remove_gridlines: |
| H, W = arr.shape |
| binary = cv2.adaptiveThreshold( |
| arr, 255, |
| cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY_INV, |
| blockSize=31, C=10, |
| ) |
| K_h = max(20, int(W * grid_min_length_frac)) |
| K_v = max(20, int(H * grid_min_length_frac)) |
|
|
| h_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (K_h, 1)) |
| h_lines = cv2.morphologyEx(binary, cv2.MORPH_OPEN, h_kernel) |
| v_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, K_v)) |
| v_lines = cv2.morphologyEx(binary, cv2.MORPH_OPEN, v_kernel) |
|
|
| meta["n_horiz_gridlines"] = int((h_lines.sum(axis=1) > 0).sum()) |
| meta["n_vert_gridlines"] = int((v_lines.sum(axis=0) > 0).sum()) |
|
|
| line_mask = cv2.bitwise_or(h_lines, v_lines) |
| if line_mask.sum() > 0: |
| line_mask = cv2.dilate(line_mask, np.ones((2, 2), np.uint8)) |
| arr = cv2.inpaint(arr, line_mask, 3, cv2.INPAINT_TELEA) |
| meta["was_cleaned"] = True |
|
|
| if soft_threshold > 0: |
| arr = np.where(arr >= soft_threshold, 255, arr).astype(np.uint8) |
|
|
| return Image.fromarray(arr, mode="L"), meta |
|
|
|
|
| |
| |
| |
|
|
| def prepare_for_image_mode( |
| pil_image: Image.Image, |
| do_crop: bool = True, |
| do_clean: bool = True, |
| target_size: int = 224, |
| ) -> Tuple[Image.Image, Dict[str, object]]: |
| """Full preprocessing pipeline for image-mode SPARK. |
| |
| Steps (any can be skipped): |
| crop_to_plot_region -> remove_gridlines_and_background -> resize. |
| |
| Args: |
| pil_image: any-mode PIL.Image. |
| do_crop: run OCR-based plot-region cropping. |
| do_clean: run background normalization + gridline removal. |
| target_size: output square edge length. |
| |
| Returns: |
| (preprocessed_pil_L, meta) where meta is a flat dict suitable for |
| showing in the UI: |
| was_cropped: bool |
| crop_bbox: (l, t, r, b) or None |
| was_stretched: bool |
| was_cleaned: bool |
| n_horiz_gridlines: int |
| n_vert_gridlines: int |
| target_size: int |
| """ |
| meta: Dict[str, object] = { |
| "was_cropped": False, |
| "crop_bbox": None, |
| "was_stretched": False, |
| "was_cleaned": False, |
| "n_horiz_gridlines": 0, |
| "n_vert_gridlines": 0, |
| "target_size": target_size, |
| } |
|
|
| img = pil_image |
| if do_crop: |
| cropped, bbox = crop_to_plot_region(img) |
| if bbox is not None: |
| img = cropped |
| meta["was_cropped"] = True |
| meta["crop_bbox"] = list(bbox) |
|
|
| if do_clean: |
| cleaned, clean_meta = remove_gridlines_and_background(img) |
| img = cleaned |
| meta["was_stretched"] = clean_meta["was_stretched"] |
| meta["was_cleaned"] = clean_meta["was_cleaned"] |
| meta["n_horiz_gridlines"] = clean_meta["n_horiz_gridlines"] |
| meta["n_vert_gridlines"] = clean_meta["n_vert_gridlines"] |
| else: |
| if img.mode != "L": |
| img = img.convert("L") |
|
|
| if img.size != (target_size, target_size): |
| img = img.resize((target_size, target_size), Image.BILINEAR) |
|
|
| return img, meta |
|
|
|
|
| __all__ = [ |
| "crop_to_plot_region", |
| "remove_gridlines_and_background", |
| "prepare_for_image_mode", |
| ] |
|
|