Corp_AI / src /auditenv /datasets /schema_mapper.py
Arpit Deep
feat: initial AuditEnv submission
a617acd
from __future__ import annotations
from typing import Any, Dict, List
def map_procurement_to_documents(
tables: Dict[str, Any],
task_prefix: str,
max_docs: int,
column_mapping: Dict[str, str] | None = None,
) -> List[Dict[str, str]]:
invoices = tables["invoices"]
column_mapping = column_mapping or {}
doc_id_col = _resolve_column(
invoices.columns,
column_mapping.get("invoice_id"),
["invoice_id", "id", "InvoiceID", "invoice_number"],
)
amount_col = _resolve_column(
invoices.columns,
column_mapping.get("amount"),
["amount", "invoice_amount", "Amount"],
)
vendor_col = _resolve_column(
invoices.columns,
column_mapping.get("vendor_id"),
["supplier_id", "vendor_id", "SupplierID"],
)
dept_col = _resolve_column(
invoices.columns,
column_mapping.get("department_id"),
["department_id", "DepartmentID"],
)
date_col = _resolve_column(
invoices.columns,
column_mapping.get("invoice_date"),
["invoice_date", "date", "InvoiceDate"],
)
docs: List[Dict[str, str]] = []
for idx, row in invoices.head(max_docs).iterrows():
source_id = str(row.get(doc_id_col, f"SRC-{idx:05d}")) if doc_id_col else f"SRC-{idx:05d}"
amount = str(row.get(amount_col, "")) if amount_col else ""
vendor = str(row.get(vendor_col, "")) if vendor_col else ""
dept = str(row.get(dept_col, "")) if dept_col else ""
date = str(row.get(date_col, "")) if date_col else ""
dup_flag = str(row.get("is_duplicate_invoice_id", ""))
dup_group_size = str(row.get("duplicate_invoice_group_size", ""))
doc_id = f"{task_prefix}-DOC-{idx:04d}"
text = (
f"invoice_id={source_id}; amount={amount}; vendor={vendor}; department={dept}; "
f"invoice_date={date}; is_duplicate_invoice_id={dup_flag}; "
f"duplicate_invoice_group_size={dup_group_size}"
)
docs.append({"id": doc_id, "type": "invoice", "text": text})
return docs
def map_hf_fraud_rows_to_signals(rows: list[dict[str, Any]], max_rows: int) -> list[str]:
signals: list[str] = []
for row in rows[:max_rows]:
pieces = []
for key in ["Company", "Label", "Fillings", "Filing", "text"]:
if key in row and row[key] is not None:
pieces.append(f"{key}={str(row[key])[:120]}")
if pieces:
signals.append(" | ".join(pieces))
return signals
def _first_existing_column(columns: Any, candidates: list[str]) -> str | None:
lower_map = {str(c).lower(): str(c) for c in columns}
for cand in candidates:
if cand.lower() in lower_map:
return lower_map[cand.lower()]
return None
def _resolve_column(columns: Any, preferred: str | None, fallbacks: list[str]) -> str | None:
if preferred:
lower_map = {str(c).lower(): str(c) for c in columns}
hit = lower_map.get(preferred.lower())
if hit:
return hit
return _first_existing_column(columns, fallbacks)