financial-intelligence-ai / scripts /ingest_kaggle_data.py
Vaibuzzz's picture
Upload folder using huggingface_hub
10ff0db verified
"""
Real-World Dataset Ingestion Pipeline.
Ingests financial document datasets from Kaggle/HuggingFace and generates
ground truth labels using GPT-4o-mini in our exact Pydantic schema.
Supported datasets:
1. SROIE v2 (Kaggle: urbikn/sroie-datasetv2) β€” 973 scanned receipts
2. CORD v1 (HuggingFace: naver-clova-ix/cord-v1) β€” 1,000 receipts
Usage:
# Process SROIE data (download from Kaggle first)
python scripts/ingest_kaggle_data.py --source sroie \
--sroie-path data/raw/sroie \
--output data/training/real_sroie.jsonl \
--max-docs 200
# Process CORD data (downloads from HuggingFace automatically)
python scripts/ingest_kaggle_data.py --source cord \
--output data/training/real_cord.jsonl \
--max-docs 100
# Merge all sources into final training set
python scripts/ingest_kaggle_data.py --source merge \
--output data/training/merged_raw.jsonl
"""
import os
import sys
import json
import glob
import time
import argparse
from typing import Optional, List
from dotenv import load_dotenv
load_dotenv()
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# GPT-4o-mini prompt (same as generate_ground_truth.py but optimized for noisy OCR)
LABELING_PROMPT = """You are a financial document extraction expert. Given raw OCR text from a scanned receipt or invoice, you must:
1. Identify the document type: "invoice", "purchase_order", "receipt", or "bank_statement"
2. Extract ALL fields into the exact JSON schema below
3. Detect any anomalies present in the document
IMPORTANT RULES:
- Output ONLY valid JSON β€” no markdown, no explanations, no code blocks
- This is REAL OCR text β€” expect noise, typos, and formatting issues. Do your best.
- If a field is not present or unreadable, use null
- If no anomalies exist, return an empty "flags" array: []
- All dates should be normalized to YYYY-MM-DD format
- All monetary amounts should be numbers (not strings)
- confidence_score should reflect OCR quality and your certainty (0.0 to 1.0)
Anomaly categories to check:
- arithmetic_error: math that doesn't add up
- missing_field: required fields absent from the document
- format_anomaly: impossible dates, negative quantities, duplicate entries
- business_logic: round-number fraud, extreme amounts, unusual terms
- cross_field: mismatched references, currency conflicts
Required JSON Schema:
{
"common": {
"document_type": "receipt",
"date": "YYYY-MM-DD or null",
"issuer": {"name": "string or null", "address": "string or null"},
"recipient": null,
"total_amount": number_or_null,
"currency": "USD"
},
"line_items": [
{"description": "string", "quantity": number, "unit_price": number, "amount": number}
],
"type_specific": {
"receipt_number": "string or null",
"payment_method": "string or null",
"store_location": "string or null",
"cashier": "string or null"
},
"flags": [],
"confidence_score": 0.85
}"""
def load_sroie_data(sroie_path: str, max_docs: Optional[int] = None) -> List[dict]:
"""
Load SROIE dataset from local directory.
Expected structure:
sroie_path/
β”œβ”€β”€ train/
β”‚ β”œβ”€β”€ img/ # Receipt images (we skip these)
β”‚ β”œβ”€β”€ box/ # Bounding box + text (OCR output)
β”‚ └── entities/ # Key-value labels (company, date, address, total)
└── test/
β”œβ”€β”€ img/
└── box/
We read from box/ (OCR text) and entities/ (basic labels for reference).
"""
documents = []
# Look for OCR text files in both train and test
for split in ["train", "test"]:
box_dir = os.path.join(sroie_path, split, "box")
entities_dir = os.path.join(sroie_path, split, "entities")
if not os.path.exists(box_dir):
# Try alternative structure (flat)
box_dir = os.path.join(sroie_path, split)
entities_dir = os.path.join(sroie_path, split)
if not os.path.exists(box_dir):
print(f" [SKIP] Directory not found: {box_dir}")
continue
txt_files = sorted(glob.glob(os.path.join(box_dir, "*.txt")))
print(f" Found {len(txt_files)} OCR text files in {split}/")
for txt_file in txt_files:
basename = os.path.splitext(os.path.basename(txt_file))[0]
# Read OCR text (format: x1,y1,x2,y2,x3,y3,x4,y4,text)
try:
with open(txt_file, "r", encoding="utf-8", errors="ignore") as f:
lines = f.readlines()
# Extract just the text portions (after the 8 coordinate values)
text_parts = []
for line in lines:
line = line.strip()
if not line:
continue
# SROIE format: coords,coords,coords,coords,text
# Split by comma, take everything after 8th value
parts = line.split(",")
if len(parts) > 8:
text = ",".join(parts[8:]).strip()
if text:
text_parts.append(text)
elif len(parts) == 1:
# Some files have just text
text_parts.append(line)
raw_text = "\n".join(text_parts)
if not raw_text.strip() or len(raw_text) < 10:
continue
# Read entity labels if available (for reference)
entity_file = os.path.join(entities_dir, basename + ".txt")
original_labels = {}
if os.path.exists(entity_file):
try:
with open(entity_file, "r", encoding="utf-8") as ef:
entity_data = json.load(ef)
original_labels = entity_data
except (json.JSONDecodeError, Exception):
pass
documents.append({
"source": "sroie",
"doc_id": basename,
"raw_text": raw_text,
"original_labels": original_labels,
})
except Exception as e:
print(f" [SKIP] {basename}: {e}")
continue
print(f" Loaded {len(documents)} SROIE documents total")
if max_docs:
documents = documents[:max_docs]
print(f" Trimmed to {max_docs} documents")
return documents
def load_cord_data(max_docs: Optional[int] = None) -> List[dict]:
"""
Load CORD v1 dataset from HuggingFace.
Downloads automatically on first use.
"""
try:
from datasets import load_dataset
except ImportError:
print(" [ERROR] 'datasets' library not installed. Run: pip install datasets")
return []
print(" Downloading CORD v1 from HuggingFace...")
try:
dataset = load_dataset("naver-clova-ix/cord-v1", split="train")
except Exception as e:
print(f" [ERROR] Failed to load CORD: {e}")
return []
print(f" Loaded {len(dataset)} CORD receipts")
documents = []
limit = max_docs if max_docs else len(dataset)
for i, sample in enumerate(dataset):
if i >= limit:
break
try:
# CORD stores parsed data in 'ground_truth' field as JSON string
gt_str = sample.get("ground_truth", "")
if isinstance(gt_str, str):
gt_data = json.loads(gt_str)
else:
gt_data = gt_str
# Extract text from the ground truth parse
# CORD gt_parse has nested structure with menu items, totals, etc.
gt_parse = gt_data.get("gt_parse", gt_data)
# Reconstruct raw text from the parsed fields
text_parts = []
# Store name
store_info = gt_parse.get("store_info", {})
if store_info:
for key in ["name", "branch", "address", "tel"]:
val = store_info.get(key, "")
if val:
text_parts.append(str(val))
# Menu items
menu = gt_parse.get("menu", [])
if menu:
text_parts.append("\n--- ITEMS ---")
for item in menu:
name = item.get("nm", "")
cnt = item.get("cnt", "")
price = item.get("price", "")
sub_items = item.get("sub", [])
line = f" {name}"
if cnt:
line += f" x{cnt}"
if price:
line += f" {price}"
text_parts.append(line)
for sub in (sub_items or []):
sub_name = sub.get("nm", "")
sub_price = sub.get("price", "")
text_parts.append(f" - {sub_name} {sub_price}")
# Totals
total_info = gt_parse.get("total", {})
if total_info:
text_parts.append("\n--- TOTALS ---")
for key in ["subtotal_price", "tax_price", "total_price",
"discount_price", "service_price"]:
val = total_info.get(key, "")
if val:
label = key.replace("_", " ").title()
text_parts.append(f" {label}: {val}")
# Payment
pay_info = gt_parse.get("payment", {})
if pay_info:
text_parts.append("\n--- PAYMENT ---")
for key in ["cash_price", "change_price", "credit_card_price"]:
val = pay_info.get(key, "")
if val:
label = key.replace("_", " ").title()
text_parts.append(f" {label}: {val}")
raw_text = "\n".join(text_parts)
if len(raw_text.strip()) < 10:
continue
documents.append({
"source": "cord",
"doc_id": f"cord_{i}",
"raw_text": raw_text,
"original_labels": gt_parse,
})
except Exception as e:
continue
print(f" Processed {len(documents)} CORD documents")
return documents
def label_with_gpt(
documents: List[dict],
model: str = "gpt-4o-mini",
batch_delay: float = 0.15,
) -> List[dict]:
"""
Use GPT-4o-mini to generate ground truth labels in our schema.
Args:
documents: List with 'raw_text' field.
model: OpenAI model to use.
batch_delay: Delay between API calls (rate limiting).
Returns:
List of labeled documents in our training format.
"""
try:
from openai import OpenAI
except ImportError:
print(" [ERROR] 'openai' library not installed. Run: pip install openai")
return []
api_key = os.getenv("OPENAI_API_KEY")
if not api_key or api_key.startswith("sk-your"):
print("\n ❌ Error: OPENAI_API_KEY not set in .env!")
return []
client = OpenAI(api_key=api_key)
labeled = []
success = 0
failed = 0
print(f"\n Labeling {len(documents)} documents with {model}...")
for i, doc in enumerate(documents):
raw_text = doc["raw_text"]
user_msg = f"Extract structured data from this financial document:\n\n---\n{raw_text}\n---"
print(f" [{i+1}/{len(documents)}] {doc.get('source', '?')}/{doc.get('doc_id', '?')}...", end="")
for attempt in range(3):
try:
response = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": LABELING_PROMPT},
{"role": "user", "content": user_msg},
],
temperature=0.1,
max_tokens=2048,
response_format={"type": "json_object"},
)
raw_output = response.choices[0].message.content
parsed = json.loads(raw_output)
# Basic validation
if "common" not in parsed:
raise ValueError("Missing 'common' field")
if "document_type" not in parsed.get("common", {}):
raise ValueError("Missing 'document_type'")
if "flags" not in parsed:
parsed["flags"] = []
if "confidence_score" not in parsed:
parsed["confidence_score"] = 0.85
# Build training document in our format
labeled.append({
"doc_type": parsed["common"]["document_type"],
"raw_text": raw_text,
"ground_truth": parsed,
"source": doc.get("source", "unknown"),
})
num_flags = len(parsed.get("flags", []))
print(f" βœ… ({num_flags} flags)")
success += 1
break
except json.JSONDecodeError as e:
print(f" [RETRY {attempt+1}]", end="")
except Exception as e:
print(f" [RETRY {attempt+1}: {str(e)[:50]}]", end="")
time.sleep(0.5)
else:
print(f" ❌ failed")
failed += 1
time.sleep(batch_delay)
print(f"\n Labeling complete: {success} success, {failed} failed")
# Cost estimate
est_input = success * 1500 / 1_000_000
est_output = success * 700 / 1_000_000
est_cost = est_input * 0.15 + est_output * 0.60
print(f" Estimated cost: ~${est_cost:.3f}")
return labeled
def merge_datasets(output_path: str):
"""
Merge all data sources into one master JSONL.
Sources:
- data/training/synthetic_raw.jsonl (our Faker-generated docs)
- data/training/real_sroie.jsonl (SROIE receipts, GPT-labeled)
- data/training/real_cord.jsonl (CORD receipts, GPT-labeled)
"""
sources = [
("data/training/synthetic_raw.jsonl", "synthetic"),
("data/training/real_sroie.jsonl", "sroie"),
("data/training/real_cord.jsonl", "cord"),
]
all_docs = []
for filepath, source_name in sources:
if not os.path.exists(filepath):
print(f" [SKIP] {filepath} not found")
continue
count = 0
with open(filepath, "r", encoding="utf-8") as f:
for line in f:
doc = json.loads(line.strip())
if "source" not in doc:
doc["source"] = source_name
all_docs.append(doc)
count += 1
print(f" Loaded {count} docs from {source_name}")
if not all_docs:
print(" ❌ No documents found to merge!")
return
# Save merged file
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
for doc in all_docs:
f.write(json.dumps(doc, ensure_ascii=False) + "\n")
# Print distribution
source_counts = {}
type_counts = {}
for doc in all_docs:
s = doc.get("source", "unknown")
t = doc.get("doc_type", "unknown")
source_counts[s] = source_counts.get(s, 0) + 1
type_counts[t] = type_counts.get(t, 0) + 1
print(f"\n Merged Dataset Summary:")
print(f" {'─' * 40}")
print(f" By Source:")
for src, cnt in sorted(source_counts.items()):
print(f" {src:<20}: {cnt}")
print(f" By Document Type:")
for dtype, cnt in sorted(type_counts.items()):
print(f" {dtype:<20}: {cnt}")
print(f" {'─' * 40}")
print(f" Total: {len(all_docs)} documents")
print(f" Saved to: {output_path}")
def main():
parser = argparse.ArgumentParser(
description="Ingest real-world financial datasets and generate training labels"
)
parser.add_argument("--source", type=str, required=True,
choices=["sroie", "cord", "merge"],
help="Data source to process")
parser.add_argument("--sroie-path", type=str, default="data/raw/sroie",
help="Path to downloaded SROIE dataset")
parser.add_argument("--output", type=str, default=None,
help="Output JSONL file path")
parser.add_argument("--max-docs", type=int, default=None,
help="Maximum documents to process")
parser.add_argument("--model", type=str, default="gpt-4o-mini",
help="OpenAI model for labeling")
args = parser.parse_args()
print(f"\n{'='*50}")
print(f" Real-World Dataset Ingestion Pipeline")
print(f"{'='*50}\n")
if args.source == "sroie":
output = args.output or "data/training/real_sroie.jsonl"
print(f" Source: SROIE v2 ({args.sroie_path})")
print(f" Output: {output}\n")
docs = load_sroie_data(args.sroie_path, args.max_docs)
if docs:
labeled = label_with_gpt(docs, model=args.model)
if labeled:
os.makedirs(os.path.dirname(output), exist_ok=True)
with open(output, "w", encoding="utf-8") as f:
for doc in labeled:
f.write(json.dumps(doc, ensure_ascii=False) + "\n")
print(f"\n βœ… Saved {len(labeled)} labeled documents to {output}")
elif args.source == "cord":
output = args.output or "data/training/real_cord.jsonl"
print(f" Source: CORD v1 (HuggingFace)")
print(f" Output: {output}\n")
docs = load_cord_data(args.max_docs)
if docs:
labeled = label_with_gpt(docs, model=args.model)
if labeled:
os.makedirs(os.path.dirname(output), exist_ok=True)
with open(output, "w", encoding="utf-8") as f:
for doc in labeled:
f.write(json.dumps(doc, ensure_ascii=False) + "\n")
print(f"\n βœ… Saved {len(labeled)} labeled documents to {output}")
elif args.source == "merge":
output = args.output or "data/training/merged_raw.jsonl"
print(f" Merging all data sources...")
print(f" Output: {output}\n")
merge_datasets(output)
print()
if __name__ == "__main__":
main()