Spaces:
Running on Zero
Running on Zero
File size: 4,398 Bytes
7f9dfed | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 | from __future__ import annotations
import csv
import json
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any
from datasets.field_notes import FieldNote, FieldNoteStore
@dataclass(frozen=True)
class OCRPrediction:
"""One local OCR prediction that may need human correction."""
source_path: str
text: str
confidence: float
page: str = ""
@classmethod
def from_row(cls, row: dict[str, Any]) -> OCRPrediction:
source_path = _first_present(row, ["source_path", "image_path", "file_path", "path"])
text = _first_present(row, ["text", "prediction", "ocr_text", "response"])
confidence_raw = _first_present(row, ["confidence", "score", "probability"], "0")
return cls(
source_path=source_path,
text=text,
confidence=_parse_confidence(confidence_raw),
page=str(row.get("page", "")),
)
def to_dict(self) -> dict[str, object]:
return asdict(self)
def load_ocr_predictions(path: str | Path) -> list[OCRPrediction]:
"""Load OCR predictions from local CSV, JSONL, or NDJSON files."""
source = Path(path)
if not source.exists():
raise FileNotFoundError(f"OCR prediction file not found: {source}")
suffix = source.suffix.lower()
if suffix == ".csv":
with source.open(newline="", encoding="utf-8") as f:
return [OCRPrediction.from_row(row) for row in csv.DictReader(f)]
if suffix in {".jsonl", ".ndjson"}:
return [
OCRPrediction.from_row(json.loads(line))
for line in source.read_text(encoding="utf-8").splitlines()
if line.strip()
]
raise ValueError("OCR predictions must be a .csv, .jsonl, or .ndjson file.")
def uncertain_predictions(
predictions: list[OCRPrediction],
confidence_threshold: float = 0.8,
) -> list[OCRPrediction]:
return [
prediction
for prediction in predictions
if prediction.confidence <= confidence_threshold or not prediction.text.strip()
]
def import_uncertain_predictions(
store: FieldNoteStore,
predictions: list[OCRPrediction],
model_id: str,
confidence_threshold: float = 0.8,
tags: str = "ocr,uncertain",
) -> int:
imported = 0
for prediction in uncertain_predictions(predictions, confidence_threshold):
page_note = f" page {prediction.page}" if prediction.page else ""
note = FieldNote.create(
model_id=model_id,
prompt=f"Review OCR text for {prediction.source_path}{page_note}.",
response=prediction.text,
correction="",
tags=tags,
image_path=prediction.source_path,
use_for_training=False,
)
store.save(note)
imported += 1
return imported
def export_corrected_ocr_notes(
store: FieldNoteStore,
output_path: str | Path = "data/ocr_corrections.jsonl",
) -> Path:
output = Path(output_path)
output.parent.mkdir(parents=True, exist_ok=True)
notes = store.list_notes(corrected_only=True, tag="ocr")
with output.open("w", encoding="utf-8") as f:
for note in notes:
row = {
"source_path": note.image_path,
"predicted_text": note.response,
"corrected_text": note.correction,
"model_id": note.model_id,
"created_at": note.created_at,
"tags": note.tags,
}
f.write(json.dumps(row, ensure_ascii=False) + "\n")
return output
def ocr_import_summary(path: str | Path, confidence_threshold: float = 0.8) -> dict[str, object]:
predictions = load_ocr_predictions(path)
uncertain = uncertain_predictions(predictions, confidence_threshold)
return {
"source": str(path),
"rows": len(predictions),
"uncertain_rows": len(uncertain),
"confidence_threshold": confidence_threshold,
"sample": [prediction.to_dict() for prediction in uncertain[:5]],
}
def _first_present(row: dict[str, Any], names: list[str], default: str = "") -> str:
for name in names:
value = row.get(name)
if value is not None:
return str(value)
return default
def _parse_confidence(value: str) -> float:
try:
return float(value)
except ValueError:
return 0.0
|