File size: 11,523 Bytes
3f4e2ae 8619a66 3f4e2ae 8619a66 3f4e2ae | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 | """
Image preprocessing for image-input SPARK on real-world uploads.
Real user uploads (paper-figure crops, software screenshots, photos of
lab monitors) live in a much wider image distribution than the rendered
training PNGs. This module produces a cleaned grayscale 224x224 PIL image
that looks closer to the training distribution before it enters the
image-mode CNN.
Three stages, all PIL-in / PIL-out:
1. crop_to_plot_region -- OCR-based detection of the inner plot
bounding box; crops out browser chrome,
paper captions, side panels.
2. remove_gridlines_and_background
-- adaptive threshold + morphological line
detection to suppress thin gridlines and
normalize the background to white.
3. prepare_for_image_mode
-- orchestrator: crop -> clean -> resize.
All heavy CV deps (`cv2`, `easyocr`) are imported lazily so the module
loads cleanly in environments that lack them; in that case the relevant
function is a no-op and `meta['was_*']` reports False.
"""
from __future__ import annotations
from typing import Dict, Optional, Tuple
import numpy as np
from PIL import Image, ImageOps
# --------------------------------------------------------------------------
# Plot-region cropping (OCR-based)
# --------------------------------------------------------------------------
def _detect_label_positions(image_array: np.ndarray):
"""Run OCR and return raw (cx, cy, val) tuples for every numeric label.
Mirrors the OCR pass in `digitizer.auto_detect_axis_bounds` but exposes
the per-label pixel positions, which we need to locate the inner plot
bounding box (right of y-labels, above x-labels).
Returns ([], None) if easyocr is unavailable or finds <4 numeric labels.
"""
try:
import easyocr
except ImportError:
return [], None
import re
if image_array.ndim == 3 and image_array.shape[2] == 4:
image_array = image_array[:, :, :3]
H, W = image_array.shape[:2]
reader = easyocr.Reader(["en"], gpu=False, verbose=False)
try:
results = reader.readtext(image_array, detail=1)
except Exception:
return [], None
_NUM_RE = re.compile(r"^[−\-–~]?\d+\.?\d*(?:[eE][+\-]?\d+)?$")
detections = []
for bbox, text, conf in results:
cleaned = (text.strip().replace(" ", "")
.replace("−", "-").replace("–", "-").replace("~", "-"))
if not _NUM_RE.match(cleaned):
continue
try:
float(cleaned)
except ValueError:
continue
if conf < 0.2:
continue
cx = float(np.mean([p[0] for p in bbox]))
cy = float(np.mean([p[1] for p in bbox]))
detections.append((cx, cy, float(cleaned.replace("-", "-"))))
if len(detections) < 4:
return [], None
return detections, (H, W)
def _plot_bbox_from_detections(detections, hw, margin_frac: float = 0.02):
"""Compute inner-plot bounding box (left, top, right, bottom) in pixels
from raw OCR label detections.
Heuristic:
- y-axis labels live in the left third of the image
-> plot_left = max cx among y-labels + margin
- x-axis labels live in the bottom third of the image
-> plot_bottom = min cy among x-labels - margin
- plot_right roughly = max cx among x-labels + margin (fallback to W)
- plot_top roughly = min cy among y-labels - margin (fallback to 0)
Returns (left, top, right, bottom) ints, or None if heuristic fails.
"""
H, W = hw
margin = int(margin_frac * max(H, W))
y_label_cxs = [cx for cx, cy, _ in detections if cx < W * 0.30]
y_label_cys = [cy for cx, cy, _ in detections if cx < W * 0.30]
x_label_cxs = [cx for cx, cy, _ in detections if cy > H * 0.65]
x_label_cys = [cy for cx, cy, _ in detections if cy > H * 0.65]
if not y_label_cxs or not x_label_cys:
return None
plot_left = int(max(y_label_cxs) + margin)
plot_bottom = int(min(x_label_cys) - margin)
plot_right = int(max(x_label_cxs) + margin) if x_label_cxs else W
plot_top = int(min(y_label_cys) - margin) if y_label_cys else 0
plot_left = max(0, min(plot_left, W - 1))
plot_right = max(plot_left + 1, min(plot_right, W))
plot_top = max(0, min(plot_top, H - 1))
plot_bottom = max(plot_top + 1, min(plot_bottom, H))
if plot_right - plot_left < 32 or plot_bottom - plot_top < 32:
return None
return (plot_left, plot_top, plot_right, plot_bottom)
def crop_to_plot_region(pil_image: Image.Image,
margin_frac: float = 0.02,
) -> Tuple[Image.Image, Optional[Tuple[int, int, int, int]]]:
"""Detect the inner plot bbox via OCR and crop to it.
Args:
pil_image: input PIL image (any mode).
margin_frac: small padding around the detected plot region as a
fraction of max(H, W).
Returns:
(cropped_pil, bbox) where bbox is (left, top, right, bottom) ints
or None if OCR-based detection failed (in which case
cropped_pil == pil_image).
"""
arr = np.asarray(pil_image.convert("RGB"))
dets, hw = _detect_label_positions(arr)
if not dets or hw is None:
return pil_image, None
bbox = _plot_bbox_from_detections(dets, hw, margin_frac=margin_frac)
if bbox is None:
return pil_image, None
cropped = pil_image.crop(bbox)
return cropped, bbox
# --------------------------------------------------------------------------
# Background normalization + gridline removal (CV2-based)
# --------------------------------------------------------------------------
def _ensure_grayscale(pil_image: Image.Image) -> np.ndarray:
"""Return uint8 grayscale numpy array from any PIL image."""
if pil_image.mode != "L":
pil_image = pil_image.convert("L")
return np.asarray(pil_image, dtype=np.uint8)
def remove_gridlines_and_background(
pil_image: Image.Image,
background_stretch: bool = True,
remove_gridlines: bool = True,
grid_min_length_frac: float = 0.30,
soft_threshold: int = 245,
) -> Tuple[Image.Image, Dict[str, object]]:
"""Normalize background to white and (optionally) remove thin gridlines.
Pipeline:
1. Convert to grayscale.
2. (background_stretch) Linearly stretch the gray histogram so the
brightest pixel is 255 (cancels colored / off-white backgrounds).
3. (remove_gridlines) Adaptive-threshold to a binary mask of dark
pixels (curve + axes + text + gridlines), then morphological
opening with very long horizontal `(1, K)` and vertical `(K, 1)`
kernels finds long thin lines; we inpaint those regions on the
grayscale image. The main curve survives because morphological
opening with a 1xK kernel only keeps strictly straight horizontal
runs of >=K dark pixels; a curving line breaks the connectivity.
4. (soft_threshold) Push pixels >= `soft_threshold` to pure 255 to
snap any residual near-white background to clean white.
Falls back to a pure-PIL background stretch if cv2 is unavailable.
Returns:
(cleaned_pil, meta) where meta has keys was_stretched,
was_cleaned, n_horiz_gridlines, n_vert_gridlines.
"""
meta: Dict[str, object] = {
"was_stretched": False,
"was_cleaned": False,
"n_horiz_gridlines": 0,
"n_vert_gridlines": 0,
}
arr = _ensure_grayscale(pil_image)
if background_stretch:
if arr.max() > 0:
scale = 255.0 / float(arr.max())
arr = np.clip(arr.astype(np.float32) * scale, 0, 255).astype(np.uint8)
meta["was_stretched"] = True
try:
import cv2
except ImportError:
if soft_threshold > 0:
arr = np.where(arr >= soft_threshold, 255, arr).astype(np.uint8)
return Image.fromarray(arr, mode="L"), meta
if remove_gridlines:
H, W = arr.shape
binary = cv2.adaptiveThreshold(
arr, 255,
cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY_INV,
blockSize=31, C=10,
)
K_h = max(20, int(W * grid_min_length_frac))
K_v = max(20, int(H * grid_min_length_frac))
h_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (K_h, 1))
h_lines = cv2.morphologyEx(binary, cv2.MORPH_OPEN, h_kernel)
v_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, K_v))
v_lines = cv2.morphologyEx(binary, cv2.MORPH_OPEN, v_kernel)
meta["n_horiz_gridlines"] = int((h_lines.sum(axis=1) > 0).sum())
meta["n_vert_gridlines"] = int((v_lines.sum(axis=0) > 0).sum())
line_mask = cv2.bitwise_or(h_lines, v_lines)
if line_mask.sum() > 0:
line_mask = cv2.dilate(line_mask, np.ones((2, 2), np.uint8))
arr = cv2.inpaint(arr, line_mask, 3, cv2.INPAINT_TELEA)
meta["was_cleaned"] = True
if soft_threshold > 0:
arr = np.where(arr >= soft_threshold, 255, arr).astype(np.uint8)
return Image.fromarray(arr, mode="L"), meta
# --------------------------------------------------------------------------
# Orchestrator
# --------------------------------------------------------------------------
def prepare_for_image_mode(
pil_image: Image.Image,
do_crop: bool = True,
do_clean: bool = True,
target_size: int = 224,
) -> Tuple[Image.Image, Dict[str, object]]:
"""Full preprocessing pipeline for image-mode SPARK.
Steps (any can be skipped):
crop_to_plot_region -> remove_gridlines_and_background -> resize.
Args:
pil_image: any-mode PIL.Image.
do_crop: run OCR-based plot-region cropping.
do_clean: run background normalization + gridline removal.
target_size: output square edge length.
Returns:
(preprocessed_pil_L, meta) where meta is a flat dict suitable for
showing in the UI:
was_cropped: bool
crop_bbox: (l, t, r, b) or None
was_stretched: bool
was_cleaned: bool
n_horiz_gridlines: int
n_vert_gridlines: int
target_size: int
"""
meta: Dict[str, object] = {
"was_cropped": False,
"crop_bbox": None,
"was_stretched": False,
"was_cleaned": False,
"n_horiz_gridlines": 0,
"n_vert_gridlines": 0,
"target_size": target_size,
}
img = pil_image
if do_crop:
cropped, bbox = crop_to_plot_region(img)
if bbox is not None:
img = cropped
meta["was_cropped"] = True
meta["crop_bbox"] = list(bbox)
if do_clean:
cleaned, clean_meta = remove_gridlines_and_background(img)
img = cleaned
meta["was_stretched"] = clean_meta["was_stretched"]
meta["was_cleaned"] = clean_meta["was_cleaned"]
meta["n_horiz_gridlines"] = clean_meta["n_horiz_gridlines"]
meta["n_vert_gridlines"] = clean_meta["n_vert_gridlines"]
else:
if img.mode != "L":
img = img.convert("L")
if img.size != (target_size, target_size):
img = img.resize((target_size, target_size), Image.BILINEAR)
return img, meta
__all__ = [
"crop_to_plot_region",
"remove_gridlines_and_background",
"prepare_for_image_mode",
]
|