File size: 7,970 Bytes
10ff0db | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 | """
Training Data Formatter.
Combines raw documents (with ground truth) into chat-format JSONL
ready for Unsloth/SFT training, then splits into train/test sets.
Usage:
python scripts/prepare_training_data.py --input data/training/with_anomalies.jsonl \
--train-output data/training/train.jsonl \
--test-output data/test/test.jsonl \
--test-size 30
"""
import json
import random
import argparse
import os
random.seed(42)
SYSTEM_PROMPT = """You are a financial document extraction expert. Your task is to:
1. Identify the document type (invoice, purchase_order, receipt, or bank_statement).
2. Extract all relevant fields into a structured JSON object following this exact schema:
- "common": document_type, date, issuer (name, address), recipient (name, address), total_amount, currency
- "line_items": array of {description, quantity, unit_price, amount}
- "type_specific": fields specific to the document type
- "flags": array of detected anomalies, each with {category, field, severity, description}
- "confidence_score": your confidence in the extraction (0.0 to 1.0)
3. Analyze the document for anomalies across these categories:
- arithmetic_error: Mathematical calculations that don't add up
- missing_field: Required fields that are absent from the document
- format_anomaly: Inconsistent formats, negative quantities, duplicate entries
- business_logic: Unusual amounts, suspicious patterns, round-number fraud indicators
- cross_field: Mismatched references between related fields or documents
4. If no anomalies are found, return an empty "flags" array.
Output ONLY valid JSON. No explanations, no markdown, no code blocks — just the raw JSON object."""
def format_as_chat(doc: dict) -> dict:
"""
Convert a document dict into chat-format training example.
Args:
doc: Dict with 'raw_text' and 'ground_truth'.
Returns:
Chat-format dict with 'messages' array.
"""
user_message = f"Extract structured data from this financial document:\n\n---\n{doc['raw_text']}\n---"
# Compact JSON for the assistant response (no extra whitespace)
assistant_response = json.dumps(doc["ground_truth"], separators=(",", ":"), ensure_ascii=False)
return {
"messages": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_message},
{"role": "assistant", "content": assistant_response},
]
}
def validate_example(example: dict) -> bool:
"""Check that a training example is well-formed."""
try:
messages = example.get("messages", [])
if len(messages) != 3:
return False
# Verify assistant response is valid JSON
assistant_content = messages[2]["content"]
parsed = json.loads(assistant_content)
# Check required keys
required = ["common", "flags", "confidence_score"]
for key in required:
if key not in parsed:
return False
# Check common has document_type
if "document_type" not in parsed.get("common", {}):
return False
return True
except (json.JSONDecodeError, KeyError):
return False
def compute_token_estimate(text: str) -> int:
"""Rough token count estimate (1 token ≈ 4 chars for English)."""
return len(text) // 4
def main():
parser = argparse.ArgumentParser(description="Prepare training data for fine-tuning")
parser.add_argument("--input", type=str, default="data/training/with_anomalies.jsonl")
parser.add_argument("--train-output", type=str, default="data/training/train.jsonl")
parser.add_argument("--test-output", type=str, default="data/test/test.jsonl")
parser.add_argument("--test-size", type=int, default=30)
args = parser.parse_args()
print(f"\n{'='*50}")
print(f" Training Data Formatter")
print(f"{'='*50}\n")
# Load documents
documents = []
with open(args.input, "r", encoding="utf-8") as f:
for line in f:
documents.append(json.loads(line.strip()))
print(f" Loaded {len(documents)} documents")
# Format as chat examples
examples = [format_as_chat(doc) for doc in documents]
# Validate
valid_examples = []
invalid_count = 0
for ex in examples:
if validate_example(ex):
valid_examples.append(ex)
else:
invalid_count += 1
print(f" Valid examples: {len(valid_examples)}")
if invalid_count > 0:
print(f" Invalid (skipped): {invalid_count}")
# Stratified split: ensure test set has both clean and anomalous examples
clean = [ex for ex in valid_examples if not json.loads(ex["messages"][2]["content"])["flags"]]
anomalous = [ex for ex in valid_examples if json.loads(ex["messages"][2]["content"])["flags"]]
random.shuffle(clean)
random.shuffle(anomalous)
test_size = min(args.test_size, len(valid_examples) // 5)
# Allocate ~40% anomalous in test set (matching the overall ratio)
test_anomalous_count = max(1, int(test_size * 0.4))
test_clean_count = test_size - test_anomalous_count
# Ensure we don't exceed available
test_anomalous_count = min(test_anomalous_count, len(anomalous))
test_clean_count = min(test_clean_count, len(clean))
test_examples = anomalous[:test_anomalous_count] + clean[:test_clean_count]
train_examples = anomalous[test_anomalous_count:] + clean[test_clean_count:]
random.shuffle(test_examples)
random.shuffle(train_examples)
# Save train set
os.makedirs(os.path.dirname(args.train_output), exist_ok=True)
with open(args.train_output, "w", encoding="utf-8") as f:
for ex in train_examples:
f.write(json.dumps(ex, ensure_ascii=False) + "\n")
# Save test set
os.makedirs(os.path.dirname(args.test_output), exist_ok=True)
with open(args.test_output, "w", encoding="utf-8") as f:
for ex in test_examples:
f.write(json.dumps(ex, ensure_ascii=False) + "\n")
# Save test ground truth separately (for evaluation)
ground_truth_path = os.path.join(os.path.dirname(args.test_output), "ground_truth.json")
test_ground_truths = []
for ex in test_examples:
gt = json.loads(ex["messages"][2]["content"])
test_ground_truths.append({
"input": ex["messages"][1]["content"],
"expected_output": gt,
})
with open(ground_truth_path, "w", encoding="utf-8") as f:
json.dump(test_ground_truths, f, indent=2, ensure_ascii=False)
# Statistics
train_tokens = sum(compute_token_estimate(json.dumps(ex)) for ex in train_examples)
test_tokens = sum(compute_token_estimate(json.dumps(ex)) for ex in test_examples)
# Count anomalies in each set
train_anomalous = sum(
1 for ex in train_examples
if json.loads(ex["messages"][2]["content"])["flags"]
)
test_anomalous = sum(
1 for ex in test_examples
if json.loads(ex["messages"][2]["content"])["flags"]
)
print(f"\n Split Summary:")
print(f" Train: {len(train_examples)} examples (~{train_tokens:,} tokens)")
print(f" - Clean: {len(train_examples) - train_anomalous}")
print(f" - With anomalies: {train_anomalous}")
print(f" Test: {len(test_examples)} examples (~{test_tokens:,} tokens)")
print(f" - Clean: {len(test_examples) - test_anomalous}")
print(f" - With anomalies: {test_anomalous}")
print(f"\n Saved:")
print(f" Train: {args.train_output}")
print(f" Test: {args.test_output}")
print(f" Ground truth: {ground_truth_path}\n")
if __name__ == "__main__":
main()
|