trace / image_preprocessing.py
bingyan user
Rebrand TRACE -> SPARK
8619a66
"""
Image preprocessing for image-input SPARK on real-world uploads.
Real user uploads (paper-figure crops, software screenshots, photos of
lab monitors) live in a much wider image distribution than the rendered
training PNGs. This module produces a cleaned grayscale 224x224 PIL image
that looks closer to the training distribution before it enters the
image-mode CNN.
Three stages, all PIL-in / PIL-out:
1. crop_to_plot_region -- OCR-based detection of the inner plot
bounding box; crops out browser chrome,
paper captions, side panels.
2. remove_gridlines_and_background
-- adaptive threshold + morphological line
detection to suppress thin gridlines and
normalize the background to white.
3. prepare_for_image_mode
-- orchestrator: crop -> clean -> resize.
All heavy CV deps (`cv2`, `easyocr`) are imported lazily so the module
loads cleanly in environments that lack them; in that case the relevant
function is a no-op and `meta['was_*']` reports False.
"""
from __future__ import annotations
from typing import Dict, Optional, Tuple
import numpy as np
from PIL import Image, ImageOps
# --------------------------------------------------------------------------
# Plot-region cropping (OCR-based)
# --------------------------------------------------------------------------
def _detect_label_positions(image_array: np.ndarray):
"""Run OCR and return raw (cx, cy, val) tuples for every numeric label.
Mirrors the OCR pass in `digitizer.auto_detect_axis_bounds` but exposes
the per-label pixel positions, which we need to locate the inner plot
bounding box (right of y-labels, above x-labels).
Returns ([], None) if easyocr is unavailable or finds <4 numeric labels.
"""
try:
import easyocr
except ImportError:
return [], None
import re
if image_array.ndim == 3 and image_array.shape[2] == 4:
image_array = image_array[:, :, :3]
H, W = image_array.shape[:2]
reader = easyocr.Reader(["en"], gpu=False, verbose=False)
try:
results = reader.readtext(image_array, detail=1)
except Exception:
return [], None
_NUM_RE = re.compile(r"^[−\-–~]?\d+\.?\d*(?:[eE][+\-]?\d+)?$")
detections = []
for bbox, text, conf in results:
cleaned = (text.strip().replace(" ", "")
.replace("−", "-").replace("–", "-").replace("~", "-"))
if not _NUM_RE.match(cleaned):
continue
try:
float(cleaned)
except ValueError:
continue
if conf < 0.2:
continue
cx = float(np.mean([p[0] for p in bbox]))
cy = float(np.mean([p[1] for p in bbox]))
detections.append((cx, cy, float(cleaned.replace("-", "-"))))
if len(detections) < 4:
return [], None
return detections, (H, W)
def _plot_bbox_from_detections(detections, hw, margin_frac: float = 0.02):
"""Compute inner-plot bounding box (left, top, right, bottom) in pixels
from raw OCR label detections.
Heuristic:
- y-axis labels live in the left third of the image
-> plot_left = max cx among y-labels + margin
- x-axis labels live in the bottom third of the image
-> plot_bottom = min cy among x-labels - margin
- plot_right roughly = max cx among x-labels + margin (fallback to W)
- plot_top roughly = min cy among y-labels - margin (fallback to 0)
Returns (left, top, right, bottom) ints, or None if heuristic fails.
"""
H, W = hw
margin = int(margin_frac * max(H, W))
y_label_cxs = [cx for cx, cy, _ in detections if cx < W * 0.30]
y_label_cys = [cy for cx, cy, _ in detections if cx < W * 0.30]
x_label_cxs = [cx for cx, cy, _ in detections if cy > H * 0.65]
x_label_cys = [cy for cx, cy, _ in detections if cy > H * 0.65]
if not y_label_cxs or not x_label_cys:
return None
plot_left = int(max(y_label_cxs) + margin)
plot_bottom = int(min(x_label_cys) - margin)
plot_right = int(max(x_label_cxs) + margin) if x_label_cxs else W
plot_top = int(min(y_label_cys) - margin) if y_label_cys else 0
plot_left = max(0, min(plot_left, W - 1))
plot_right = max(plot_left + 1, min(plot_right, W))
plot_top = max(0, min(plot_top, H - 1))
plot_bottom = max(plot_top + 1, min(plot_bottom, H))
if plot_right - plot_left < 32 or plot_bottom - plot_top < 32:
return None
return (plot_left, plot_top, plot_right, plot_bottom)
def crop_to_plot_region(pil_image: Image.Image,
margin_frac: float = 0.02,
) -> Tuple[Image.Image, Optional[Tuple[int, int, int, int]]]:
"""Detect the inner plot bbox via OCR and crop to it.
Args:
pil_image: input PIL image (any mode).
margin_frac: small padding around the detected plot region as a
fraction of max(H, W).
Returns:
(cropped_pil, bbox) where bbox is (left, top, right, bottom) ints
or None if OCR-based detection failed (in which case
cropped_pil == pil_image).
"""
arr = np.asarray(pil_image.convert("RGB"))
dets, hw = _detect_label_positions(arr)
if not dets or hw is None:
return pil_image, None
bbox = _plot_bbox_from_detections(dets, hw, margin_frac=margin_frac)
if bbox is None:
return pil_image, None
cropped = pil_image.crop(bbox)
return cropped, bbox
# --------------------------------------------------------------------------
# Background normalization + gridline removal (CV2-based)
# --------------------------------------------------------------------------
def _ensure_grayscale(pil_image: Image.Image) -> np.ndarray:
"""Return uint8 grayscale numpy array from any PIL image."""
if pil_image.mode != "L":
pil_image = pil_image.convert("L")
return np.asarray(pil_image, dtype=np.uint8)
def remove_gridlines_and_background(
pil_image: Image.Image,
background_stretch: bool = True,
remove_gridlines: bool = True,
grid_min_length_frac: float = 0.30,
soft_threshold: int = 245,
) -> Tuple[Image.Image, Dict[str, object]]:
"""Normalize background to white and (optionally) remove thin gridlines.
Pipeline:
1. Convert to grayscale.
2. (background_stretch) Linearly stretch the gray histogram so the
brightest pixel is 255 (cancels colored / off-white backgrounds).
3. (remove_gridlines) Adaptive-threshold to a binary mask of dark
pixels (curve + axes + text + gridlines), then morphological
opening with very long horizontal `(1, K)` and vertical `(K, 1)`
kernels finds long thin lines; we inpaint those regions on the
grayscale image. The main curve survives because morphological
opening with a 1xK kernel only keeps strictly straight horizontal
runs of >=K dark pixels; a curving line breaks the connectivity.
4. (soft_threshold) Push pixels >= `soft_threshold` to pure 255 to
snap any residual near-white background to clean white.
Falls back to a pure-PIL background stretch if cv2 is unavailable.
Returns:
(cleaned_pil, meta) where meta has keys was_stretched,
was_cleaned, n_horiz_gridlines, n_vert_gridlines.
"""
meta: Dict[str, object] = {
"was_stretched": False,
"was_cleaned": False,
"n_horiz_gridlines": 0,
"n_vert_gridlines": 0,
}
arr = _ensure_grayscale(pil_image)
if background_stretch:
if arr.max() > 0:
scale = 255.0 / float(arr.max())
arr = np.clip(arr.astype(np.float32) * scale, 0, 255).astype(np.uint8)
meta["was_stretched"] = True
try:
import cv2
except ImportError:
if soft_threshold > 0:
arr = np.where(arr >= soft_threshold, 255, arr).astype(np.uint8)
return Image.fromarray(arr, mode="L"), meta
if remove_gridlines:
H, W = arr.shape
binary = cv2.adaptiveThreshold(
arr, 255,
cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY_INV,
blockSize=31, C=10,
)
K_h = max(20, int(W * grid_min_length_frac))
K_v = max(20, int(H * grid_min_length_frac))
h_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (K_h, 1))
h_lines = cv2.morphologyEx(binary, cv2.MORPH_OPEN, h_kernel)
v_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, K_v))
v_lines = cv2.morphologyEx(binary, cv2.MORPH_OPEN, v_kernel)
meta["n_horiz_gridlines"] = int((h_lines.sum(axis=1) > 0).sum())
meta["n_vert_gridlines"] = int((v_lines.sum(axis=0) > 0).sum())
line_mask = cv2.bitwise_or(h_lines, v_lines)
if line_mask.sum() > 0:
line_mask = cv2.dilate(line_mask, np.ones((2, 2), np.uint8))
arr = cv2.inpaint(arr, line_mask, 3, cv2.INPAINT_TELEA)
meta["was_cleaned"] = True
if soft_threshold > 0:
arr = np.where(arr >= soft_threshold, 255, arr).astype(np.uint8)
return Image.fromarray(arr, mode="L"), meta
# --------------------------------------------------------------------------
# Orchestrator
# --------------------------------------------------------------------------
def prepare_for_image_mode(
pil_image: Image.Image,
do_crop: bool = True,
do_clean: bool = True,
target_size: int = 224,
) -> Tuple[Image.Image, Dict[str, object]]:
"""Full preprocessing pipeline for image-mode SPARK.
Steps (any can be skipped):
crop_to_plot_region -> remove_gridlines_and_background -> resize.
Args:
pil_image: any-mode PIL.Image.
do_crop: run OCR-based plot-region cropping.
do_clean: run background normalization + gridline removal.
target_size: output square edge length.
Returns:
(preprocessed_pil_L, meta) where meta is a flat dict suitable for
showing in the UI:
was_cropped: bool
crop_bbox: (l, t, r, b) or None
was_stretched: bool
was_cleaned: bool
n_horiz_gridlines: int
n_vert_gridlines: int
target_size: int
"""
meta: Dict[str, object] = {
"was_cropped": False,
"crop_bbox": None,
"was_stretched": False,
"was_cleaned": False,
"n_horiz_gridlines": 0,
"n_vert_gridlines": 0,
"target_size": target_size,
}
img = pil_image
if do_crop:
cropped, bbox = crop_to_plot_region(img)
if bbox is not None:
img = cropped
meta["was_cropped"] = True
meta["crop_bbox"] = list(bbox)
if do_clean:
cleaned, clean_meta = remove_gridlines_and_background(img)
img = cleaned
meta["was_stretched"] = clean_meta["was_stretched"]
meta["was_cleaned"] = clean_meta["was_cleaned"]
meta["n_horiz_gridlines"] = clean_meta["n_horiz_gridlines"]
meta["n_vert_gridlines"] = clean_meta["n_vert_gridlines"]
else:
if img.mode != "L":
img = img.convert("L")
if img.size != (target_size, target_size):
img = img.resize((target_size, target_size), Image.BILINEAR)
return img, meta
__all__ = [
"crop_to_plot_region",
"remove_gridlines_and_background",
"prepare_for_image_mode",
]