""" Image preprocessing for image-input SPARK on real-world uploads. Real user uploads (paper-figure crops, software screenshots, photos of lab monitors) live in a much wider image distribution than the rendered training PNGs. This module produces a cleaned grayscale 224x224 PIL image that looks closer to the training distribution before it enters the image-mode CNN. Three stages, all PIL-in / PIL-out: 1. crop_to_plot_region -- OCR-based detection of the inner plot bounding box; crops out browser chrome, paper captions, side panels. 2. remove_gridlines_and_background -- adaptive threshold + morphological line detection to suppress thin gridlines and normalize the background to white. 3. prepare_for_image_mode -- orchestrator: crop -> clean -> resize. All heavy CV deps (`cv2`, `easyocr`) are imported lazily so the module loads cleanly in environments that lack them; in that case the relevant function is a no-op and `meta['was_*']` reports False. """ from __future__ import annotations from typing import Dict, Optional, Tuple import numpy as np from PIL import Image, ImageOps # -------------------------------------------------------------------------- # Plot-region cropping (OCR-based) # -------------------------------------------------------------------------- def _detect_label_positions(image_array: np.ndarray): """Run OCR and return raw (cx, cy, val) tuples for every numeric label. Mirrors the OCR pass in `digitizer.auto_detect_axis_bounds` but exposes the per-label pixel positions, which we need to locate the inner plot bounding box (right of y-labels, above x-labels). Returns ([], None) if easyocr is unavailable or finds <4 numeric labels. """ try: import easyocr except ImportError: return [], None import re if image_array.ndim == 3 and image_array.shape[2] == 4: image_array = image_array[:, :, :3] H, W = image_array.shape[:2] reader = easyocr.Reader(["en"], gpu=False, verbose=False) try: results = reader.readtext(image_array, detail=1) except Exception: return [], None _NUM_RE = re.compile(r"^[−\-–~]?\d+\.?\d*(?:[eE][+\-]?\d+)?$") detections = [] for bbox, text, conf in results: cleaned = (text.strip().replace(" ", "") .replace("−", "-").replace("–", "-").replace("~", "-")) if not _NUM_RE.match(cleaned): continue try: float(cleaned) except ValueError: continue if conf < 0.2: continue cx = float(np.mean([p[0] for p in bbox])) cy = float(np.mean([p[1] for p in bbox])) detections.append((cx, cy, float(cleaned.replace("-", "-")))) if len(detections) < 4: return [], None return detections, (H, W) def _plot_bbox_from_detections(detections, hw, margin_frac: float = 0.02): """Compute inner-plot bounding box (left, top, right, bottom) in pixels from raw OCR label detections. Heuristic: - y-axis labels live in the left third of the image -> plot_left = max cx among y-labels + margin - x-axis labels live in the bottom third of the image -> plot_bottom = min cy among x-labels - margin - plot_right roughly = max cx among x-labels + margin (fallback to W) - plot_top roughly = min cy among y-labels - margin (fallback to 0) Returns (left, top, right, bottom) ints, or None if heuristic fails. """ H, W = hw margin = int(margin_frac * max(H, W)) y_label_cxs = [cx for cx, cy, _ in detections if cx < W * 0.30] y_label_cys = [cy for cx, cy, _ in detections if cx < W * 0.30] x_label_cxs = [cx for cx, cy, _ in detections if cy > H * 0.65] x_label_cys = [cy for cx, cy, _ in detections if cy > H * 0.65] if not y_label_cxs or not x_label_cys: return None plot_left = int(max(y_label_cxs) + margin) plot_bottom = int(min(x_label_cys) - margin) plot_right = int(max(x_label_cxs) + margin) if x_label_cxs else W plot_top = int(min(y_label_cys) - margin) if y_label_cys else 0 plot_left = max(0, min(plot_left, W - 1)) plot_right = max(plot_left + 1, min(plot_right, W)) plot_top = max(0, min(plot_top, H - 1)) plot_bottom = max(plot_top + 1, min(plot_bottom, H)) if plot_right - plot_left < 32 or plot_bottom - plot_top < 32: return None return (plot_left, plot_top, plot_right, plot_bottom) def crop_to_plot_region(pil_image: Image.Image, margin_frac: float = 0.02, ) -> Tuple[Image.Image, Optional[Tuple[int, int, int, int]]]: """Detect the inner plot bbox via OCR and crop to it. Args: pil_image: input PIL image (any mode). margin_frac: small padding around the detected plot region as a fraction of max(H, W). Returns: (cropped_pil, bbox) where bbox is (left, top, right, bottom) ints or None if OCR-based detection failed (in which case cropped_pil == pil_image). """ arr = np.asarray(pil_image.convert("RGB")) dets, hw = _detect_label_positions(arr) if not dets or hw is None: return pil_image, None bbox = _plot_bbox_from_detections(dets, hw, margin_frac=margin_frac) if bbox is None: return pil_image, None cropped = pil_image.crop(bbox) return cropped, bbox # -------------------------------------------------------------------------- # Background normalization + gridline removal (CV2-based) # -------------------------------------------------------------------------- def _ensure_grayscale(pil_image: Image.Image) -> np.ndarray: """Return uint8 grayscale numpy array from any PIL image.""" if pil_image.mode != "L": pil_image = pil_image.convert("L") return np.asarray(pil_image, dtype=np.uint8) def remove_gridlines_and_background( pil_image: Image.Image, background_stretch: bool = True, remove_gridlines: bool = True, grid_min_length_frac: float = 0.30, soft_threshold: int = 245, ) -> Tuple[Image.Image, Dict[str, object]]: """Normalize background to white and (optionally) remove thin gridlines. Pipeline: 1. Convert to grayscale. 2. (background_stretch) Linearly stretch the gray histogram so the brightest pixel is 255 (cancels colored / off-white backgrounds). 3. (remove_gridlines) Adaptive-threshold to a binary mask of dark pixels (curve + axes + text + gridlines), then morphological opening with very long horizontal `(1, K)` and vertical `(K, 1)` kernels finds long thin lines; we inpaint those regions on the grayscale image. The main curve survives because morphological opening with a 1xK kernel only keeps strictly straight horizontal runs of >=K dark pixels; a curving line breaks the connectivity. 4. (soft_threshold) Push pixels >= `soft_threshold` to pure 255 to snap any residual near-white background to clean white. Falls back to a pure-PIL background stretch if cv2 is unavailable. Returns: (cleaned_pil, meta) where meta has keys was_stretched, was_cleaned, n_horiz_gridlines, n_vert_gridlines. """ meta: Dict[str, object] = { "was_stretched": False, "was_cleaned": False, "n_horiz_gridlines": 0, "n_vert_gridlines": 0, } arr = _ensure_grayscale(pil_image) if background_stretch: if arr.max() > 0: scale = 255.0 / float(arr.max()) arr = np.clip(arr.astype(np.float32) * scale, 0, 255).astype(np.uint8) meta["was_stretched"] = True try: import cv2 except ImportError: if soft_threshold > 0: arr = np.where(arr >= soft_threshold, 255, arr).astype(np.uint8) return Image.fromarray(arr, mode="L"), meta if remove_gridlines: H, W = arr.shape binary = cv2.adaptiveThreshold( arr, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY_INV, blockSize=31, C=10, ) K_h = max(20, int(W * grid_min_length_frac)) K_v = max(20, int(H * grid_min_length_frac)) h_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (K_h, 1)) h_lines = cv2.morphologyEx(binary, cv2.MORPH_OPEN, h_kernel) v_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, K_v)) v_lines = cv2.morphologyEx(binary, cv2.MORPH_OPEN, v_kernel) meta["n_horiz_gridlines"] = int((h_lines.sum(axis=1) > 0).sum()) meta["n_vert_gridlines"] = int((v_lines.sum(axis=0) > 0).sum()) line_mask = cv2.bitwise_or(h_lines, v_lines) if line_mask.sum() > 0: line_mask = cv2.dilate(line_mask, np.ones((2, 2), np.uint8)) arr = cv2.inpaint(arr, line_mask, 3, cv2.INPAINT_TELEA) meta["was_cleaned"] = True if soft_threshold > 0: arr = np.where(arr >= soft_threshold, 255, arr).astype(np.uint8) return Image.fromarray(arr, mode="L"), meta # -------------------------------------------------------------------------- # Orchestrator # -------------------------------------------------------------------------- def prepare_for_image_mode( pil_image: Image.Image, do_crop: bool = True, do_clean: bool = True, target_size: int = 224, ) -> Tuple[Image.Image, Dict[str, object]]: """Full preprocessing pipeline for image-mode SPARK. Steps (any can be skipped): crop_to_plot_region -> remove_gridlines_and_background -> resize. Args: pil_image: any-mode PIL.Image. do_crop: run OCR-based plot-region cropping. do_clean: run background normalization + gridline removal. target_size: output square edge length. Returns: (preprocessed_pil_L, meta) where meta is a flat dict suitable for showing in the UI: was_cropped: bool crop_bbox: (l, t, r, b) or None was_stretched: bool was_cleaned: bool n_horiz_gridlines: int n_vert_gridlines: int target_size: int """ meta: Dict[str, object] = { "was_cropped": False, "crop_bbox": None, "was_stretched": False, "was_cleaned": False, "n_horiz_gridlines": 0, "n_vert_gridlines": 0, "target_size": target_size, } img = pil_image if do_crop: cropped, bbox = crop_to_plot_region(img) if bbox is not None: img = cropped meta["was_cropped"] = True meta["crop_bbox"] = list(bbox) if do_clean: cleaned, clean_meta = remove_gridlines_and_background(img) img = cleaned meta["was_stretched"] = clean_meta["was_stretched"] meta["was_cleaned"] = clean_meta["was_cleaned"] meta["n_horiz_gridlines"] = clean_meta["n_horiz_gridlines"] meta["n_vert_gridlines"] = clean_meta["n_vert_gridlines"] else: if img.mode != "L": img = img.convert("L") if img.size != (target_size, target_size): img = img.resize((target_size, target_size), Image.BILINEAR) return img, meta __all__ = [ "crop_to_plot_region", "remove_gridlines_and_background", "prepare_for_image_mode", ]