| """ |
| Training Data Formatter. |
| |
| Combines raw documents (with ground truth) into chat-format JSONL |
| ready for Unsloth/SFT training, then splits into train/test sets. |
| |
| Usage: |
| python scripts/prepare_training_data.py --input data/training/with_anomalies.jsonl \ |
| --train-output data/training/train.jsonl \ |
| --test-output data/test/test.jsonl \ |
| --test-size 30 |
| """ |
|
|
| import json |
| import random |
| import argparse |
| import os |
|
|
|
|
| random.seed(42) |
|
|
| SYSTEM_PROMPT = """You are a financial document extraction expert. Your task is to: |
| |
| 1. Identify the document type (invoice, purchase_order, receipt, or bank_statement). |
| 2. Extract all relevant fields into a structured JSON object following this exact schema: |
| - "common": document_type, date, issuer (name, address), recipient (name, address), total_amount, currency |
| - "line_items": array of {description, quantity, unit_price, amount} |
| - "type_specific": fields specific to the document type |
| - "flags": array of detected anomalies, each with {category, field, severity, description} |
| - "confidence_score": your confidence in the extraction (0.0 to 1.0) |
| |
| 3. Analyze the document for anomalies across these categories: |
| - arithmetic_error: Mathematical calculations that don't add up |
| - missing_field: Required fields that are absent from the document |
| - format_anomaly: Inconsistent formats, negative quantities, duplicate entries |
| - business_logic: Unusual amounts, suspicious patterns, round-number fraud indicators |
| - cross_field: Mismatched references between related fields or documents |
| |
| 4. If no anomalies are found, return an empty "flags" array. |
| |
| Output ONLY valid JSON. No explanations, no markdown, no code blocks — just the raw JSON object.""" |
|
|
|
|
| def format_as_chat(doc: dict) -> dict: |
| """ |
| Convert a document dict into chat-format training example. |
| |
| Args: |
| doc: Dict with 'raw_text' and 'ground_truth'. |
| |
| Returns: |
| Chat-format dict with 'messages' array. |
| """ |
| user_message = f"Extract structured data from this financial document:\n\n---\n{doc['raw_text']}\n---" |
| |
| |
| assistant_response = json.dumps(doc["ground_truth"], separators=(",", ":"), ensure_ascii=False) |
| |
| return { |
| "messages": [ |
| {"role": "system", "content": SYSTEM_PROMPT}, |
| {"role": "user", "content": user_message}, |
| {"role": "assistant", "content": assistant_response}, |
| ] |
| } |
|
|
|
|
| def validate_example(example: dict) -> bool: |
| """Check that a training example is well-formed.""" |
| try: |
| messages = example.get("messages", []) |
| if len(messages) != 3: |
| return False |
| |
| |
| assistant_content = messages[2]["content"] |
| parsed = json.loads(assistant_content) |
| |
| |
| required = ["common", "flags", "confidence_score"] |
| for key in required: |
| if key not in parsed: |
| return False |
| |
| |
| if "document_type" not in parsed.get("common", {}): |
| return False |
| |
| return True |
| except (json.JSONDecodeError, KeyError): |
| return False |
|
|
|
|
| def compute_token_estimate(text: str) -> int: |
| """Rough token count estimate (1 token ≈ 4 chars for English).""" |
| return len(text) // 4 |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Prepare training data for fine-tuning") |
| parser.add_argument("--input", type=str, default="data/training/with_anomalies.jsonl") |
| parser.add_argument("--train-output", type=str, default="data/training/train.jsonl") |
| parser.add_argument("--test-output", type=str, default="data/test/test.jsonl") |
| parser.add_argument("--test-size", type=int, default=30) |
| args = parser.parse_args() |
| |
| print(f"\n{'='*50}") |
| print(f" Training Data Formatter") |
| print(f"{'='*50}\n") |
| |
| |
| documents = [] |
| with open(args.input, "r", encoding="utf-8") as f: |
| for line in f: |
| documents.append(json.loads(line.strip())) |
| |
| print(f" Loaded {len(documents)} documents") |
| |
| |
| examples = [format_as_chat(doc) for doc in documents] |
| |
| |
| valid_examples = [] |
| invalid_count = 0 |
| for ex in examples: |
| if validate_example(ex): |
| valid_examples.append(ex) |
| else: |
| invalid_count += 1 |
| |
| print(f" Valid examples: {len(valid_examples)}") |
| if invalid_count > 0: |
| print(f" Invalid (skipped): {invalid_count}") |
| |
| |
| clean = [ex for ex in valid_examples if not json.loads(ex["messages"][2]["content"])["flags"]] |
| anomalous = [ex for ex in valid_examples if json.loads(ex["messages"][2]["content"])["flags"]] |
| |
| random.shuffle(clean) |
| random.shuffle(anomalous) |
| |
| test_size = min(args.test_size, len(valid_examples) // 5) |
| |
| |
| test_anomalous_count = max(1, int(test_size * 0.4)) |
| test_clean_count = test_size - test_anomalous_count |
| |
| |
| test_anomalous_count = min(test_anomalous_count, len(anomalous)) |
| test_clean_count = min(test_clean_count, len(clean)) |
| |
| test_examples = anomalous[:test_anomalous_count] + clean[:test_clean_count] |
| train_examples = anomalous[test_anomalous_count:] + clean[test_clean_count:] |
| |
| random.shuffle(test_examples) |
| random.shuffle(train_examples) |
| |
| |
| os.makedirs(os.path.dirname(args.train_output), exist_ok=True) |
| with open(args.train_output, "w", encoding="utf-8") as f: |
| for ex in train_examples: |
| f.write(json.dumps(ex, ensure_ascii=False) + "\n") |
| |
| |
| os.makedirs(os.path.dirname(args.test_output), exist_ok=True) |
| with open(args.test_output, "w", encoding="utf-8") as f: |
| for ex in test_examples: |
| f.write(json.dumps(ex, ensure_ascii=False) + "\n") |
| |
| |
| ground_truth_path = os.path.join(os.path.dirname(args.test_output), "ground_truth.json") |
| test_ground_truths = [] |
| for ex in test_examples: |
| gt = json.loads(ex["messages"][2]["content"]) |
| test_ground_truths.append({ |
| "input": ex["messages"][1]["content"], |
| "expected_output": gt, |
| }) |
| |
| with open(ground_truth_path, "w", encoding="utf-8") as f: |
| json.dump(test_ground_truths, f, indent=2, ensure_ascii=False) |
| |
| |
| train_tokens = sum(compute_token_estimate(json.dumps(ex)) for ex in train_examples) |
| test_tokens = sum(compute_token_estimate(json.dumps(ex)) for ex in test_examples) |
| |
| |
| train_anomalous = sum( |
| 1 for ex in train_examples |
| if json.loads(ex["messages"][2]["content"])["flags"] |
| ) |
| test_anomalous = sum( |
| 1 for ex in test_examples |
| if json.loads(ex["messages"][2]["content"])["flags"] |
| ) |
| |
| print(f"\n Split Summary:") |
| print(f" Train: {len(train_examples)} examples (~{train_tokens:,} tokens)") |
| print(f" - Clean: {len(train_examples) - train_anomalous}") |
| print(f" - With anomalies: {train_anomalous}") |
| print(f" Test: {len(test_examples)} examples (~{test_tokens:,} tokens)") |
| print(f" - Clean: {len(test_examples) - test_anomalous}") |
| print(f" - With anomalies: {test_anomalous}") |
| print(f"\n Saved:") |
| print(f" Train: {args.train_output}") |
| print(f" Test: {args.test_output}") |
| print(f" Ground truth: {ground_truth_path}\n") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|