|
|
|
|
|
|
|
|
""" |
|
|
Prepare training data using DocTR OCR output. |
|
|
|
|
|
This script: |
|
|
1. Iterates through SROIE training/test images |
|
|
2. Runs DocTR OCR to get words and boxes |
|
|
3. Aligns DocTR output with ground truth labels using fuzzy matching |
|
|
4. Saves the aligned dataset to a pickle file for training |
|
|
|
|
|
This ensures the model learns from DocTR's actual output (with its specific errors) |
|
|
rather than from perfect ground truth which it will never see in production. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import sys |
|
|
import os |
|
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) |
|
|
|
|
|
import json |
|
|
import pickle |
|
|
from pathlib import Path |
|
|
from PIL import Image |
|
|
from tqdm import tqdm |
|
|
from difflib import SequenceMatcher |
|
|
from typing import List, Dict, Any, Tuple, Optional |
|
|
|
|
|
from doctr.io import DocumentFile |
|
|
from doctr.models import ocr_predictor |
|
|
|
|
|
|
|
|
SROIE_DATA_PATH = "data/sroie" |
|
|
OUTPUT_CACHE_PATH = "data/doctr_trained_cache.pkl" |
|
|
|
|
|
|
|
|
GT_FIELD_MAPPING = { |
|
|
"company": "COMPANY", |
|
|
"date": "DATE", |
|
|
"address": "ADDRESS", |
|
|
"total": "TOTAL", |
|
|
} |
|
|
|
|
|
|
|
|
def load_doctr_predictor(): |
|
|
"""Initialize DocTR predictor with lightweight backbone and move to GPU.""" |
|
|
print("Loading DocTR OCR predictor...") |
|
|
|
|
|
|
|
|
predictor = ocr_predictor( |
|
|
det_arch='db_resnet50', |
|
|
reco_arch='crnn_vgg16_bn', |
|
|
pretrained=True |
|
|
) |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
print("🚀 Moving DocTR to GPU (CUDA)...") |
|
|
predictor.cuda() |
|
|
else: |
|
|
print("⚠️ GPU not found. Running on CPU (this will be slow).") |
|
|
|
|
|
print("DocTR OCR predictor ready.") |
|
|
return predictor |
|
|
|
|
|
|
|
|
def parse_doctr_output(doctr_result, img_width: int, img_height: int) -> Tuple[List[str], List[List[int]]]: |
|
|
""" |
|
|
Parse DocTR output into words and normalized boxes (0-1000 scale). |
|
|
|
|
|
Returns: |
|
|
words: List of word strings |
|
|
normalized_boxes: List of [x0, y0, x1, y1] in 0-1000 scale |
|
|
""" |
|
|
words = [] |
|
|
normalized_boxes = [] |
|
|
|
|
|
for page in doctr_result.pages: |
|
|
for block in page.blocks: |
|
|
for line in block.lines: |
|
|
for word in line.words: |
|
|
if not word.value.strip(): |
|
|
continue |
|
|
|
|
|
words.append(word.value) |
|
|
|
|
|
|
|
|
(x_min, y_min), (x_max, y_max) = word.geometry |
|
|
|
|
|
|
|
|
normalized_boxes.append([ |
|
|
max(0, min(1000, int(x_min * 1000))), |
|
|
max(0, min(1000, int(y_min * 1000))), |
|
|
max(0, min(1000, int(x_max * 1000))), |
|
|
max(0, min(1000, int(y_max * 1000))), |
|
|
]) |
|
|
|
|
|
return words, normalized_boxes |
|
|
|
|
|
|
|
|
def fuzzy_match_score(s1: str, s2: str) -> float: |
|
|
"""Calculate fuzzy match score between two strings.""" |
|
|
return SequenceMatcher(None, s1.lower(), s2.lower()).ratio() |
|
|
|
|
|
|
|
|
def find_entity_in_words( |
|
|
entity_text: str, |
|
|
words: List[str], |
|
|
start_idx: int = 0, |
|
|
threshold: float = 0.7 |
|
|
) -> Optional[Tuple[int, int]]: |
|
|
""" |
|
|
Find a ground truth entity in the DocTR words using fuzzy matching. |
|
|
Includes expansion search to handle OCR word splitting. |
|
|
""" |
|
|
entity_words = entity_text.split() |
|
|
n_target = len(entity_words) |
|
|
|
|
|
|
|
|
if n_target == 1: |
|
|
best_score = 0 |
|
|
best_idx = -1 |
|
|
for i in range(start_idx, len(words)): |
|
|
score = fuzzy_match_score(entity_text, words[i]) |
|
|
if score > best_score and score >= threshold: |
|
|
best_score = score |
|
|
best_idx = i |
|
|
if best_idx >= 0: |
|
|
return (best_idx, best_idx) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
best_match_score = 0.0 |
|
|
best_match_indices = None |
|
|
|
|
|
|
|
|
min_len = max(1, n_target - 3) |
|
|
max_len = min(len(words) - start_idx, n_target + 5) |
|
|
|
|
|
combined_entity_text = " ".join(entity_words) |
|
|
|
|
|
|
|
|
for window_size in range(min_len, max_len + 1): |
|
|
for i in range(start_idx, len(words) - window_size + 1): |
|
|
|
|
|
|
|
|
window_tokens = words[i : i + window_size] |
|
|
window_text = " ".join(window_tokens) |
|
|
|
|
|
score = fuzzy_match_score(combined_entity_text, window_text) |
|
|
|
|
|
|
|
|
if score > 0.95: |
|
|
return (i, i + window_size - 1) |
|
|
|
|
|
if score > best_match_score and score >= threshold: |
|
|
best_match_score = score |
|
|
best_match_indices = (i, i + window_size - 1) |
|
|
|
|
|
return best_match_indices |
|
|
|
|
|
|
|
|
def load_ground_truth(json_path: Path) -> Dict[str, str]: |
|
|
""" |
|
|
Load ground truth entities from the tagged JSON. |
|
|
|
|
|
The SROIE tagged JSON has: {"words": [...], "bbox": [...], "labels": [...]} |
|
|
We need to reconstruct the entity values from words + labels. |
|
|
""" |
|
|
with open(json_path, encoding="utf-8") as f: |
|
|
data = json.load(f) |
|
|
|
|
|
words = data.get("words", []) |
|
|
labels = data.get("labels", []) |
|
|
|
|
|
|
|
|
entities = {} |
|
|
current_entity = None |
|
|
current_text = [] |
|
|
|
|
|
for word, label in zip(words, labels): |
|
|
if label.startswith("B-"): |
|
|
|
|
|
if current_entity and current_text: |
|
|
entities[current_entity.lower()] = " ".join(current_text) |
|
|
|
|
|
|
|
|
current_entity = label[2:] |
|
|
current_text = [word] |
|
|
|
|
|
elif label.startswith("I-") and current_entity: |
|
|
entity_type = label[2:] |
|
|
if entity_type == current_entity: |
|
|
current_text.append(word) |
|
|
else: |
|
|
|
|
|
if current_text: |
|
|
entities[current_entity.lower()] = " ".join(current_text) |
|
|
current_entity = None |
|
|
current_text = [] |
|
|
else: |
|
|
|
|
|
if current_entity and current_text: |
|
|
entities[current_entity.lower()] = " ".join(current_text) |
|
|
current_entity = None |
|
|
current_text = [] |
|
|
|
|
|
|
|
|
if current_entity and current_text: |
|
|
entities[current_entity.lower()] = " ".join(current_text) |
|
|
|
|
|
return entities |
|
|
|
|
|
|
|
|
def align_labels( |
|
|
doctr_words: List[str], |
|
|
ground_truth: Dict[str, str] |
|
|
) -> List[str]: |
|
|
labels = ["O"] * len(doctr_words) |
|
|
used_indices = set() |
|
|
|
|
|
for gt_field, bio_label in GT_FIELD_MAPPING.items(): |
|
|
if gt_field not in ground_truth: |
|
|
continue |
|
|
|
|
|
entity_text = ground_truth[gt_field] |
|
|
if not entity_text or not entity_text.strip(): |
|
|
continue |
|
|
|
|
|
|
|
|
current_threshold = 0.6 |
|
|
if bio_label == "ADDRESS": |
|
|
current_threshold = 0.45 |
|
|
elif bio_label in ["DATE", "TOTAL"]: |
|
|
current_threshold = 0.7 |
|
|
|
|
|
match = find_entity_in_words(entity_text, doctr_words, start_idx=0, threshold=current_threshold) |
|
|
|
|
|
if match: |
|
|
start_idx, end_idx = match |
|
|
|
|
|
|
|
|
if any(i in used_indices for i in range(start_idx, end_idx + 1)): |
|
|
continue |
|
|
|
|
|
labels[start_idx] = f"B-{bio_label}" |
|
|
for i in range(start_idx + 1, end_idx + 1): |
|
|
labels[i] = f"I-{bio_label}" |
|
|
|
|
|
used_indices.update(range(start_idx, end_idx + 1)) |
|
|
|
|
|
return labels |
|
|
|
|
|
|
|
|
def process_split( |
|
|
split_path: Path, |
|
|
predictor, |
|
|
split_name: str |
|
|
) -> List[Dict[str, Any]]: |
|
|
"""Process all images in a split directory.""" |
|
|
|
|
|
|
|
|
if (split_path / "images").exists(): |
|
|
img_dir = split_path / "images" |
|
|
elif (split_path / "img").exists(): |
|
|
img_dir = split_path / "img" |
|
|
else: |
|
|
print(f" ⚠️ No image directory found in {split_path}") |
|
|
return [] |
|
|
|
|
|
if (split_path / "tagged").exists(): |
|
|
ann_dir = split_path / "tagged" |
|
|
elif (split_path / "box").exists(): |
|
|
ann_dir = split_path / "box" |
|
|
else: |
|
|
print(f" ⚠️ No annotation directory found in {split_path}") |
|
|
return [] |
|
|
|
|
|
examples = [] |
|
|
image_files = sorted([f for f in img_dir.iterdir() if f.suffix.lower() in [".jpg", ".png"]]) |
|
|
|
|
|
print(f" Processing {len(image_files)} images in {split_name}...") |
|
|
|
|
|
for img_file in tqdm(image_files, desc=f" {split_name}"): |
|
|
try: |
|
|
|
|
|
json_path = ann_dir / f"{img_file.stem}.json" |
|
|
if not json_path.exists(): |
|
|
continue |
|
|
|
|
|
|
|
|
with Image.open(img_file) as img: |
|
|
width, height = img.size |
|
|
|
|
|
|
|
|
doc = DocumentFile.from_images(str(img_file)) |
|
|
doctr_result = predictor(doc) |
|
|
|
|
|
|
|
|
words, boxes = parse_doctr_output(doctr_result, width, height) |
|
|
|
|
|
if not words: |
|
|
continue |
|
|
|
|
|
|
|
|
ground_truth = load_ground_truth(json_path) |
|
|
aligned_labels = align_labels(words, ground_truth) |
|
|
|
|
|
|
|
|
examples.append({ |
|
|
"image_path": str(img_file), |
|
|
"words": words, |
|
|
"bboxes": boxes, |
|
|
"ner_tags": aligned_labels, |
|
|
"ground_truth": ground_truth |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"\n ❌ Error processing {img_file.name}: {e}") |
|
|
continue |
|
|
|
|
|
return examples |
|
|
|
|
|
|
|
|
def main(): |
|
|
print("=" * 60) |
|
|
print("📦 DocTR Training Data Preparation") |
|
|
print("=" * 60) |
|
|
|
|
|
sroie_path = Path(SROIE_DATA_PATH) |
|
|
|
|
|
if not sroie_path.exists(): |
|
|
print(f"❌ SROIE path not found: {sroie_path}") |
|
|
return |
|
|
|
|
|
|
|
|
predictor = load_doctr_predictor() |
|
|
|
|
|
dataset = {"train": [], "test": []} |
|
|
|
|
|
|
|
|
for split in ["train", "test"]: |
|
|
split_path = sroie_path / split |
|
|
if not split_path.exists(): |
|
|
print(f" ⚠️ Split not found: {split}") |
|
|
continue |
|
|
|
|
|
print(f"\n📂 Processing {split} split...") |
|
|
examples = process_split(split_path, predictor, split) |
|
|
dataset[split] = examples |
|
|
|
|
|
|
|
|
total_entities = sum( |
|
|
sum(1 for label in ex["ner_tags"] if label.startswith("B-")) |
|
|
for ex in examples |
|
|
) |
|
|
print(f" ✅ {len(examples)} images processed") |
|
|
print(f" 📊 {total_entities} entities aligned") |
|
|
|
|
|
|
|
|
print(f"\n💾 Saving cache to {OUTPUT_CACHE_PATH}...") |
|
|
output_path = Path(OUTPUT_CACHE_PATH) |
|
|
output_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
with open(output_path, "wb") as f: |
|
|
pickle.dump(dataset, f) |
|
|
|
|
|
print(f"✅ Cache saved!") |
|
|
print(f" - Train examples: {len(dataset['train'])}") |
|
|
print(f" - Test examples: {len(dataset['test'])}") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|