| """ |
| Post-Extraction Validation Engine. |
| |
| Performs programmatic validation checks on extracted financial data |
| that complement the model's anomaly detection: |
| |
| 1. Arithmetic Consistency — line items sum to subtotal, subtotal + tax = total |
| 2. Required Field Completeness — checks for missing critical fields per doc type |
| 3. Date Format Validation — ensures dates are valid and reasonable |
| 4. Cross-Field Reference Checks — currency consistency, PO references |
| |
| Usage: |
| from src.validator import validate_extraction |
| |
| extra_flags = validate_extraction(json_data) |
| # Returns list of additional anomaly flag dicts |
| """ |
|
|
| import re |
| from datetime import datetime |
| from typing import List, Optional |
|
|
|
|
| def validate_extraction(data: dict) -> List[dict]: |
| """ |
| Run all validation checks on extracted document data. |
| |
| Args: |
| data: Parsed JSON dict from model output. |
| |
| Returns: |
| List of additional anomaly flags not detected by the model. |
| """ |
| flags = [] |
| |
| flags.extend(_check_arithmetic(data)) |
| flags.extend(_check_required_fields(data)) |
| flags.extend(_check_date_formats(data)) |
| flags.extend(_check_cross_field(data)) |
| flags.extend(_check_business_logic(data)) |
| |
| return flags |
|
|
|
|
| def _check_arithmetic(data: dict) -> List[dict]: |
| """Verify that math adds up in the document.""" |
| flags = [] |
| common = data.get("common", {}) |
| line_items = data.get("line_items", []) |
| type_specific = data.get("type_specific", {}) |
| |
| |
| for i, item in enumerate(line_items or []): |
| qty = item.get("quantity") |
| price = item.get("unit_price") |
| amount = item.get("amount") |
| |
| if qty is not None and price is not None and amount is not None: |
| try: |
| expected = round(float(qty) * float(price), 2) |
| actual = round(float(amount), 2) |
| if abs(expected - actual) > 0.02: |
| flags.append({ |
| "category": "arithmetic_error", |
| "field": f"line_items[{i}].amount", |
| "severity": "high", |
| "description": ( |
| f"Line item '{item.get('description', '?')}': " |
| f"amount {actual} ≠ quantity ({qty}) × unit_price ({price}) = {expected}" |
| ), |
| "source": "validator", |
| }) |
| except (ValueError, TypeError): |
| pass |
| |
| |
| if line_items: |
| try: |
| items_sum = round(sum( |
| float(item.get("amount", 0) or 0) for item in line_items |
| ), 2) |
| |
| subtotal = type_specific.get("subtotal") |
| if subtotal is not None: |
| subtotal = round(float(subtotal), 2) |
| if abs(items_sum - subtotal) > 0.05: |
| flags.append({ |
| "category": "arithmetic_error", |
| "field": "type_specific.subtotal", |
| "severity": "high", |
| "description": ( |
| f"Sum of line items ({items_sum}) ≠ subtotal ({subtotal}). " |
| f"Discrepancy: {abs(items_sum - subtotal):.2f}" |
| ), |
| "source": "validator", |
| }) |
| except (ValueError, TypeError): |
| pass |
| |
| |
| subtotal = type_specific.get("subtotal") |
| tax = type_specific.get("tax_amount") |
| total = common.get("total_amount") |
| |
| if subtotal is not None and tax is not None and total is not None: |
| try: |
| expected_total = round(float(subtotal) + float(tax), 2) |
| actual_total = round(float(total), 2) |
| if abs(expected_total - actual_total) > 0.05: |
| flags.append({ |
| "category": "arithmetic_error", |
| "field": "common.total_amount", |
| "severity": "high", |
| "description": ( |
| f"Total ({actual_total}) ≠ subtotal ({subtotal}) + tax ({tax}) = {expected_total}. " |
| f"Discrepancy: {abs(expected_total - actual_total):.2f}" |
| ), |
| "source": "validator", |
| }) |
| except (ValueError, TypeError): |
| pass |
| |
| |
| if common.get("document_type") == "bank_statement": |
| opening = type_specific.get("opening_balance") |
| closing = type_specific.get("closing_balance") |
| |
| if opening is not None and closing is not None and line_items: |
| try: |
| txn_sum = sum(float(item.get("amount", 0) or 0) for item in line_items) |
| expected_closing = round(float(opening) + txn_sum, 2) |
| actual_closing = round(float(closing), 2) |
| |
| if abs(expected_closing - actual_closing) > 0.10: |
| flags.append({ |
| "category": "arithmetic_error", |
| "field": "type_specific.closing_balance", |
| "severity": "high", |
| "description": ( |
| f"Closing balance ({actual_closing}) ≠ opening ({opening}) + " |
| f"transactions ({txn_sum:.2f}) = {expected_closing}" |
| ), |
| "source": "validator", |
| }) |
| except (ValueError, TypeError): |
| pass |
| |
| return flags |
|
|
|
|
| def _check_required_fields(data: dict) -> List[dict]: |
| """Check for missing critical fields based on document type.""" |
| flags = [] |
| common = data.get("common", {}) |
| type_specific = data.get("type_specific", {}) |
| doc_type = common.get("document_type", "") |
| |
| |
| universal = { |
| "common.date": common.get("date"), |
| "common.total_amount": common.get("total_amount"), |
| "common.issuer": common.get("issuer"), |
| } |
| |
| for field_path, value in universal.items(): |
| if value is None: |
| flags.append({ |
| "category": "missing_field", |
| "field": field_path, |
| "severity": "medium", |
| "description": f"Required field '{field_path}' is missing.", |
| "source": "validator", |
| }) |
| |
| |
| required_by_type = { |
| "invoice": ["invoice_number", "due_date", "subtotal"], |
| "purchase_order": ["po_number", "delivery_date"], |
| "receipt": ["receipt_number"], |
| "bank_statement": ["account_number", "opening_balance", "closing_balance"], |
| } |
| |
| for field_name in required_by_type.get(doc_type, []): |
| if type_specific.get(field_name) is None: |
| flags.append({ |
| "category": "missing_field", |
| "field": f"type_specific.{field_name}", |
| "severity": "low", |
| "description": f"Expected field '{field_name}' for {doc_type} is missing.", |
| "source": "validator", |
| }) |
| |
| |
| issuer = common.get("issuer") |
| if isinstance(issuer, dict) and not issuer.get("name"): |
| flags.append({ |
| "category": "missing_field", |
| "field": "common.issuer.name", |
| "severity": "medium", |
| "description": "Issuer entity is present but name is missing.", |
| "source": "validator", |
| }) |
| |
| return flags |
|
|
|
|
| def _check_date_formats(data: dict) -> List[dict]: |
| """Validate date fields are in proper format and reasonable range.""" |
| flags = [] |
| common = data.get("common", {}) |
| type_specific = data.get("type_specific", {}) |
| |
| date_fields = { |
| "common.date": common.get("date"), |
| "type_specific.due_date": type_specific.get("due_date"), |
| "type_specific.delivery_date": type_specific.get("delivery_date"), |
| } |
| |
| for field_path, date_str in date_fields.items(): |
| if date_str is None: |
| continue |
| |
| if not isinstance(date_str, str): |
| continue |
| |
| |
| date_pattern = r'^\d{4}-\d{2}-\d{2}$' |
| if not re.match(date_pattern, date_str): |
| flags.append({ |
| "category": "format_anomaly", |
| "field": field_path, |
| "severity": "medium", |
| "description": f"Date '{date_str}' is not in standard YYYY-MM-DD format.", |
| "source": "validator", |
| }) |
| continue |
| |
| |
| try: |
| parsed = datetime.strptime(date_str, "%Y-%m-%d") |
| |
| |
| now = datetime.now() |
| if parsed.year < 2000: |
| flags.append({ |
| "category": "format_anomaly", |
| "field": field_path, |
| "severity": "low", |
| "description": f"Date '{date_str}' is before year 2000, which is unusual.", |
| "source": "validator", |
| }) |
| elif parsed > now.replace(year=now.year + 2): |
| flags.append({ |
| "category": "format_anomaly", |
| "field": field_path, |
| "severity": "medium", |
| "description": f"Date '{date_str}' is more than 2 years in the future.", |
| "source": "validator", |
| }) |
| except ValueError: |
| flags.append({ |
| "category": "format_anomaly", |
| "field": field_path, |
| "severity": "medium", |
| "description": f"Date '{date_str}' is not a valid calendar date.", |
| "source": "validator", |
| }) |
| |
| return flags |
|
|
|
|
| def _check_cross_field(data: dict) -> List[dict]: |
| """Check for inconsistencies between related fields.""" |
| flags = [] |
| common = data.get("common", {}) |
| type_specific = data.get("type_specific", {}) |
| |
| |
| doc_date = common.get("date") |
| due_date = type_specific.get("due_date") |
| |
| if doc_date and due_date: |
| try: |
| d1 = datetime.strptime(doc_date, "%Y-%m-%d") |
| d2 = datetime.strptime(due_date, "%Y-%m-%d") |
| if d2 < d1: |
| flags.append({ |
| "category": "cross_field", |
| "field": "type_specific.due_date", |
| "severity": "high", |
| "description": f"Due date ({due_date}) is before document date ({doc_date}).", |
| "source": "validator", |
| }) |
| except ValueError: |
| pass |
| |
| |
| delivery_date = type_specific.get("delivery_date") |
| if doc_date and delivery_date: |
| try: |
| d1 = datetime.strptime(doc_date, "%Y-%m-%d") |
| d2 = datetime.strptime(delivery_date, "%Y-%m-%d") |
| if d2 < d1: |
| flags.append({ |
| "category": "cross_field", |
| "field": "type_specific.delivery_date", |
| "severity": "medium", |
| "description": f"Delivery date ({delivery_date}) is before PO date ({doc_date}).", |
| "source": "validator", |
| }) |
| except ValueError: |
| pass |
| |
| |
| total = common.get("total_amount") |
| if total is not None: |
| try: |
| if float(total) < 0: |
| flags.append({ |
| "category": "cross_field", |
| "field": "common.total_amount", |
| "severity": "high", |
| "description": f"Total amount is negative ({total}), which is unusual.", |
| "source": "validator", |
| }) |
| except (ValueError, TypeError): |
| pass |
| |
| return flags |
|
|
|
|
| def _check_business_logic(data: dict) -> List[dict]: |
| """Check for business logic red flags.""" |
| flags = [] |
| common = data.get("common", {}) |
| line_items = data.get("line_items", []) |
| |
| total = common.get("total_amount") |
| |
| if total is not None: |
| try: |
| total_val = float(total) |
| |
| |
| if total_val > 1_000_000: |
| flags.append({ |
| "category": "business_logic", |
| "field": "common.total_amount", |
| "severity": "high", |
| "description": f"Total amount ${total_val:,.2f} exceeds $1M — requires review.", |
| "source": "validator", |
| }) |
| |
| |
| if total_val >= 10_000 and total_val == int(total_val) and total_val % 1000 == 0: |
| flags.append({ |
| "category": "business_logic", |
| "field": "common.total_amount", |
| "severity": "medium", |
| "description": ( |
| f"Total amount ${total_val:,.2f} is a perfectly round number — " |
| f"potential fraud indicator." |
| ), |
| "source": "validator", |
| }) |
| except (ValueError, TypeError): |
| pass |
| |
| |
| for i, item in enumerate(line_items or []): |
| qty = item.get("quantity") |
| if qty is not None: |
| try: |
| if float(qty) < 0: |
| flags.append({ |
| "category": "format_anomaly", |
| "field": f"line_items[{i}].quantity", |
| "severity": "medium", |
| "description": ( |
| f"Line item '{item.get('description', '?')}' has negative " |
| f"quantity ({qty})." |
| ), |
| "source": "validator", |
| }) |
| except (ValueError, TypeError): |
| pass |
| |
| return flags |
|
|
|
|
| def merge_flags(model_flags: list, validator_flags: list) -> list: |
| """ |
| Merge model-detected and validator-detected flags, removing duplicates. |
| |
| Deduplication is based on (category, field) pairs. |
| |
| Args: |
| model_flags: Flags from the model output. |
| validator_flags: Flags from programmatic validation. |
| |
| Returns: |
| Combined list of unique flags. |
| """ |
| seen = set() |
| merged = [] |
| |
| |
| for flag in model_flags: |
| key = (flag.get("category", ""), flag.get("field", "")) |
| if key not in seen: |
| seen.add(key) |
| merged.append(flag) |
| |
| |
| for flag in validator_flags: |
| key = (flag.get("category", ""), flag.get("field", "")) |
| if key not in seen: |
| seen.add(key) |
| merged.append(flag) |
| |
| return merged |
|
|