financial-intelligence-ai / scripts /prepare_training_data.py
Vaibuzzz's picture
Upload folder using huggingface_hub
10ff0db verified
"""
Training Data Formatter.
Combines raw documents (with ground truth) into chat-format JSONL
ready for Unsloth/SFT training, then splits into train/test sets.
Usage:
python scripts/prepare_training_data.py --input data/training/with_anomalies.jsonl \
--train-output data/training/train.jsonl \
--test-output data/test/test.jsonl \
--test-size 30
"""
import json
import random
import argparse
import os
random.seed(42)
SYSTEM_PROMPT = """You are a financial document extraction expert. Your task is to:
1. Identify the document type (invoice, purchase_order, receipt, or bank_statement).
2. Extract all relevant fields into a structured JSON object following this exact schema:
- "common": document_type, date, issuer (name, address), recipient (name, address), total_amount, currency
- "line_items": array of {description, quantity, unit_price, amount}
- "type_specific": fields specific to the document type
- "flags": array of detected anomalies, each with {category, field, severity, description}
- "confidence_score": your confidence in the extraction (0.0 to 1.0)
3. Analyze the document for anomalies across these categories:
- arithmetic_error: Mathematical calculations that don't add up
- missing_field: Required fields that are absent from the document
- format_anomaly: Inconsistent formats, negative quantities, duplicate entries
- business_logic: Unusual amounts, suspicious patterns, round-number fraud indicators
- cross_field: Mismatched references between related fields or documents
4. If no anomalies are found, return an empty "flags" array.
Output ONLY valid JSON. No explanations, no markdown, no code blocks — just the raw JSON object."""
def format_as_chat(doc: dict) -> dict:
"""
Convert a document dict into chat-format training example.
Args:
doc: Dict with 'raw_text' and 'ground_truth'.
Returns:
Chat-format dict with 'messages' array.
"""
user_message = f"Extract structured data from this financial document:\n\n---\n{doc['raw_text']}\n---"
# Compact JSON for the assistant response (no extra whitespace)
assistant_response = json.dumps(doc["ground_truth"], separators=(",", ":"), ensure_ascii=False)
return {
"messages": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_message},
{"role": "assistant", "content": assistant_response},
]
}
def validate_example(example: dict) -> bool:
"""Check that a training example is well-formed."""
try:
messages = example.get("messages", [])
if len(messages) != 3:
return False
# Verify assistant response is valid JSON
assistant_content = messages[2]["content"]
parsed = json.loads(assistant_content)
# Check required keys
required = ["common", "flags", "confidence_score"]
for key in required:
if key not in parsed:
return False
# Check common has document_type
if "document_type" not in parsed.get("common", {}):
return False
return True
except (json.JSONDecodeError, KeyError):
return False
def compute_token_estimate(text: str) -> int:
"""Rough token count estimate (1 token ≈ 4 chars for English)."""
return len(text) // 4
def main():
parser = argparse.ArgumentParser(description="Prepare training data for fine-tuning")
parser.add_argument("--input", type=str, default="data/training/with_anomalies.jsonl")
parser.add_argument("--train-output", type=str, default="data/training/train.jsonl")
parser.add_argument("--test-output", type=str, default="data/test/test.jsonl")
parser.add_argument("--test-size", type=int, default=30)
args = parser.parse_args()
print(f"\n{'='*50}")
print(f" Training Data Formatter")
print(f"{'='*50}\n")
# Load documents
documents = []
with open(args.input, "r", encoding="utf-8") as f:
for line in f:
documents.append(json.loads(line.strip()))
print(f" Loaded {len(documents)} documents")
# Format as chat examples
examples = [format_as_chat(doc) for doc in documents]
# Validate
valid_examples = []
invalid_count = 0
for ex in examples:
if validate_example(ex):
valid_examples.append(ex)
else:
invalid_count += 1
print(f" Valid examples: {len(valid_examples)}")
if invalid_count > 0:
print(f" Invalid (skipped): {invalid_count}")
# Stratified split: ensure test set has both clean and anomalous examples
clean = [ex for ex in valid_examples if not json.loads(ex["messages"][2]["content"])["flags"]]
anomalous = [ex for ex in valid_examples if json.loads(ex["messages"][2]["content"])["flags"]]
random.shuffle(clean)
random.shuffle(anomalous)
test_size = min(args.test_size, len(valid_examples) // 5)
# Allocate ~40% anomalous in test set (matching the overall ratio)
test_anomalous_count = max(1, int(test_size * 0.4))
test_clean_count = test_size - test_anomalous_count
# Ensure we don't exceed available
test_anomalous_count = min(test_anomalous_count, len(anomalous))
test_clean_count = min(test_clean_count, len(clean))
test_examples = anomalous[:test_anomalous_count] + clean[:test_clean_count]
train_examples = anomalous[test_anomalous_count:] + clean[test_clean_count:]
random.shuffle(test_examples)
random.shuffle(train_examples)
# Save train set
os.makedirs(os.path.dirname(args.train_output), exist_ok=True)
with open(args.train_output, "w", encoding="utf-8") as f:
for ex in train_examples:
f.write(json.dumps(ex, ensure_ascii=False) + "\n")
# Save test set
os.makedirs(os.path.dirname(args.test_output), exist_ok=True)
with open(args.test_output, "w", encoding="utf-8") as f:
for ex in test_examples:
f.write(json.dumps(ex, ensure_ascii=False) + "\n")
# Save test ground truth separately (for evaluation)
ground_truth_path = os.path.join(os.path.dirname(args.test_output), "ground_truth.json")
test_ground_truths = []
for ex in test_examples:
gt = json.loads(ex["messages"][2]["content"])
test_ground_truths.append({
"input": ex["messages"][1]["content"],
"expected_output": gt,
})
with open(ground_truth_path, "w", encoding="utf-8") as f:
json.dump(test_ground_truths, f, indent=2, ensure_ascii=False)
# Statistics
train_tokens = sum(compute_token_estimate(json.dumps(ex)) for ex in train_examples)
test_tokens = sum(compute_token_estimate(json.dumps(ex)) for ex in test_examples)
# Count anomalies in each set
train_anomalous = sum(
1 for ex in train_examples
if json.loads(ex["messages"][2]["content"])["flags"]
)
test_anomalous = sum(
1 for ex in test_examples
if json.loads(ex["messages"][2]["content"])["flags"]
)
print(f"\n Split Summary:")
print(f" Train: {len(train_examples)} examples (~{train_tokens:,} tokens)")
print(f" - Clean: {len(train_examples) - train_anomalous}")
print(f" - With anomalies: {train_anomalous}")
print(f" Test: {len(test_examples)} examples (~{test_tokens:,} tokens)")
print(f" - Clean: {len(test_examples) - test_anomalous}")
print(f" - With anomalies: {test_anomalous}")
print(f"\n Saved:")
print(f" Train: {args.train_output}")
print(f" Test: {args.test_output}")
print(f" Ground truth: {ground_truth_path}\n")
if __name__ == "__main__":
main()