Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import csv | |
| import json | |
| from dataclasses import asdict, dataclass | |
| from pathlib import Path | |
| from typing import Any | |
| from datasets.field_notes import FieldNote, FieldNoteStore | |
| class OCRPrediction: | |
| """One local OCR prediction that may need human correction.""" | |
| source_path: str | |
| text: str | |
| confidence: float | |
| page: str = "" | |
| def from_row(cls, row: dict[str, Any]) -> OCRPrediction: | |
| source_path = _first_present(row, ["source_path", "image_path", "file_path", "path"]) | |
| text = _first_present(row, ["text", "prediction", "ocr_text", "response"]) | |
| confidence_raw = _first_present(row, ["confidence", "score", "probability"], "0") | |
| return cls( | |
| source_path=source_path, | |
| text=text, | |
| confidence=_parse_confidence(confidence_raw), | |
| page=str(row.get("page", "")), | |
| ) | |
| def to_dict(self) -> dict[str, object]: | |
| return asdict(self) | |
| def load_ocr_predictions(path: str | Path) -> list[OCRPrediction]: | |
| """Load OCR predictions from local CSV, JSONL, or NDJSON files.""" | |
| source = Path(path) | |
| if not source.exists(): | |
| raise FileNotFoundError(f"OCR prediction file not found: {source}") | |
| suffix = source.suffix.lower() | |
| if suffix == ".csv": | |
| with source.open(newline="", encoding="utf-8") as f: | |
| return [OCRPrediction.from_row(row) for row in csv.DictReader(f)] | |
| if suffix in {".jsonl", ".ndjson"}: | |
| return [ | |
| OCRPrediction.from_row(json.loads(line)) | |
| for line in source.read_text(encoding="utf-8").splitlines() | |
| if line.strip() | |
| ] | |
| raise ValueError("OCR predictions must be a .csv, .jsonl, or .ndjson file.") | |
| def uncertain_predictions( | |
| predictions: list[OCRPrediction], | |
| confidence_threshold: float = 0.8, | |
| ) -> list[OCRPrediction]: | |
| return [ | |
| prediction | |
| for prediction in predictions | |
| if prediction.confidence <= confidence_threshold or not prediction.text.strip() | |
| ] | |
| def import_uncertain_predictions( | |
| store: FieldNoteStore, | |
| predictions: list[OCRPrediction], | |
| model_id: str, | |
| confidence_threshold: float = 0.8, | |
| tags: str = "ocr,uncertain", | |
| ) -> int: | |
| imported = 0 | |
| for prediction in uncertain_predictions(predictions, confidence_threshold): | |
| page_note = f" page {prediction.page}" if prediction.page else "" | |
| note = FieldNote.create( | |
| model_id=model_id, | |
| prompt=f"Review OCR text for {prediction.source_path}{page_note}.", | |
| response=prediction.text, | |
| correction="", | |
| tags=tags, | |
| image_path=prediction.source_path, | |
| use_for_training=False, | |
| ) | |
| store.save(note) | |
| imported += 1 | |
| return imported | |
| def export_corrected_ocr_notes( | |
| store: FieldNoteStore, | |
| output_path: str | Path = "data/ocr_corrections.jsonl", | |
| ) -> Path: | |
| output = Path(output_path) | |
| output.parent.mkdir(parents=True, exist_ok=True) | |
| notes = store.list_notes(corrected_only=True, tag="ocr") | |
| with output.open("w", encoding="utf-8") as f: | |
| for note in notes: | |
| row = { | |
| "source_path": note.image_path, | |
| "predicted_text": note.response, | |
| "corrected_text": note.correction, | |
| "model_id": note.model_id, | |
| "created_at": note.created_at, | |
| "tags": note.tags, | |
| } | |
| f.write(json.dumps(row, ensure_ascii=False) + "\n") | |
| return output | |
| def ocr_import_summary(path: str | Path, confidence_threshold: float = 0.8) -> dict[str, object]: | |
| predictions = load_ocr_predictions(path) | |
| uncertain = uncertain_predictions(predictions, confidence_threshold) | |
| return { | |
| "source": str(path), | |
| "rows": len(predictions), | |
| "uncertain_rows": len(uncertain), | |
| "confidence_threshold": confidence_threshold, | |
| "sample": [prediction.to_dict() for prediction in uncertain[:5]], | |
| } | |
| def _first_present(row: dict[str, Any], names: list[str], default: str = "") -> str: | |
| for name in names: | |
| value = row.get(name) | |
| if value is not None: | |
| return str(value) | |
| return default | |
| def _parse_confidence(value: str) -> float: | |
| try: | |
| return float(value) | |
| except ValueError: | |
| return 0.0 | |