""" Post-Extraction Validation Engine. Performs programmatic validation checks on extracted financial data that complement the model's anomaly detection: 1. Arithmetic Consistency — line items sum to subtotal, subtotal + tax = total 2. Required Field Completeness — checks for missing critical fields per doc type 3. Date Format Validation — ensures dates are valid and reasonable 4. Cross-Field Reference Checks — currency consistency, PO references Usage: from src.validator import validate_extraction extra_flags = validate_extraction(json_data) # Returns list of additional anomaly flag dicts """ import re from datetime import datetime from typing import List, Optional def validate_extraction(data: dict) -> List[dict]: """ Run all validation checks on extracted document data. Args: data: Parsed JSON dict from model output. Returns: List of additional anomaly flags not detected by the model. """ flags = [] flags.extend(_check_arithmetic(data)) flags.extend(_check_required_fields(data)) flags.extend(_check_date_formats(data)) flags.extend(_check_cross_field(data)) flags.extend(_check_business_logic(data)) return flags def _check_arithmetic(data: dict) -> List[dict]: """Verify that math adds up in the document.""" flags = [] common = data.get("common", {}) line_items = data.get("line_items", []) type_specific = data.get("type_specific", {}) # Check 1: Line item amounts = quantity × unit_price for i, item in enumerate(line_items or []): qty = item.get("quantity") price = item.get("unit_price") amount = item.get("amount") if qty is not None and price is not None and amount is not None: try: expected = round(float(qty) * float(price), 2) actual = round(float(amount), 2) if abs(expected - actual) > 0.02: # 2 cent tolerance flags.append({ "category": "arithmetic_error", "field": f"line_items[{i}].amount", "severity": "high", "description": ( f"Line item '{item.get('description', '?')}': " f"amount {actual} ≠ quantity ({qty}) × unit_price ({price}) = {expected}" ), "source": "validator", }) except (ValueError, TypeError): pass # Check 2: Line items sum to subtotal if line_items: try: items_sum = round(sum( float(item.get("amount", 0) or 0) for item in line_items ), 2) subtotal = type_specific.get("subtotal") if subtotal is not None: subtotal = round(float(subtotal), 2) if abs(items_sum - subtotal) > 0.05: flags.append({ "category": "arithmetic_error", "field": "type_specific.subtotal", "severity": "high", "description": ( f"Sum of line items ({items_sum}) ≠ subtotal ({subtotal}). " f"Discrepancy: {abs(items_sum - subtotal):.2f}" ), "source": "validator", }) except (ValueError, TypeError): pass # Check 3: Subtotal + tax = total subtotal = type_specific.get("subtotal") tax = type_specific.get("tax_amount") total = common.get("total_amount") if subtotal is not None and tax is not None and total is not None: try: expected_total = round(float(subtotal) + float(tax), 2) actual_total = round(float(total), 2) if abs(expected_total - actual_total) > 0.05: flags.append({ "category": "arithmetic_error", "field": "common.total_amount", "severity": "high", "description": ( f"Total ({actual_total}) ≠ subtotal ({subtotal}) + tax ({tax}) = {expected_total}. " f"Discrepancy: {abs(expected_total - actual_total):.2f}" ), "source": "validator", }) except (ValueError, TypeError): pass # Check 4: Bank statement — opening + transactions = closing if common.get("document_type") == "bank_statement": opening = type_specific.get("opening_balance") closing = type_specific.get("closing_balance") if opening is not None and closing is not None and line_items: try: txn_sum = sum(float(item.get("amount", 0) or 0) for item in line_items) expected_closing = round(float(opening) + txn_sum, 2) actual_closing = round(float(closing), 2) if abs(expected_closing - actual_closing) > 0.10: flags.append({ "category": "arithmetic_error", "field": "type_specific.closing_balance", "severity": "high", "description": ( f"Closing balance ({actual_closing}) ≠ opening ({opening}) + " f"transactions ({txn_sum:.2f}) = {expected_closing}" ), "source": "validator", }) except (ValueError, TypeError): pass return flags def _check_required_fields(data: dict) -> List[dict]: """Check for missing critical fields based on document type.""" flags = [] common = data.get("common", {}) type_specific = data.get("type_specific", {}) doc_type = common.get("document_type", "") # Universal required fields universal = { "common.date": common.get("date"), "common.total_amount": common.get("total_amount"), "common.issuer": common.get("issuer"), } for field_path, value in universal.items(): if value is None: flags.append({ "category": "missing_field", "field": field_path, "severity": "medium", "description": f"Required field '{field_path}' is missing.", "source": "validator", }) # Type-specific required fields required_by_type = { "invoice": ["invoice_number", "due_date", "subtotal"], "purchase_order": ["po_number", "delivery_date"], "receipt": ["receipt_number"], "bank_statement": ["account_number", "opening_balance", "closing_balance"], } for field_name in required_by_type.get(doc_type, []): if type_specific.get(field_name) is None: flags.append({ "category": "missing_field", "field": f"type_specific.{field_name}", "severity": "low", "description": f"Expected field '{field_name}' for {doc_type} is missing.", "source": "validator", }) # Check issuer has at least a name issuer = common.get("issuer") if isinstance(issuer, dict) and not issuer.get("name"): flags.append({ "category": "missing_field", "field": "common.issuer.name", "severity": "medium", "description": "Issuer entity is present but name is missing.", "source": "validator", }) return flags def _check_date_formats(data: dict) -> List[dict]: """Validate date fields are in proper format and reasonable range.""" flags = [] common = data.get("common", {}) type_specific = data.get("type_specific", {}) date_fields = { "common.date": common.get("date"), "type_specific.due_date": type_specific.get("due_date"), "type_specific.delivery_date": type_specific.get("delivery_date"), } for field_path, date_str in date_fields.items(): if date_str is None: continue if not isinstance(date_str, str): continue # Check YYYY-MM-DD format date_pattern = r'^\d{4}-\d{2}-\d{2}$' if not re.match(date_pattern, date_str): flags.append({ "category": "format_anomaly", "field": field_path, "severity": "medium", "description": f"Date '{date_str}' is not in standard YYYY-MM-DD format.", "source": "validator", }) continue # Check if date is actually valid try: parsed = datetime.strptime(date_str, "%Y-%m-%d") # Check reasonable range (not before year 2000 or more than 2 years in future) now = datetime.now() if parsed.year < 2000: flags.append({ "category": "format_anomaly", "field": field_path, "severity": "low", "description": f"Date '{date_str}' is before year 2000, which is unusual.", "source": "validator", }) elif parsed > now.replace(year=now.year + 2): flags.append({ "category": "format_anomaly", "field": field_path, "severity": "medium", "description": f"Date '{date_str}' is more than 2 years in the future.", "source": "validator", }) except ValueError: flags.append({ "category": "format_anomaly", "field": field_path, "severity": "medium", "description": f"Date '{date_str}' is not a valid calendar date.", "source": "validator", }) return flags def _check_cross_field(data: dict) -> List[dict]: """Check for inconsistencies between related fields.""" flags = [] common = data.get("common", {}) type_specific = data.get("type_specific", {}) # Check: due_date should be after invoice date doc_date = common.get("date") due_date = type_specific.get("due_date") if doc_date and due_date: try: d1 = datetime.strptime(doc_date, "%Y-%m-%d") d2 = datetime.strptime(due_date, "%Y-%m-%d") if d2 < d1: flags.append({ "category": "cross_field", "field": "type_specific.due_date", "severity": "high", "description": f"Due date ({due_date}) is before document date ({doc_date}).", "source": "validator", }) except ValueError: pass # Check: delivery date should be after PO date delivery_date = type_specific.get("delivery_date") if doc_date and delivery_date: try: d1 = datetime.strptime(doc_date, "%Y-%m-%d") d2 = datetime.strptime(delivery_date, "%Y-%m-%d") if d2 < d1: flags.append({ "category": "cross_field", "field": "type_specific.delivery_date", "severity": "medium", "description": f"Delivery date ({delivery_date}) is before PO date ({doc_date}).", "source": "validator", }) except ValueError: pass # Check: negative total amount total = common.get("total_amount") if total is not None: try: if float(total) < 0: flags.append({ "category": "cross_field", "field": "common.total_amount", "severity": "high", "description": f"Total amount is negative ({total}), which is unusual.", "source": "validator", }) except (ValueError, TypeError): pass return flags def _check_business_logic(data: dict) -> List[dict]: """Check for business logic red flags.""" flags = [] common = data.get("common", {}) line_items = data.get("line_items", []) total = common.get("total_amount") if total is not None: try: total_val = float(total) # Extremely large amounts if total_val > 1_000_000: flags.append({ "category": "business_logic", "field": "common.total_amount", "severity": "high", "description": f"Total amount ${total_val:,.2f} exceeds $1M — requires review.", "source": "validator", }) # Perfectly round large amounts (potential fraud indicator) if total_val >= 10_000 and total_val == int(total_val) and total_val % 1000 == 0: flags.append({ "category": "business_logic", "field": "common.total_amount", "severity": "medium", "description": ( f"Total amount ${total_val:,.2f} is a perfectly round number — " f"potential fraud indicator." ), "source": "validator", }) except (ValueError, TypeError): pass # Check for negative quantities in line items for i, item in enumerate(line_items or []): qty = item.get("quantity") if qty is not None: try: if float(qty) < 0: flags.append({ "category": "format_anomaly", "field": f"line_items[{i}].quantity", "severity": "medium", "description": ( f"Line item '{item.get('description', '?')}' has negative " f"quantity ({qty})." ), "source": "validator", }) except (ValueError, TypeError): pass return flags def merge_flags(model_flags: list, validator_flags: list) -> list: """ Merge model-detected and validator-detected flags, removing duplicates. Deduplication is based on (category, field) pairs. Args: model_flags: Flags from the model output. validator_flags: Flags from programmatic validation. Returns: Combined list of unique flags. """ seen = set() merged = [] # Model flags take priority for flag in model_flags: key = (flag.get("category", ""), flag.get("field", "")) if key not in seen: seen.add(key) merged.append(flag) # Add validator flags that aren't duplicates for flag in validator_flags: key = (flag.get("category", ""), flag.get("field", "")) if key not in seen: seen.add(key) merged.append(flag) return merged