# scripts/prepare_doctr_data.py """ Prepare training data using DocTR OCR output. This script: 1. Iterates through SROIE training/test images 2. Runs DocTR OCR to get words and boxes 3. Aligns DocTR output with ground truth labels using fuzzy matching 4. Saves the aligned dataset to a pickle file for training This ensures the model learns from DocTR's actual output (with its specific errors) rather than from perfect ground truth which it will never see in production. """ import torch import sys import os sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) import json import pickle from pathlib import Path from PIL import Image from tqdm import tqdm from difflib import SequenceMatcher from typing import List, Dict, Any, Tuple, Optional from doctr.io import DocumentFile from doctr.models import ocr_predictor # --- CONFIGURATION --- SROIE_DATA_PATH = "data/sroie" OUTPUT_CACHE_PATH = "data/doctr_trained_cache.pkl" # Ground truth field names and their corresponding BIO labels GT_FIELD_MAPPING = { "company": "COMPANY", "date": "DATE", "address": "ADDRESS", "total": "TOTAL", } def load_doctr_predictor(): """Initialize DocTR predictor with lightweight backbone and move to GPU.""" print("Loading DocTR OCR predictor...") # 1. Initialize the model predictor = ocr_predictor( det_arch='db_resnet50', reco_arch='crnn_vgg16_bn', pretrained=True ) # 2. Force it to GPU if available if torch.cuda.is_available(): print("šŸš€ Moving DocTR to GPU (CUDA)...") predictor.cuda() else: print("āš ļø GPU not found. Running on CPU (this will be slow).") print("DocTR OCR predictor ready.") return predictor def parse_doctr_output(doctr_result, img_width: int, img_height: int) -> Tuple[List[str], List[List[int]]]: """ Parse DocTR output into words and normalized boxes (0-1000 scale). Returns: words: List of word strings normalized_boxes: List of [x0, y0, x1, y1] in 0-1000 scale """ words = [] normalized_boxes = [] for page in doctr_result.pages: for block in page.blocks: for line in block.lines: for word in line.words: if not word.value.strip(): continue words.append(word.value) # DocTR bbox format: ((x_min, y_min), (x_max, y_max)) in 0-1 scale (x_min, y_min), (x_max, y_max) = word.geometry # Normalize to 0-1000 scale with clamping normalized_boxes.append([ max(0, min(1000, int(x_min * 1000))), max(0, min(1000, int(y_min * 1000))), max(0, min(1000, int(x_max * 1000))), max(0, min(1000, int(y_max * 1000))), ]) return words, normalized_boxes def fuzzy_match_score(s1: str, s2: str) -> float: """Calculate fuzzy match score between two strings.""" return SequenceMatcher(None, s1.lower(), s2.lower()).ratio() def find_entity_in_words( entity_text: str, words: List[str], start_idx: int = 0, threshold: float = 0.7 ) -> Optional[Tuple[int, int]]: """ Find a ground truth entity in the DocTR words using fuzzy matching. Includes expansion search to handle OCR word splitting. """ entity_words = entity_text.split() n_target = len(entity_words) # 1. Single word match if n_target == 1: best_score = 0 best_idx = -1 for i in range(start_idx, len(words)): score = fuzzy_match_score(entity_text, words[i]) if score > best_score and score >= threshold: best_score = score best_idx = i if best_idx >= 0: return (best_idx, best_idx) # 2. Multi-word entity: Flexible Window Search # We search windows of size N, N+1, N+2... up to N+5 (to catch OCR splits) # AND N-1, N-2... (to catch OCR merges) best_match_score = 0.0 best_match_indices = None # Define search range: from (Length - 3) to (Length + 5) min_len = max(1, n_target - 3) max_len = min(len(words) - start_idx, n_target + 5) combined_entity_text = " ".join(entity_words) # Iterate through window sizes for window_size in range(min_len, max_len + 1): for i in range(start_idx, len(words) - window_size + 1): # Construct window text window_tokens = words[i : i + window_size] window_text = " ".join(window_tokens) score = fuzzy_match_score(combined_entity_text, window_text) # Optimization: If perfect match, return immediately if score > 0.95: return (i, i + window_size - 1) if score > best_match_score and score >= threshold: best_match_score = score best_match_indices = (i, i + window_size - 1) return best_match_indices def load_ground_truth(json_path: Path) -> Dict[str, str]: """ Load ground truth entities from the tagged JSON. The SROIE tagged JSON has: {"words": [...], "bbox": [...], "labels": [...]} We need to reconstruct the entity values from words + labels. """ with open(json_path, encoding="utf-8") as f: data = json.load(f) words = data.get("words", []) labels = data.get("labels", []) # Reconstruct entities from BIO tags entities = {} current_entity = None current_text = [] for word, label in zip(words, labels): if label.startswith("B-"): # Save previous entity if exists if current_entity and current_text: entities[current_entity.lower()] = " ".join(current_text) # Start new entity current_entity = label[2:] # Remove "B-" prefix current_text = [word] elif label.startswith("I-") and current_entity: entity_type = label[2:] if entity_type == current_entity: current_text.append(word) else: # Entity type changed, save current if current_text: entities[current_entity.lower()] = " ".join(current_text) current_entity = None current_text = [] else: # "O" label - save current entity if exists if current_entity and current_text: entities[current_entity.lower()] = " ".join(current_text) current_entity = None current_text = [] # Don't forget the last entity if current_entity and current_text: entities[current_entity.lower()] = " ".join(current_text) return entities def align_labels( doctr_words: List[str], ground_truth: Dict[str, str] ) -> List[str]: labels = ["O"] * len(doctr_words) used_indices = set() for gt_field, bio_label in GT_FIELD_MAPPING.items(): if gt_field not in ground_truth: continue entity_text = ground_truth[gt_field] if not entity_text or not entity_text.strip(): continue # DYNAMIC THRESHOLD: Be lenient with Addresses, strict with Dates/Totals current_threshold = 0.6 if bio_label == "ADDRESS": current_threshold = 0.45 # Lower threshold for messy addresses elif bio_label in ["DATE", "TOTAL"]: current_threshold = 0.7 # Keep strict for precision fields match = find_entity_in_words(entity_text, doctr_words, start_idx=0, threshold=current_threshold) if match: start_idx, end_idx = match # Overlap check if any(i in used_indices for i in range(start_idx, end_idx + 1)): continue labels[start_idx] = f"B-{bio_label}" for i in range(start_idx + 1, end_idx + 1): labels[i] = f"I-{bio_label}" used_indices.update(range(start_idx, end_idx + 1)) return labels def process_split( split_path: Path, predictor, split_name: str ) -> List[Dict[str, Any]]: """Process all images in a split directory.""" # Find image and annotation directories if (split_path / "images").exists(): img_dir = split_path / "images" elif (split_path / "img").exists(): img_dir = split_path / "img" else: print(f" āš ļø No image directory found in {split_path}") return [] if (split_path / "tagged").exists(): ann_dir = split_path / "tagged" elif (split_path / "box").exists(): ann_dir = split_path / "box" else: print(f" āš ļø No annotation directory found in {split_path}") return [] examples = [] image_files = sorted([f for f in img_dir.iterdir() if f.suffix.lower() in [".jpg", ".png"]]) print(f" Processing {len(image_files)} images in {split_name}...") for img_file in tqdm(image_files, desc=f" {split_name}"): try: # Check for corresponding annotation json_path = ann_dir / f"{img_file.stem}.json" if not json_path.exists(): continue # Load image dimensions with Image.open(img_file) as img: width, height = img.size # Run DocTR OCR doc = DocumentFile.from_images(str(img_file)) doctr_result = predictor(doc) # Parse DocTR output words, boxes = parse_doctr_output(doctr_result, width, height) if not words: continue # Load ground truth and align labels ground_truth = load_ground_truth(json_path) aligned_labels = align_labels(words, ground_truth) # Create example examples.append({ "image_path": str(img_file), "words": words, "bboxes": boxes, "ner_tags": aligned_labels, "ground_truth": ground_truth # Keep for debugging }) except Exception as e: print(f"\n āŒ Error processing {img_file.name}: {e}") continue return examples def main(): print("=" * 60) print("šŸ“¦ DocTR Training Data Preparation") print("=" * 60) sroie_path = Path(SROIE_DATA_PATH) if not sroie_path.exists(): print(f"āŒ SROIE path not found: {sroie_path}") return # Load DocTR predictor predictor = load_doctr_predictor() dataset = {"train": [], "test": []} # Process each split for split in ["train", "test"]: split_path = sroie_path / split if not split_path.exists(): print(f" āš ļø Split not found: {split}") continue print(f"\nšŸ“‚ Processing {split} split...") examples = process_split(split_path, predictor, split) dataset[split] = examples # Stats total_entities = sum( sum(1 for label in ex["ner_tags"] if label.startswith("B-")) for ex in examples ) print(f" āœ… {len(examples)} images processed") print(f" šŸ“Š {total_entities} entities aligned") # Save cache print(f"\nšŸ’¾ Saving cache to {OUTPUT_CACHE_PATH}...") output_path = Path(OUTPUT_CACHE_PATH) output_path.parent.mkdir(parents=True, exist_ok=True) with open(output_path, "wb") as f: pickle.dump(dataset, f) print(f"āœ… Cache saved!") print(f" - Train examples: {len(dataset['train'])}") print(f" - Test examples: {len(dataset['test'])}") print("=" * 60) if __name__ == "__main__": main()