aml-intelligence-app / scripts /train_all.py
soupstick's picture
rebuild: 5-agent fraud intelligence suite with trained models + FastAPI
cc1ad5a
from __future__ import annotations
import importlib
import json
import sys
from pathlib import Path
import pandas as pd
ROOT_DIR = Path(__file__).resolve().parents[1]
if str(ROOT_DIR) not in sys.path:
sys.path.insert(0, str(ROOT_DIR))
from agents.common import APP_VERSION, DATA_DIR, save_metadata, utc_now_iso
from agents.credit_risk.model import train_model as train_credit_model
from agents.kyc_identity.model import train_model as train_kyc_model
from agents.transaction_fraud.model import train_model as train_transaction_model
from data.generate_all import main as generate_all_datasets
def _typo(name: str) -> str:
parts = name.split()
if not parts:
return name
target = max(parts, key=len)
if len(target) < 5:
return name
chars = list(target)
chars[1], chars[2] = chars[2], chars[1]
mutated = "".join(chars)
return name.replace(target, mutated, 1)
def _validate_sanctions() -> dict[str, object]:
sanctions_module = importlib.import_module("agents.sanctions_pep.matcher")
sanctions_module = importlib.reload(sanctions_module)
sanctions_df = pd.read_csv(DATA_DIR / "sanctions_pep_list.csv")
exact_names = sanctions_df["full_name"].head(5).tolist()
fuzzy_names = [_typo(name) for name in sanctions_df["full_name"].iloc[5:10].tolist()]
misses = ["Alice Walker", "Daniel Mercer", "Nina Holloway", "Peter Whitmore", "Alicia Stone"]
total = 0
passed = 0
for name in exact_names:
total += 1
result = sanctions_module.screen({"query_name": name})
passed += int(result["match_found"] and result["best_match"]["match_score"] == 1.0)
for name in fuzzy_names:
total += 1
result = sanctions_module.screen({"query_name": name})
passed += int(result["match_found"] and result["best_match"]["match_score"] >= 0.7)
for name in misses:
total += 1
result = sanctions_module.screen({"query_name": name})
passed += int((not result["match_found"]) or (result["best_match"]["match_score"] < 0.6))
hit_rate = passed / total
if hit_rate < 0.95:
raise RuntimeError(f"sanctions_pep hit_rate below threshold: {hit_rate:.4f}")
return {
"version": APP_VERSION,
"artifact": str((DATA_DIR / "sanctions_pep_list.csv").relative_to(ROOT_DIR)),
"metrics": {"hit_rate": round(hit_rate, 4)},
}
def main() -> None:
generate_all_datasets()
metadata = {
"version": APP_VERSION,
"training_date": utc_now_iso(),
"models": {},
}
metadata["models"]["transaction_fraud"] = train_transaction_model(
DATA_DIR / "transaction_fraud_train.csv",
DATA_DIR / "transaction_fraud_test.csv",
)
metadata["models"]["credit_risk"] = train_credit_model(
DATA_DIR / "credit_risk_train.csv",
DATA_DIR / "credit_risk_test.csv",
)
metadata["models"]["kyc_identity"] = train_kyc_model(
DATA_DIR / "kyc_identity_train.csv",
DATA_DIR / "kyc_identity_test.csv",
)
metadata["models"]["sanctions_pep"] = _validate_sanctions()
metadata["models"]["risk_consultant"] = {
"version": APP_VERSION,
"artifact": "env:LLM_API_KEY or static FAQ fallback",
"metrics": {"source": "llm_or_static_faq"},
}
save_metadata(metadata)
print(json.dumps(metadata, indent=2))
if __name__ == "__main__":
main()