workbench / datasets /ocr.py
GitHub Actions
Initial ZeroGPU deployment with spaces shim
7f9dfed
Raw
History Blame Contribute Delete
4.4 kB
from __future__ import annotations
import csv
import json
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any
from datasets.field_notes import FieldNote, FieldNoteStore
@dataclass(frozen=True)
class OCRPrediction:
"""One local OCR prediction that may need human correction."""
source_path: str
text: str
confidence: float
page: str = ""
@classmethod
def from_row(cls, row: dict[str, Any]) -> OCRPrediction:
source_path = _first_present(row, ["source_path", "image_path", "file_path", "path"])
text = _first_present(row, ["text", "prediction", "ocr_text", "response"])
confidence_raw = _first_present(row, ["confidence", "score", "probability"], "0")
return cls(
source_path=source_path,
text=text,
confidence=_parse_confidence(confidence_raw),
page=str(row.get("page", "")),
)
def to_dict(self) -> dict[str, object]:
return asdict(self)
def load_ocr_predictions(path: str | Path) -> list[OCRPrediction]:
"""Load OCR predictions from local CSV, JSONL, or NDJSON files."""
source = Path(path)
if not source.exists():
raise FileNotFoundError(f"OCR prediction file not found: {source}")
suffix = source.suffix.lower()
if suffix == ".csv":
with source.open(newline="", encoding="utf-8") as f:
return [OCRPrediction.from_row(row) for row in csv.DictReader(f)]
if suffix in {".jsonl", ".ndjson"}:
return [
OCRPrediction.from_row(json.loads(line))
for line in source.read_text(encoding="utf-8").splitlines()
if line.strip()
]
raise ValueError("OCR predictions must be a .csv, .jsonl, or .ndjson file.")
def uncertain_predictions(
predictions: list[OCRPrediction],
confidence_threshold: float = 0.8,
) -> list[OCRPrediction]:
return [
prediction
for prediction in predictions
if prediction.confidence <= confidence_threshold or not prediction.text.strip()
]
def import_uncertain_predictions(
store: FieldNoteStore,
predictions: list[OCRPrediction],
model_id: str,
confidence_threshold: float = 0.8,
tags: str = "ocr,uncertain",
) -> int:
imported = 0
for prediction in uncertain_predictions(predictions, confidence_threshold):
page_note = f" page {prediction.page}" if prediction.page else ""
note = FieldNote.create(
model_id=model_id,
prompt=f"Review OCR text for {prediction.source_path}{page_note}.",
response=prediction.text,
correction="",
tags=tags,
image_path=prediction.source_path,
use_for_training=False,
)
store.save(note)
imported += 1
return imported
def export_corrected_ocr_notes(
store: FieldNoteStore,
output_path: str | Path = "data/ocr_corrections.jsonl",
) -> Path:
output = Path(output_path)
output.parent.mkdir(parents=True, exist_ok=True)
notes = store.list_notes(corrected_only=True, tag="ocr")
with output.open("w", encoding="utf-8") as f:
for note in notes:
row = {
"source_path": note.image_path,
"predicted_text": note.response,
"corrected_text": note.correction,
"model_id": note.model_id,
"created_at": note.created_at,
"tags": note.tags,
}
f.write(json.dumps(row, ensure_ascii=False) + "\n")
return output
def ocr_import_summary(path: str | Path, confidence_threshold: float = 0.8) -> dict[str, object]:
predictions = load_ocr_predictions(path)
uncertain = uncertain_predictions(predictions, confidence_threshold)
return {
"source": str(path),
"rows": len(predictions),
"uncertain_rows": len(uncertain),
"confidence_threshold": confidence_threshold,
"sample": [prediction.to_dict() for prediction in uncertain[:5]],
}
def _first_present(row: dict[str, Any], names: list[str], default: str = "") -> str:
for name in names:
value = row.get(name)
if value is not None:
return str(value)
return default
def _parse_confidence(value: str) -> float:
try:
return float(value)
except ValueError:
return 0.0