File size: 19,067 Bytes
d7f1cd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
"""
Document Readability Scorer
============================
A multi-signal pre-screening system for document validation pipelines.
Scores documents on readability before expensive OCR/LLM inference.

Signals extracted (all normalized to 0-1, higher = better):
  1. Sharpness    β€” Laplacian variance + FFT high-freq energy
  2. Contrast     β€” RMS contrast + Michelson contrast
  3. Noise level  β€” Estimated noise sigma (inverted: low noise = high score)
  4. Text presence β€” MSER-based text region coverage + edge density
  5. Brightness   β€” Penalizes over/under-exposed documents
  6. Entropy      β€” Shannon entropy (blank pages score low)
  7. Learned IQA  β€” CLIP-IQA or BRISQUE via pyiqa (optional, GPU-free)

The composite "readability_score" is a weighted sum of these signals.
Weights are fully configurable for calibration to your pipeline.

Usage:
    scorer = DocumentReadabilityScorer()
    result = scorer.score("document.png")
    print(result["readability_score"])   # float in [0, 1]
    print(result["ocr_recommended"])     # bool
    print(result["signals"])             # dict of all sub-scores
"""

from __future__ import annotations

import warnings
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, Union

import cv2
import numpy as np
from PIL import Image
from scipy import ndimage
from skimage.filters import sobel
from skimage.measure import shannon_entropy

warnings.filterwarnings("ignore", category=UserWarning)


# ─── Configuration ───────────────────────────────────────────────────────────

@dataclass
class ScorerConfig:
    """Weights and thresholds for the readability scorer.
    
    All weights should sum to 1.0. Adjust these to calibrate
    the scorer for your specific document types.
    """
    # Signal weights (must sum to 1.0)
    w_sharpness: float = 0.30
    w_contrast: float = 0.15
    w_noise: float = 0.10
    w_text_presence: float = 0.15
    w_brightness: float = 0.05
    w_entropy: float = 0.10
    w_learned_iqa: float = 0.15

    # Decision threshold
    ocr_threshold: float = 0.45  # below this β†’ skip OCR

    # Normalization constants (tune per your doc distribution)
    laplacian_cap: float = 800.0   # laplacian var at which sharpness = 1.0
    noise_cap: float = 15.0        # noise sigma at which noise_score = 0.0
    min_text_coverage: float = 0.01  # below this β†’ likely blank

    # Learned metric to use (set to None to disable)
    learned_metric: Optional[str] = "clipiqa"  # "clipiqa", "brisque", "niqe", "topiq_nr", None
    
    # Whether to use GPU for learned metrics
    device: str = "cpu"

    def validate(self):
        total = (self.w_sharpness + self.w_contrast + self.w_noise +
                 self.w_text_presence + self.w_brightness + self.w_entropy +
                 self.w_learned_iqa)
        if abs(total - 1.0) > 0.01:
            raise ValueError(f"Weights must sum to 1.0, got {total:.3f}")


# ─── Signal Extractors ──────────────────────────────────────────────────────

def _load_gray(image: Union[str, Path, np.ndarray, Image.Image]) -> np.ndarray:
    """Load image as grayscale numpy array."""
    if isinstance(image, (str, Path)):
        img = cv2.imread(str(image))
        if img is None:
            raise FileNotFoundError(f"Cannot read image: {image}")
        return cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    elif isinstance(image, Image.Image):
        return np.array(image.convert("L"))
    elif isinstance(image, np.ndarray):
        if image.ndim == 3:
            return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        return image
    raise TypeError(f"Unsupported image type: {type(image)}")


def _load_color(image: Union[str, Path, np.ndarray, Image.Image]) -> np.ndarray:
    """Load image as BGR numpy array."""
    if isinstance(image, (str, Path)):
        img = cv2.imread(str(image))
        if img is None:
            raise FileNotFoundError(f"Cannot read image: {image}")
        return img
    elif isinstance(image, Image.Image):
        return cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
    elif isinstance(image, np.ndarray):
        return image
    raise TypeError(f"Unsupported image type: {type(image)}")


def sharpness_score(gray: np.ndarray, laplacian_cap: float = 800.0) -> dict:
    """
    Sharpness via Laplacian variance + FFT high-frequency energy.
    
    Laplacian variance: measures second-derivative magnitude.
      - Sharp document text: 200-2000+
      - Moderately blurry: 50-200
      - Very blurry: <50
    
    FFT energy ratio: fraction of spectral energy in high frequencies.
    """
    # Laplacian variance
    lap = cv2.Laplacian(gray, cv2.CV_64F)
    lap_var = float(lap.var())
    lap_norm = min(lap_var / laplacian_cap, 1.0)

    # FFT-based: ratio of high-freq energy to total energy
    h, w = gray.shape
    f = np.fft.fft2(gray.astype(np.float64))
    fshift = np.fft.fftshift(f)
    magnitude = np.abs(fshift)
    total_energy = magnitude.sum()

    # Create low-freq mask (center circle, radius = 5% of min dimension)
    cy, cx = h // 2, w // 2
    radius = int(min(h, w) * 0.05)
    Y, X = np.ogrid[:h, :w]
    low_freq_mask = ((Y - cy) ** 2 + (X - cx) ** 2) <= radius ** 2
    low_energy = magnitude[low_freq_mask].sum()
    high_freq_ratio = float(1.0 - low_energy / (total_energy + 1e-10))

    # Combined sharpness: 70% Laplacian + 30% FFT
    combined = 0.7 * lap_norm + 0.3 * high_freq_ratio

    return {
        "sharpness": float(np.clip(combined, 0, 1)),
        "laplacian_variance": lap_var,
        "high_freq_ratio": high_freq_ratio,
    }


def contrast_score(gray: np.ndarray) -> dict:
    """
    Contrast via RMS and Michelson metrics.
    
    Good documents have RMS contrast ~0.2-0.5 (black text on white).
    Washed-out or very dark scans have low contrast.
    """
    # RMS contrast
    rms = float(gray.std() / 255.0)

    # Michelson contrast
    i_max, i_min = float(gray.max()), float(gray.min())
    michelson = (i_max - i_min) / (i_max + i_min + 1e-10)

    # Normalize: RMS of 0.25+ is good for documents
    rms_norm = min(rms / 0.30, 1.0)
    mich_norm = michelson  # already in [0, 1]

    combined = 0.6 * rms_norm + 0.4 * mich_norm

    return {
        "contrast": float(np.clip(combined, 0, 1)),
        "rms_contrast": rms,
        "michelson_contrast": float(michelson),
    }


def noise_score(gray: np.ndarray, noise_cap: float = 15.0) -> dict:
    """
    Noise estimation via Immerkær (1996) method.
    Uses a 3x3 Laplacian kernel on the image to isolate high-frequency noise.
    
    Clean documents: sigma < 3
    Noisy scans: sigma 5-15
    Very noisy: sigma > 15
    """
    H = np.array([[1, -2, 1], [-2, 4, -2], [1, -2, 1]], dtype=np.float64)
    filtered = ndimage.convolve(gray.astype(np.float64), H)
    sigma = float(np.abs(filtered).mean() * np.sqrt(np.pi / 2) / 6.0)

    # Invert: low noise = high score
    noise_norm = 1.0 - min(sigma / noise_cap, 1.0)

    return {
        "noise": float(np.clip(noise_norm, 0, 1)),
        "noise_sigma": sigma,
    }


def text_presence_score(gray: np.ndarray, min_coverage: float = 0.01) -> dict:
    """
    Text presence via MSER regions + edge density.
    
    MSER (Maximally Stable Extremal Regions) detects text-like blobs.
    Edge density via Sobel measures structural content.
    """
    # MSER text region detection
    mser = cv2.MSER_create()
    mser.setDelta(5)
    mser.setMinArea(30)
    mser.setMaxArea(int(gray.size * 0.05))
    mser.setMaxVariation(0.25)
    try:
        regions, _ = mser.detectRegions(gray)
    except cv2.error:
        regions = []

    if regions:
        mask = np.zeros_like(gray)
        for r in regions:
            hull = cv2.convexHull(r.reshape(-1, 1, 2))
            cv2.fillPoly(mask, [hull], 255)
        text_coverage = float(mask.sum() / (255.0 * mask.size))
    else:
        text_coverage = 0.0

    # Edge density via Sobel
    gray_float = gray.astype(np.float64) / 255.0
    edges = sobel(gray_float)
    edge_density = float(edges.mean())

    # Normalize: coverage >5% is good, edges >0.05 is good
    cov_norm = min(text_coverage / 0.10, 1.0)
    edge_norm = min(edge_density / 0.08, 1.0)

    combined = 0.5 * cov_norm + 0.5 * edge_norm
    has_text = text_coverage > min_coverage or edge_density > 0.02

    return {
        "text_presence": float(np.clip(combined, 0, 1)),
        "text_coverage": text_coverage,
        "edge_density": edge_density,
        "has_text": has_text,
    }


def brightness_score(gray: np.ndarray) -> dict:
    """
    Brightness assessment β€” penalizes over/under-exposure.
    
    Ideal document: mean brightness ~160-245 (white paper, dark text).
    Score drops for very dark (<80) or fully saturated (==255 everywhere).
    
    Note: Documents naturally have many white pixels (paper background).
    White paper with mean brightness ~240-250 is normal and good.
    """
    mean_brightness = float(gray.mean())

    # Fraction of truly problematic pixels
    dark_frac = float((gray < 15).sum() / gray.size)       # crushed to black
    pure_white_frac = float((gray == 255).sum() / gray.size)  # fully saturated

    # Score mapping for documents:
    #   Very dark (<60): bad
    #   Dim (60-140): mediocre
    #   Normal (140-250): good (peak at 200-220, but 240-250 is still fine)
    #   Pure white (>252): suspicious
    if mean_brightness < 60:
        bright_norm = mean_brightness / 60.0 * 0.3
    elif mean_brightness < 140:
        bright_norm = 0.3 + (mean_brightness - 60) / 80.0 * 0.5
    elif mean_brightness <= 250:
        # Wide sweet spot for documents: 140-250 is all good
        # Peak at 200, but gentle falloff
        dist_from_ideal = abs(mean_brightness - 200) / 60.0
        bright_norm = 1.0 - dist_from_ideal * 0.2  # at 250: 0.83, at 140: 0.80
    else:
        # Over 250 β€” nearly blank white
        bright_norm = max(0.4, 1.0 - (mean_brightness - 250) / 5.0)

    # Only penalize if image is mostly crushed blacks or ALL pure white
    # (pure_white_frac of 0.9 on a text doc is fine β€” paper is white)
    exposure_penalty = min(dark_frac * 3 + max(0, pure_white_frac - 0.95) * 5, 0.5)
    bright_norm = max(0, bright_norm - exposure_penalty)

    return {
        "brightness": float(np.clip(bright_norm, 0, 1)),
        "mean_brightness": mean_brightness,
        "dark_pixel_frac": dark_frac,
        "bright_pixel_frac": pure_white_frac,
    }


def entropy_score(gray: np.ndarray) -> dict:
    """
    Shannon entropy β€” measures information content.
    
    Blank/uniform pages: entropy ~0-3
    Text documents: entropy ~5-7
    Complex images: entropy ~7-8
    """
    ent = float(shannon_entropy(gray))

    # Normalize: entropy of 4+ is good for documents (lower threshold than natural images)
    # Blank page: ~0-2, simple doc: 3-5, rich doc: 5-7
    ent_norm = min(ent / 5.5, 1.0)

    return {
        "entropy": float(np.clip(ent_norm, 0, 1)),
        "shannon_entropy": ent,
    }


# ─── Learned IQA (optional) ─────────────────────────────────────────────────

_iqa_cache: dict = {}

def learned_iqa_score(
    image: Union[str, Path, np.ndarray, Image.Image],
    metric_name: str = "clipiqa",
    device: str = "cpu",
) -> dict:
    """
    Learned no-reference IQA via pyiqa library.
    
    Supported metrics (all run on CPU):
      - clipiqa: CLIP-IQA (0-1, higher=better)
      - brisque: BRISQUE (0-100, lower=better, we invert)
      - niqe: NIQE (lower=better, we invert)
      - topiq_nr: TOPIQ-NR (0-1, higher=better)
    """
    import torch
    import pyiqa

    cache_key = f"{metric_name}_{device}"
    if cache_key not in _iqa_cache:
        _iqa_cache[cache_key] = pyiqa.create_metric(metric_name, device=device)
    
    metric = _iqa_cache[cache_key]
    lower_better = metric.lower_better

    # Convert to tensor
    if isinstance(image, (str, Path)):
        pil_img = Image.open(str(image)).convert("RGB")
    elif isinstance(image, np.ndarray):
        if image.ndim == 2:
            pil_img = Image.fromarray(image).convert("RGB")
        else:
            pil_img = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
    elif isinstance(image, Image.Image):
        pil_img = image.convert("RGB")
    else:
        raise TypeError(f"Unsupported type: {type(image)}")
    
    # Resize for speed (IQA doesn't need full resolution)
    max_dim = 512
    w, h = pil_img.size
    if max(w, h) > max_dim:
        scale = max_dim / max(w, h)
        pil_img = pil_img.resize((int(w * scale), int(h * scale)), Image.LANCZOS)

    img_tensor = torch.from_numpy(
        np.array(pil_img).transpose(2, 0, 1)
    ).float().unsqueeze(0) / 255.0
    img_tensor = img_tensor.to(device)

    with torch.no_grad():
        raw_score = float(metric(img_tensor).item())

    # Normalize to [0, 1] higher=better
    if lower_better:
        if metric_name == "brisque":
            normalized = float(np.clip(1.0 - raw_score / 100.0, 0, 1))
        elif metric_name == "niqe":
            normalized = float(np.clip(1.0 - raw_score / 20.0, 0, 1))
        else:
            normalized = float(np.clip(1.0 - raw_score / 50.0, 0, 1))
    else:
        normalized = float(np.clip(raw_score, 0, 1))

    return {
        "learned_iqa": normalized,
        f"{metric_name}_raw": raw_score,
        "metric_name": metric_name,
    }


# ─── Main Scorer ─────────────────────────────────────────────────────────────

@dataclass
class ReadabilityResult:
    """Complete readability assessment for a document image."""
    readability_score: float          # Composite score [0, 1]
    ocr_recommended: bool             # Whether to proceed with OCR
    confidence_label: str             # "excellent" / "good" / "fair" / "poor" / "bad"
    signals: dict                     # All individual signal scores and raw values
    config: dict                      # Config used for this scoring

    def to_dict(self) -> dict:
        return {
            "readability_score": self.readability_score,
            "ocr_recommended": self.ocr_recommended,
            "confidence_label": self.confidence_label,
            "signals": self.signals,
        }


class DocumentReadabilityScorer:
    """
    Multi-signal document readability scorer.
    
    Example:
        scorer = DocumentReadabilityScorer()
        result = scorer.score("scan.pdf")
        if result.ocr_recommended:
            run_ocr(...)
        else:
            log_rejected(result.signals)
    """

    def __init__(self, config: Optional[ScorerConfig] = None):
        self.config = config or ScorerConfig()
        self.config.validate()

    def score(
        self,
        image: Union[str, Path, np.ndarray, Image.Image],
    ) -> ReadabilityResult:
        """
        Score a document image for readability.
        
        Args:
            image: File path, numpy array (BGR or gray), or PIL Image.
            
        Returns:
            ReadabilityResult with composite score, sub-signals, and recommendation.
        """
        cfg = self.config
        gray = _load_gray(image)

        # Extract all classical signals
        sharp = sharpness_score(gray, cfg.laplacian_cap)
        cont = contrast_score(gray)
        noi = noise_score(gray, cfg.noise_cap)
        text = text_presence_score(gray, cfg.min_text_coverage)
        bright = brightness_score(gray)
        ent = entropy_score(gray)

        # Optional learned IQA
        if cfg.learned_metric:
            try:
                iqa = learned_iqa_score(image, cfg.learned_metric, cfg.device)
            except Exception as e:
                # Fall back gracefully β€” redistribute weight to sharpness
                iqa = {"learned_iqa": 0.5, "error": str(e), "metric_name": cfg.learned_metric}
        else:
            iqa = {"learned_iqa": 0.5, "metric_name": "disabled"}

        # Composite score
        composite = (
            cfg.w_sharpness * sharp["sharpness"] +
            cfg.w_contrast * cont["contrast"] +
            cfg.w_noise * noi["noise"] +
            cfg.w_text_presence * text["text_presence"] +
            cfg.w_brightness * bright["brightness"] +
            cfg.w_entropy * ent["entropy"] +
            cfg.w_learned_iqa * iqa["learned_iqa"]
        )
        composite = float(np.clip(composite, 0, 1))

        # Label
        if composite >= 0.80:
            label = "excellent"
        elif composite >= 0.60:
            label = "good"
        elif composite >= 0.40:
            label = "fair"
        elif composite >= 0.20:
            label = "poor"
        else:
            label = "bad"

        # Merge all signals
        signals = {}
        for d in [sharp, cont, noi, text, bright, ent, iqa]:
            signals.update(d)

        return ReadabilityResult(
            readability_score=round(composite, 4),
            ocr_recommended=composite >= cfg.ocr_threshold,
            confidence_label=label,
            signals=signals,
            config={
                "weights": {
                    "sharpness": cfg.w_sharpness,
                    "contrast": cfg.w_contrast,
                    "noise": cfg.w_noise,
                    "text_presence": cfg.w_text_presence,
                    "brightness": cfg.w_brightness,
                    "entropy": cfg.w_entropy,
                    "learned_iqa": cfg.w_learned_iqa,
                },
                "ocr_threshold": cfg.ocr_threshold,
                "learned_metric": cfg.learned_metric or "disabled",
            },
        )


# ─── Batch processing helper ─────────────────────────────────────────────────

def score_batch(
    image_paths: list[Union[str, Path]],
    config: Optional[ScorerConfig] = None,
    sort_by_score: bool = True,
) -> list[dict]:
    """Score a batch of documents and optionally sort by readability."""
    scorer = DocumentReadabilityScorer(config)
    results = []
    for path in image_paths:
        try:
            result = scorer.score(path)
            results.append({
                "path": str(path),
                **result.to_dict(),
            })
        except Exception as e:
            results.append({
                "path": str(path),
                "readability_score": 0.0,
                "ocr_recommended": False,
                "confidence_label": "error",
                "error": str(e),
            })
    
    if sort_by_score:
        results.sort(key=lambda x: x["readability_score"], reverse=True)
    
    return results