financial-intelligence-ai / scripts /convert_datasets.py
Vaibuzzz's picture
Upload folder using huggingface_hub
10ff0db verified
"""
Deterministic Dataset Converter β€” No GPT Required.
Converts structured datasets (CORD, SROIE) directly into our training
format using programmatic mapping. No API calls, no costs, runs offline.
Supports:
1. CORD v1 (HuggingFace) β€” rich receipt annotations with line items
2. SROIE v2 (Kaggle) β€” receipts with company, date, address, total
Usage:
# Convert CORD (auto-downloads from HuggingFace)
python scripts/convert_datasets.py --source cord --max-docs 200 \
--output data/training/real_cord.jsonl
# Convert SROIE (must download first)
python scripts/convert_datasets.py --source sroie --sroie-path data/raw/sroie \
--max-docs 200 --output data/training/real_sroie.jsonl
# Merge all sources
python scripts/convert_datasets.py --source merge --output data/training/merged_raw.jsonl
"""
import os
import sys
import json
import glob
import random
import argparse
import hashlib
from typing import Optional, List
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
def convert_cord(max_docs: Optional[int] = None) -> List[dict]:
"""
Convert CORD v1 structured annotations into our training format.
CORD gt_parse has:
- store_info: {name, branch, address, tel}
- menu: [{nm, cnt, price, sub: [{nm, price}]}]
- total: {subtotal_price, tax_price, total_price, ...}
- payment: {cash_price, change_price, credit_card_price}
We map this directly to our schema.
"""
try:
from datasets import load_dataset
except ImportError:
print(" [ERROR] Install: pip install datasets")
return []
print(" Downloading CORD v1 from HuggingFace...")
try:
dataset = load_dataset("naver-clova-ix/cord-v1", split="train")
except Exception as e:
print(f" [ERROR] {e}")
return []
print(f" Loaded {len(dataset)} CORD receipts")
documents = []
limit = max_docs if max_docs else len(dataset)
for i, sample in enumerate(dataset):
if i >= limit:
break
try:
gt_str = sample.get("ground_truth", "")
gt_data = json.loads(gt_str) if isinstance(gt_str, str) else gt_str
gt_parse = gt_data.get("gt_parse", gt_data)
# --- Build raw text (simulating OCR output) ---
text_parts = []
store = gt_parse.get("store_info", {})
store_name = store.get("name", {}).get("text", "") if isinstance(store.get("name"), dict) else store.get("name", "")
store_addr = store.get("address", {}).get("text", "") if isinstance(store.get("address"), dict) else store.get("address", "")
store_tel = store.get("tel", {}).get("text", "") if isinstance(store.get("tel"), dict) else store.get("tel", "")
# Handle list-type store_info fields
if isinstance(store_name, list):
store_name = " ".join(str(x.get("text", x) if isinstance(x, dict) else x) for x in store_name)
if isinstance(store_addr, list):
store_addr = " ".join(str(x.get("text", x) if isinstance(x, dict) else x) for x in store_addr)
store_name = str(store_name).strip() if store_name else ""
store_addr = str(store_addr).strip() if store_addr else ""
store_tel = str(store_tel).strip() if store_tel else ""
if store_name:
text_parts.append(store_name)
if store_addr:
text_parts.append(store_addr)
if store_tel:
text_parts.append(f"Tel: {store_tel}")
text_parts.append("")
text_parts.append("RECEIPT")
text_parts.append("-" * 40)
# Menu items
menu = gt_parse.get("menu", [])
line_items = []
for item in (menu or []):
nm = _extract_text(item.get("nm", ""))
cnt = _extract_text(item.get("cnt", "1"))
price = _extract_text(item.get("price", "0"))
try:
qty = float(cnt) if cnt else 1
except (ValueError, TypeError):
qty = 1
try:
amount = float(price.replace(",", "").replace(" ", "")) if price else 0
except (ValueError, TypeError):
amount = 0
unit_price = round(amount / qty, 2) if qty > 0 else amount
text_parts.append(f" {nm:<30} x{int(qty):<5} {amount:>10.2f}")
line_items.append({
"description": nm or "Unknown Item",
"quantity": qty,
"unit_price": unit_price,
"amount": amount,
})
# Sub-items
for sub in (item.get("sub", []) or []):
sub_nm = _extract_text(sub.get("nm", ""))
sub_price = _extract_text(sub.get("price", ""))
if sub_nm:
text_parts.append(f" + {sub_nm:<28} {sub_price:>10}")
text_parts.append("-" * 40)
# Totals
total_info = gt_parse.get("total", {})
subtotal = _parse_amount(total_info.get("subtotal_price"))
tax = _parse_amount(total_info.get("tax_price"))
total = _parse_amount(total_info.get("total_price"))
discount = _parse_amount(total_info.get("discount_price"))
if subtotal:
text_parts.append(f" Subtotal: {subtotal:.2f}")
if discount:
text_parts.append(f" Discount: -{discount:.2f}")
if tax:
text_parts.append(f" Tax: {tax:.2f}")
if total:
text_parts.append(f" TOTAL: {total:.2f}")
# Payment
pay_info = gt_parse.get("payment", {})
cash = _parse_amount(pay_info.get("cash_price"))
change = _parse_amount(pay_info.get("change_price"))
card = _parse_amount(pay_info.get("credit_card_price"))
payment_method = "unknown"
if cash and cash > 0:
text_parts.append(f" Cash: {cash:.2f}")
payment_method = "cash"
if change and change >0:
text_parts.append(f" Change: {change:.2f}")
if card and card > 0:
text_parts.append(f" Card: {card:.2f}")
payment_method = "credit_card"
raw_text = "\n".join(text_parts)
if len(raw_text.strip()) < 15:
continue
# --- Build ground truth in our schema ---
# Calculate what total SHOULD be for anomaly detection
computed_total = sum(item["amount"] for item in line_items)
flags = []
# Check if subtotal β‰  sum of items
if subtotal and line_items and abs(computed_total - subtotal) > 1.0:
flags.append({
"category": "arithmetic_error",
"field": "type_specific.subtotal",
"severity": "medium",
"description": f"Sum of items ({computed_total:.2f}) β‰  subtotal ({subtotal:.2f})"
})
# Check subtotal + tax = total
if subtotal and tax is not None and total:
expected_total = subtotal + (tax or 0) - (discount or 0)
if abs(expected_total - total) > 1.0:
flags.append({
"category": "arithmetic_error",
"field": "common.total_amount",
"severity": "high",
"description": f"Total ({total:.2f}) β‰  subtotal ({subtotal:.2f}) + tax ({tax:.2f}) = {expected_total:.2f}"
})
# Generate a receipt number from hash
receipt_num = f"RCP-{hashlib.md5(raw_text.encode()).hexdigest()[:8].upper()}"
ground_truth = {
"common": {
"document_type": "receipt",
"date": None, # CORD doesn't have dates
"issuer": {
"name": store_name or None,
"address": store_addr or None,
},
"recipient": None,
"total_amount": total if total else computed_total,
"currency": "USD",
},
"line_items": line_items,
"type_specific": {
"receipt_number": receipt_num,
"payment_method": payment_method,
"store_location": store_addr or None,
"cashier": None,
"subtotal": subtotal,
"tax_amount": tax,
},
"flags": flags,
"confidence_score": 0.88,
}
documents.append({
"doc_type": "receipt",
"raw_text": raw_text,
"ground_truth": ground_truth,
"source": "cord",
})
except Exception as e:
continue
print(f" βœ… Converted {len(documents)} CORD documents")
return documents
def convert_sroie(sroie_path: str, max_docs: Optional[int] = None) -> List[dict]:
"""
Convert SROIE dataset annotations into our training format.
SROIE labels are: {company, date, address, total}
We map these to our schema with receipt type.
"""
documents = []
for split in ["train", "test"]:
box_dir = os.path.join(sroie_path, split, "box")
entities_dir = os.path.join(sroie_path, split, "entities")
if not os.path.exists(box_dir):
box_dir = os.path.join(sroie_path, split)
entities_dir = os.path.join(sroie_path, split)
if not os.path.exists(box_dir):
print(f" [SKIP] {box_dir} not found")
continue
txt_files = sorted(glob.glob(os.path.join(box_dir, "*.txt")))
print(f" Found {len(txt_files)} OCR files in {split}/")
for txt_file in txt_files:
basename = os.path.splitext(os.path.basename(txt_file))[0]
try:
# Read OCR text
with open(txt_file, "r", encoding="utf-8", errors="ignore") as f:
lines = f.readlines()
text_parts = []
for line in lines:
line = line.strip()
if not line:
continue
parts = line.split(",")
if len(parts) > 8:
text = ",".join(parts[8:]).strip()
if text:
text_parts.append(text)
elif len(parts) == 1:
text_parts.append(line)
raw_text = "\n".join(text_parts)
if len(raw_text.strip()) < 10:
continue
# Read entity labels
entity_file = os.path.join(entities_dir, basename + ".txt")
company, date_str, address, total_str = "", None, "", None
if os.path.exists(entity_file):
try:
with open(entity_file, "r", encoding="utf-8") as ef:
entity_data = json.load(ef)
company = entity_data.get("company", "")
date_str = entity_data.get("date", None)
address = entity_data.get("address", "")
total_str = entity_data.get("total", None)
except Exception:
pass
# Parse total
total = None
if total_str:
try:
total = float(str(total_str).replace(",", "").replace("$", "").strip())
except (ValueError, TypeError):
pass
# Normalize date
normalized_date = _normalize_date(date_str)
receipt_num = f"SROIE-{basename}"
ground_truth = {
"common": {
"document_type": "receipt",
"date": normalized_date,
"issuer": {
"name": company or None,
"address": address or None,
},
"recipient": None,
"total_amount": total,
"currency": "USD",
},
"line_items": [],
"type_specific": {
"receipt_number": receipt_num,
"payment_method": None,
"store_location": address or None,
"cashier": None,
},
"flags": [],
"confidence_score": 0.82,
}
documents.append({
"doc_type": "receipt",
"raw_text": raw_text,
"ground_truth": ground_truth,
"source": "sroie",
})
except Exception as e:
continue
print(f" βœ… Converted {len(documents)} SROIE documents")
if max_docs:
documents = documents[:max_docs]
print(f" Trimmed to {max_docs}")
return documents
def merge_all_datasets(output_path: str):
"""Merge all data sources into one master JSONL."""
sources = {
"data/training/synthetic_raw.jsonl": "synthetic",
"data/training/with_anomalies.jsonl": "synthetic_anomalies",
"data/training/real_cord.jsonl": "cord",
"data/training/real_sroie.jsonl": "sroie",
}
all_docs = []
# Prefer with_anomalies over synthetic_raw (same docs but enriched)
# If both exist, skip synthetic_raw
has_anomalies = os.path.exists("data/training/with_anomalies.jsonl")
for filepath, source_name in sources.items():
if source_name == "synthetic" and has_anomalies:
print(f" [SKIP] {filepath} (using with_anomalies.jsonl instead)")
continue
if not os.path.exists(filepath):
print(f" [SKIP] {filepath} not found")
continue
count = 0
with open(filepath, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
doc = json.loads(line)
if "source" not in doc:
doc["source"] = source_name
all_docs.append(doc)
count += 1
print(f" βœ… Loaded {count} docs from {source_name}")
if not all_docs:
print(" ❌ No documents to merge!")
return
# Shuffle for training
random.seed(42)
random.shuffle(all_docs)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
for doc in all_docs:
f.write(json.dumps(doc, ensure_ascii=False) + "\n")
# Stats
by_source = {}
by_type = {}
for doc in all_docs:
s = doc.get("source", "?")
t = doc.get("doc_type", "?")
by_source[s] = by_source.get(s, 0) + 1
by_type[t] = by_type.get(t, 0) + 1
print(f"\n {'═'*45}")
print(f" πŸ“Š MERGED DATASET")
print(f" {'═'*45}")
print(f" By Source:")
for src, cnt in sorted(by_source.items()):
print(f" {src:<25}: {cnt:>4}")
print(f" By Document Type:")
for dtype, cnt in sorted(by_type.items()):
print(f" {dtype:<25}: {cnt:>4}")
print(f" {'─'*45}")
print(f" Total: {len(all_docs)} documents")
print(f" Saved: {output_path}")
# === Helper Functions ===
def _extract_text(val) -> str:
"""Extract text from CORD's nested annotation format."""
if val is None:
return ""
if isinstance(val, str):
return val.strip()
if isinstance(val, dict):
return str(val.get("text", val.get("value", ""))).strip()
if isinstance(val, list):
texts = []
for v in val:
if isinstance(v, dict):
texts.append(str(v.get("text", v.get("value", ""))).strip())
else:
texts.append(str(v).strip())
return " ".join(texts)
return str(val).strip()
def _parse_amount(val) -> Optional[float]:
"""Parse monetary amount from various formats."""
if val is None:
return None
text = _extract_text(val)
if not text:
return None
try:
cleaned = text.replace(",", "").replace(" ", "").replace("$", "")
return float(cleaned)
except (ValueError, TypeError):
return None
def _normalize_date(date_str) -> Optional[str]:
"""Try to normalize date string to YYYY-MM-DD format."""
if not date_str:
return None
import re
from datetime import datetime
date_str = str(date_str).strip()
formats = [
"%Y-%m-%d", "%d/%m/%Y", "%m/%d/%Y", "%d-%m-%Y",
"%d %b %Y", "%d %B %Y", "%b %d, %Y", "%B %d, %Y",
"%Y/%m/%d", "%d.%m.%Y",
]
for fmt in formats:
try:
parsed = datetime.strptime(date_str, fmt)
return parsed.strftime("%Y-%m-%d")
except ValueError:
continue
return date_str # Return as-is if can't parse
def main():
parser = argparse.ArgumentParser(
description="Convert real-world datasets to our training format (no API needed)"
)
parser.add_argument("--source", required=True, choices=["cord", "sroie", "merge"])
parser.add_argument("--sroie-path", default="data/raw/sroie")
parser.add_argument("--output", default=None)
parser.add_argument("--max-docs", type=int, default=None)
args = parser.parse_args()
print(f"\n{'='*50}")
print(f" Dataset Converter (No API Required)")
print(f"{'='*50}\n")
if args.source == "cord":
output = args.output or "data/training/real_cord.jsonl"
print(f" Source: CORD v1 (HuggingFace)")
print(f" Output: {output}\n")
docs = convert_cord(args.max_docs)
if docs:
os.makedirs(os.path.dirname(output), exist_ok=True)
with open(output, "w", encoding="utf-8") as f:
for doc in docs:
f.write(json.dumps(doc, ensure_ascii=False) + "\n")
print(f"\n βœ… Saved {len(docs)} documents to {output}")
elif args.source == "sroie":
output = args.output or "data/training/real_sroie.jsonl"
print(f" Source: SROIE v2 ({args.sroie_path})")
print(f" Output: {output}\n")
docs = convert_sroie(args.sroie_path, args.max_docs)
if docs:
os.makedirs(os.path.dirname(output), exist_ok=True)
with open(output, "w", encoding="utf-8") as f:
for doc in docs:
f.write(json.dumps(doc, ensure_ascii=False) + "\n")
print(f"\n βœ… Saved {len(docs)} documents to {output}")
elif args.source == "merge":
output = args.output or "data/training/merged_raw.jsonl"
print(f" Merging all data sources...")
print(f" Output: {output}\n")
merge_all_datasets(output)
print()
if __name__ == "__main__":
main()