Spaces:
Running
Running
Claude
fix: résoudre les 64 erreurs ruff pré-existantes révélées par le lint actif
6362212 unverified | """Taxonomie des erreurs OCR — classification automatique (classes 1 à 9). | |
| Chaque erreur identifiée par l'alignement GT↔OCR est catégorisée selon | |
| la taxonomie Picarones : | |
| | Classe | Nom | Description | | |
| |--------|-------------------|----------------------------------------------------| | |
| | 1 | visual_confusion | Confusion morphologique (rn/m, l/1, O/0, u/n…) | | |
| | 2 | diacritic_error | Diacritique absent, incorrect ou ajouté | | |
| | 3 | case_error | Erreur de casse uniquement (A/a) | | |
| | 4 | ligature_error | Ligature non résolue ou mal résolue | | |
| | 5 | abbreviation_error| Abréviation médiévale non développée | | |
| | 6 | hapax | Mot introuvable dans tout lexique | | |
| | 7 | segmentation_error| Fusion ou fragmentation de tokens (mots/lignes) | | |
| | 8 | oov_character | Caractère hors-vocabulaire du moteur | | |
| | 9 | lacuna | Texte présent dans le GT absent de l'OCR | | |
| | 10 | over_normalization| Sur-normalisation LLM (voir pipelines/) | | |
| Note : la classe 10 est calculée par picarones/pipelines/over_normalization.py. | |
| """ | |
| from __future__ import annotations | |
| import difflib | |
| import unicodedata | |
| from dataclasses import dataclass, field | |
| # --------------------------------------------------------------------------- | |
| # Tables de référence pour la classification | |
| # --------------------------------------------------------------------------- | |
| #: Confusions visuelles bien connues en OCR (caractères morphologiquement proches) | |
| VISUAL_CONFUSIONS: dict[frozenset, str] = {} | |
| _VISUAL_PAIRS: list[tuple[str, str]] = [ | |
| # Minuscules | |
| ("r", "n"), ("rn", "m"), ("l", "1"), ("l", "i"), ("l", "|"), | |
| ("O", "0"), ("O", "o"), ("u", "n"), ("n", "u"), ("v", "u"), | |
| ("c", "e"), ("e", "c"), ("a", "o"), ("o", "a"), | |
| ("f", "ſ"), ("ſ", "f"), ("f", "t"), | |
| ("h", "li"), ("h", "lı"), | |
| ("m", "rn"), ("m", "in"), | |
| ("d", "cl"), ("d", "a"), | |
| ("q", "g"), ("p", "q"), | |
| # Majuscules ↔ minuscules homographes (classe 1, pas classe 3) | |
| ("I", "l"), ("I", "1"), | |
| # Chiffres | |
| ("1", "I"), ("1", "l"), ("0", "O"), | |
| # Ponctuation | |
| (".", ","), (",", "."), | |
| ] | |
| for _a, _b in _VISUAL_PAIRS: | |
| VISUAL_CONFUSIONS[frozenset({_a, _b})] = f"{_a}/{_b}" | |
| #: Couples de ligatures pour la détection des erreurs de ligatures | |
| from picarones.core.char_scores import LIGATURE_TABLE, DIACRITIC_MAP # noqa: E402 | |
| # Caractères hors-ASCII présumés hors-vocabulaire (alphabet non latin de base) | |
| _LATIN_BASIC = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" | |
| " \t\n.,;:!?-_'\"«»()[]{}/@#%&*+=/\\|<>~^") | |
| # --------------------------------------------------------------------------- | |
| # Résultat structuré | |
| # --------------------------------------------------------------------------- | |
| class TaxonomyResult: | |
| """Résultat de la classification taxonomique des erreurs pour un document.""" | |
| counts: dict[str, int] = field(default_factory=dict) | |
| """Nombre d'erreurs par classe. Clés : 'visual_confusion', 'diacritic_error'…""" | |
| examples: dict[str, list[dict]] = field(default_factory=dict) | |
| """Exemples d'erreurs par classe (max 5 par classe). | |
| Format : [{'gt': 'chaîne', 'ocr': 'chaîne', 'position': int}] | |
| """ | |
| total_errors: int = 0 | |
| """Nombre total d'erreurs classifiées.""" | |
| def class_distribution(self) -> dict[str, float]: | |
| """Distribution relative (0–1) par classe.""" | |
| if not self.total_errors: | |
| return {} | |
| return { | |
| cls: round(cnt / self.total_errors, 4) | |
| for cls, cnt in self.counts.items() | |
| } | |
| def as_dict(self) -> dict: | |
| return { | |
| "counts": self.counts, | |
| "total_errors": self.total_errors, | |
| "class_distribution": self.class_distribution, | |
| "examples": { | |
| cls: exs[:3] for cls, exs in self.examples.items() | |
| }, | |
| } | |
| def from_dict(cls, data: dict) -> "TaxonomyResult": | |
| return cls( | |
| counts=data.get("counts", {}), | |
| examples=data.get("examples", {}), | |
| total_errors=data.get("total_errors", 0), | |
| ) | |
| # Noms des classes en ordre | |
| ERROR_CLASSES = [ | |
| "visual_confusion", | |
| "diacritic_error", | |
| "case_error", | |
| "ligature_error", | |
| "abbreviation_error", | |
| "hapax", | |
| "segmentation_error", | |
| "oov_character", | |
| "lacuna", | |
| ] | |
| # --------------------------------------------------------------------------- | |
| # Classification principale | |
| # --------------------------------------------------------------------------- | |
| def classify_errors( | |
| ground_truth: str, | |
| hypothesis: str, | |
| max_examples: int = 5, | |
| ) -> TaxonomyResult: | |
| """Classifie automatiquement les erreurs OCR dans une paire GT/OCR. | |
| L'alignement utilise difflib.SequenceMatcher au niveau mot pour détecter | |
| les erreurs de segmentation, puis au niveau caractère pour les autres classes. | |
| Parameters | |
| ---------- | |
| ground_truth: | |
| Texte de référence (vérité terrain). | |
| hypothesis: | |
| Texte produit par l'OCR. | |
| max_examples: | |
| Nombre maximal d'exemples conservés par classe. | |
| Returns | |
| ------- | |
| TaxonomyResult | |
| """ | |
| counts: dict[str, int] = {cls: 0 for cls in ERROR_CLASSES} | |
| examples: dict[str, list[dict]] = {cls: [] for cls in ERROR_CLASSES} | |
| total = 0 | |
| if not ground_truth and not hypothesis: | |
| return TaxonomyResult(counts=counts, examples=examples, total_errors=0) | |
| # ----------------------------------------------------------------------- | |
| # Niveau mot : détecter segmentation (classe 7) et lacunes (classe 9) | |
| # ----------------------------------------------------------------------- | |
| gt_words = ground_truth.split() | |
| hyp_words = hypothesis.split() | |
| word_matcher = difflib.SequenceMatcher(None, gt_words, hyp_words, autojunk=False) | |
| for tag, i1, i2, j1, j2 in word_matcher.get_opcodes(): | |
| if tag == "delete": | |
| # Mots GT absents de l'OCR → lacune (classe 9) | |
| for w in gt_words[i1:i2]: | |
| counts["lacuna"] += 1 | |
| total += 1 | |
| if len(examples["lacuna"]) < max_examples: | |
| examples["lacuna"].append({"gt": w, "ocr": "", "position": i1}) | |
| elif tag == "insert": | |
| # Mots ajoutés par l'OCR → généralement classe 8 (hors-vocab) | |
| for w in hyp_words[j1:j2]: | |
| if _is_oov_word(w): | |
| counts["oov_character"] += 1 | |
| total += 1 | |
| elif tag == "replace": | |
| gt_seg = gt_words[i1:i2] | |
| hyp_seg = hyp_words[j1:j2] | |
| # Segmentation : fusion de mots (moins de mots OCR) ou fragmentation | |
| if len(hyp_seg) != len(gt_seg): | |
| n_seg = abs(len(gt_seg) - len(hyp_seg)) | |
| counts["segmentation_error"] += n_seg | |
| total += n_seg | |
| if len(examples["segmentation_error"]) < max_examples: | |
| examples["segmentation_error"].append({ | |
| "gt": " ".join(gt_seg), | |
| "ocr": " ".join(hyp_seg), | |
| "position": i1, | |
| }) | |
| else: | |
| # Paires mot-à-mot | |
| for gt_w, hyp_w in zip(gt_seg, hyp_seg): | |
| if gt_w != hyp_w: | |
| _classify_word_error( | |
| gt_w, hyp_w, counts, examples, max_examples | |
| ) | |
| total += 1 | |
| return TaxonomyResult( | |
| counts=counts, | |
| examples=examples, | |
| total_errors=total, | |
| ) | |
| def _classify_word_error( | |
| gt_word: str, | |
| hyp_word: str, | |
| counts: dict[str, int], | |
| examples: dict[str, list[dict]], | |
| max_examples: int, | |
| ) -> None: | |
| """Classifie l'erreur entre deux mots non-identiques.""" | |
| # Classe 3 : erreur de casse seule | |
| if gt_word.casefold() == hyp_word.casefold() and gt_word != hyp_word: | |
| counts["case_error"] += 1 | |
| if len(examples["case_error"]) < max_examples: | |
| examples["case_error"].append({"gt": gt_word, "ocr": hyp_word}) | |
| return | |
| # Classe 4 : erreur de ligature | |
| gt_norm = unicodedata.normalize("NFC", gt_word) | |
| hyp_norm = unicodedata.normalize("NFC", hyp_word) | |
| if _is_ligature_error(gt_norm, hyp_norm): | |
| counts["ligature_error"] += 1 | |
| if len(examples["ligature_error"]) < max_examples: | |
| examples["ligature_error"].append({"gt": gt_word, "ocr": hyp_word}) | |
| return | |
| # Classe 5 : erreur d'abréviation (présence de ꝑ, ꝓ, ꝗ dans le GT) | |
| if _is_abbreviation_error(gt_norm, hyp_norm): | |
| counts["abbreviation_error"] += 1 | |
| if len(examples["abbreviation_error"]) < max_examples: | |
| examples["abbreviation_error"].append({"gt": gt_word, "ocr": hyp_word}) | |
| return | |
| # Classe 2 : erreur diacritique | |
| if _is_diacritic_error(gt_norm, hyp_norm): | |
| counts["diacritic_error"] += 1 | |
| if len(examples["diacritic_error"]) < max_examples: | |
| examples["diacritic_error"].append({"gt": gt_word, "ocr": hyp_word}) | |
| return | |
| # Classe 1 : confusion visuelle (comparaison char par char) | |
| if _is_visual_confusion(gt_norm, hyp_norm): | |
| counts["visual_confusion"] += 1 | |
| if len(examples["visual_confusion"]) < max_examples: | |
| examples["visual_confusion"].append({"gt": gt_word, "ocr": hyp_word}) | |
| return | |
| # Classe 8 : caractère hors-vocabulaire | |
| if _is_oov_word(hyp_word): | |
| counts["oov_character"] += 1 | |
| if len(examples["oov_character"]) < max_examples: | |
| examples["oov_character"].append({"gt": gt_word, "ocr": hyp_word}) | |
| return | |
| # Classe 6 : hapax (erreur résiduelle non classifiable) | |
| counts["hapax"] += 1 | |
| if len(examples["hapax"]) < max_examples: | |
| examples["hapax"].append({"gt": gt_word, "ocr": hyp_word}) | |
| def _is_ligature_error(gt: str, hyp: str) -> bool: | |
| """Vrai si la différence implique une ligature Unicode.""" | |
| # GT contient une ligature que l'OCR a décomposée, ou vice versa | |
| for lig, seqs in LIGATURE_TABLE.items(): | |
| if lig in gt: | |
| for seq in seqs: | |
| if seq in hyp and lig not in hyp: | |
| return True | |
| for seq in seqs: | |
| if seq in gt and lig in hyp: | |
| return True | |
| return False | |
| def _is_abbreviation_error(gt: str, hyp: str) -> bool: | |
| """Vrai si le GT contient un caractère d'abréviation médiévale.""" | |
| abbreviation_chars = "\uA751\uA753\uA757" # ꝑ ꝓ ꝗ | |
| return any(c in gt for c in abbreviation_chars) | |
| def _is_diacritic_error(gt: str, hyp: str) -> bool: | |
| """Vrai si la différence est principalement due à des diacritiques.""" | |
| # Comparer les formes sans diacritiques | |
| def strip_diacritics(text: str) -> str: | |
| nfd = unicodedata.normalize("NFD", text) | |
| return "".join(c for c in nfd if unicodedata.category(c) != "Mn") | |
| gt_stripped = strip_diacritics(gt) | |
| hyp_stripped = strip_diacritics(hyp) | |
| # Si les mots sont identiques sans diacritiques → erreur diacritique | |
| if gt_stripped.casefold() == hyp_stripped.casefold() and gt != hyp: | |
| return True | |
| # Si le GT contient des diacritiques que l'OCR a perdus et que les textes | |
| # sans diacritiques sont identiques (même longueur requise) | |
| gt_has_diac = any(c in DIACRITIC_MAP for c in gt) | |
| return gt_has_diac and len(gt) == len(hyp) and gt_stripped.casefold() == hyp_stripped.casefold() | |
| def _is_visual_confusion(gt: str, hyp: str) -> bool: | |
| """Vrai si la différence implique des confusions visuelles connues.""" | |
| if abs(len(gt) - len(hyp)) > 2: | |
| return False | |
| # Vérifier les paires de confusions connues | |
| for pair in VISUAL_CONFUSIONS: | |
| chars = list(pair) | |
| if len(chars) == 2: | |
| a, b = chars | |
| if a in gt and b in hyp and a not in hyp: | |
| return True | |
| if b in gt and a in hyp and b not in hyp: | |
| return True | |
| return False | |
| def _is_oov_word(word: str) -> bool: | |
| """Vrai si le mot contient des caractères hors de l'alphabet latin de base.""" | |
| return any(c not in _LATIN_BASIC and not c.isalpha() for c in word) | |
| # --------------------------------------------------------------------------- | |
| # Agrégation | |
| # --------------------------------------------------------------------------- | |
| def aggregate_taxonomy(results: list[TaxonomyResult]) -> dict: | |
| """Agrège les résultats taxonomiques sur un corpus.""" | |
| combined: dict[str, int] = {cls: 0 for cls in ERROR_CLASSES} | |
| total = 0 | |
| for r in results: | |
| for cls, cnt in r.counts.items(): | |
| combined[cls] = combined.get(cls, 0) + cnt | |
| total += r.total_errors | |
| distribution = { | |
| cls: round(cnt / total, 4) if total > 0 else 0.0 | |
| for cls, cnt in combined.items() | |
| } | |
| return { | |
| "counts": combined, | |
| "total_errors": total, | |
| "class_distribution": distribution, | |
| } | |