""" Training Data Formatter. Combines raw documents (with ground truth) into chat-format JSONL ready for Unsloth/SFT training, then splits into train/test sets. Usage: python scripts/prepare_training_data.py --input data/training/with_anomalies.jsonl \ --train-output data/training/train.jsonl \ --test-output data/test/test.jsonl \ --test-size 30 """ import json import random import argparse import os random.seed(42) SYSTEM_PROMPT = """You are a financial document extraction expert. Your task is to: 1. Identify the document type (invoice, purchase_order, receipt, or bank_statement). 2. Extract all relevant fields into a structured JSON object following this exact schema: - "common": document_type, date, issuer (name, address), recipient (name, address), total_amount, currency - "line_items": array of {description, quantity, unit_price, amount} - "type_specific": fields specific to the document type - "flags": array of detected anomalies, each with {category, field, severity, description} - "confidence_score": your confidence in the extraction (0.0 to 1.0) 3. Analyze the document for anomalies across these categories: - arithmetic_error: Mathematical calculations that don't add up - missing_field: Required fields that are absent from the document - format_anomaly: Inconsistent formats, negative quantities, duplicate entries - business_logic: Unusual amounts, suspicious patterns, round-number fraud indicators - cross_field: Mismatched references between related fields or documents 4. If no anomalies are found, return an empty "flags" array. Output ONLY valid JSON. No explanations, no markdown, no code blocks — just the raw JSON object.""" def format_as_chat(doc: dict) -> dict: """ Convert a document dict into chat-format training example. Args: doc: Dict with 'raw_text' and 'ground_truth'. Returns: Chat-format dict with 'messages' array. """ user_message = f"Extract structured data from this financial document:\n\n---\n{doc['raw_text']}\n---" # Compact JSON for the assistant response (no extra whitespace) assistant_response = json.dumps(doc["ground_truth"], separators=(",", ":"), ensure_ascii=False) return { "messages": [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_message}, {"role": "assistant", "content": assistant_response}, ] } def validate_example(example: dict) -> bool: """Check that a training example is well-formed.""" try: messages = example.get("messages", []) if len(messages) != 3: return False # Verify assistant response is valid JSON assistant_content = messages[2]["content"] parsed = json.loads(assistant_content) # Check required keys required = ["common", "flags", "confidence_score"] for key in required: if key not in parsed: return False # Check common has document_type if "document_type" not in parsed.get("common", {}): return False return True except (json.JSONDecodeError, KeyError): return False def compute_token_estimate(text: str) -> int: """Rough token count estimate (1 token ≈ 4 chars for English).""" return len(text) // 4 def main(): parser = argparse.ArgumentParser(description="Prepare training data for fine-tuning") parser.add_argument("--input", type=str, default="data/training/with_anomalies.jsonl") parser.add_argument("--train-output", type=str, default="data/training/train.jsonl") parser.add_argument("--test-output", type=str, default="data/test/test.jsonl") parser.add_argument("--test-size", type=int, default=30) args = parser.parse_args() print(f"\n{'='*50}") print(f" Training Data Formatter") print(f"{'='*50}\n") # Load documents documents = [] with open(args.input, "r", encoding="utf-8") as f: for line in f: documents.append(json.loads(line.strip())) print(f" Loaded {len(documents)} documents") # Format as chat examples examples = [format_as_chat(doc) for doc in documents] # Validate valid_examples = [] invalid_count = 0 for ex in examples: if validate_example(ex): valid_examples.append(ex) else: invalid_count += 1 print(f" Valid examples: {len(valid_examples)}") if invalid_count > 0: print(f" Invalid (skipped): {invalid_count}") # Stratified split: ensure test set has both clean and anomalous examples clean = [ex for ex in valid_examples if not json.loads(ex["messages"][2]["content"])["flags"]] anomalous = [ex for ex in valid_examples if json.loads(ex["messages"][2]["content"])["flags"]] random.shuffle(clean) random.shuffle(anomalous) test_size = min(args.test_size, len(valid_examples) // 5) # Allocate ~40% anomalous in test set (matching the overall ratio) test_anomalous_count = max(1, int(test_size * 0.4)) test_clean_count = test_size - test_anomalous_count # Ensure we don't exceed available test_anomalous_count = min(test_anomalous_count, len(anomalous)) test_clean_count = min(test_clean_count, len(clean)) test_examples = anomalous[:test_anomalous_count] + clean[:test_clean_count] train_examples = anomalous[test_anomalous_count:] + clean[test_clean_count:] random.shuffle(test_examples) random.shuffle(train_examples) # Save train set os.makedirs(os.path.dirname(args.train_output), exist_ok=True) with open(args.train_output, "w", encoding="utf-8") as f: for ex in train_examples: f.write(json.dumps(ex, ensure_ascii=False) + "\n") # Save test set os.makedirs(os.path.dirname(args.test_output), exist_ok=True) with open(args.test_output, "w", encoding="utf-8") as f: for ex in test_examples: f.write(json.dumps(ex, ensure_ascii=False) + "\n") # Save test ground truth separately (for evaluation) ground_truth_path = os.path.join(os.path.dirname(args.test_output), "ground_truth.json") test_ground_truths = [] for ex in test_examples: gt = json.loads(ex["messages"][2]["content"]) test_ground_truths.append({ "input": ex["messages"][1]["content"], "expected_output": gt, }) with open(ground_truth_path, "w", encoding="utf-8") as f: json.dump(test_ground_truths, f, indent=2, ensure_ascii=False) # Statistics train_tokens = sum(compute_token_estimate(json.dumps(ex)) for ex in train_examples) test_tokens = sum(compute_token_estimate(json.dumps(ex)) for ex in test_examples) # Count anomalies in each set train_anomalous = sum( 1 for ex in train_examples if json.loads(ex["messages"][2]["content"])["flags"] ) test_anomalous = sum( 1 for ex in test_examples if json.loads(ex["messages"][2]["content"])["flags"] ) print(f"\n Split Summary:") print(f" Train: {len(train_examples)} examples (~{train_tokens:,} tokens)") print(f" - Clean: {len(train_examples) - train_anomalous}") print(f" - With anomalies: {train_anomalous}") print(f" Test: {len(test_examples)} examples (~{test_tokens:,} tokens)") print(f" - Clean: {len(test_examples) - test_anomalous}") print(f" - With anomalies: {test_anomalous}") print(f"\n Saved:") print(f" Train: {args.train_output}") print(f" Test: {args.test_output}") print(f" Ground truth: {ground_truth_path}\n") if __name__ == "__main__": main()