AML_Shield / app.py
AJAY KASU
Feature: Add column mapping and automated model evaluation for is_fraud datasets
132848a
import os
from dotenv import load_dotenv
# Load for local testing only
if os.path.exists(".env"):
load_dotenv()
import streamlit as st
import pandas as pd
import plotly.express as px
import time
from datetime import datetime
# Initialize database connections properly
# Assuming database uses Supabase, if Supabase client init fails due to missing secrets, app will handle gracefully
import modules.database as db
from modules.etl import load_and_validate, engineer_features
from modules.detection import apply_detection
from modules.risk_profiling import build_customer_profiles, assign_kyc_tier
import modules.visualizations as viz
from modules.ai_agent import stream_compliance_report
from modules.pdf_report import build_pdf
st.set_page_config(page_title="AML Shield", page_icon="πŸ›‘οΈ", layout="wide", initial_sidebar_state="expanded")
# --- CSS Overrides ---
st.markdown("""
<style>
.stMetric { background-color: #1e1e2d; padding: 15px; border-radius: 5px; }
h1, h2, h3 { color: #f8f9fa; }
</style>
""", unsafe_allow_html=True)
# --- Sidebar ---
st.sidebar.title("πŸ›‘οΈ AML Shield Navigation")
tabs = ["Upload & Analyze", "Dashboard", "Customer Profiles", "AI Report", "Global Analytics", "About"]
page = st.sidebar.radio("Go to", tabs)
# --- Helper logic for analysis ---
def run_pipeline(file_obj, filename="Uploaded Data"):
progress_bar = st.progress(0)
status_text = st.empty()
# Step 1
status_text.text("[β–“β–‘β–‘β–‘β–‘] Loading & validating data...")
df, msg = load_and_validate(file_obj)
if df is None:
st.error(msg)
progress_bar.empty()
status_text.empty()
return False
progress_bar.progress(25)
time.sleep(0.5)
# Step 2
status_text.text("[β–“β–“β–“β–‘β–‘] Engineering features...")
df = engineer_features(df)
progress_bar.progress(50)
time.sleep(0.5)
# Step 3
status_text.text("[β–“β–“β–“β–“β–‘] Running anomaly detection...")
df = apply_detection(df)
progress_bar.progress(75)
time.sleep(0.5)
# Step 4
status_text.text("[β–“β–“β–“β–“β–“] Building customer profiles...")
profile_df = build_customer_profiles(df)
profile_df = assign_kyc_tier(profile_df)
progress_bar.progress(100)
time.sleep(0.5)
status_text.empty()
progress_bar.empty()
# Summary Metrics
total_tx = len(df)
flagged = df['is_flagged'].sum()
high_risk = len(df[df['risk_level'] == 'High'])
med_risk = len(df[df['risk_level'] == 'Medium'])
avg_score = df['risk_score'].mean()
date_range = f"{df['timestamp'].dt.date.min()} to {df['timestamp'].dt.date.max()}"
# Structuring & Intl stats for report
struct_attempts = profile_df['structuring_attempts'].sum()
intl_high = len(df[(df['is_international'] == 1) & (df['amount'] > 25000)])
kyc_counts = profile_df['kyc_tier'].value_counts().to_dict()
# Top flagged rules
rules_flat = [rule for sublist in df['rule_flags'] if isinstance(sublist, list) for rule in sublist]
top_rules = pd.Series(rules_flat).value_counts().head(3).to_dict() if rules_flat else {}
top_customers = profile_df.sort_values('avg_risk_score', ascending=False)['customer_id'].head(3).tolist()
# Save Upload to DB
upload_id = db.save_upload(
filename=filename, total=total_tx, flagged=flagged,
high_risk=high_risk, medium_risk=med_risk,
avg_score=avg_score, date_range=date_range
)
if upload_id:
# Batch insert chunks
db.save_transactions(df, upload_id)
db.save_customer_profiles(profile_df, upload_id)
summary_data = {
"filename": filename,
"total_transactions": int(total_tx),
"flagged_count": int(flagged),
"high_risk_count": int(high_risk),
"medium_risk_count": int(med_risk),
"avg_risk_score": float(avg_score),
"date_range": date_range,
"structuring_attempts": int(struct_attempts),
"international_high_value_count": int(intl_high),
"kyc_tier_breakdown": kyc_counts,
"top_rules_triggered": top_rules,
"top_flagged_customers": top_customers
}
# Session State
st.session_state.df_raw = df.copy()
st.session_state.df_scored = df.copy()
st.session_state.profile_df = profile_df.copy()
st.session_state.upload_id = upload_id
st.session_state.summary_data = summary_data
st.session_state.ai_report = None
st.success(f"βœ… {total_tx} transactions analyzed | ⚠️ {flagged} flagged | πŸ”΄ {high_risk} high risk | πŸ“Š Avg risk score: {avg_score:.1f}")
# --- Model Evaluation (if is_fraud exists) ---
if "is_fraud" in df.columns:
st.markdown("---")
st.subheader("🎯 Model Performance Evaluation")
# Normalize fraud labels
df["is_fraud_numeric"] = df["is_fraud"].astype(str).str.upper().map(
{"TRUE": 1, "FALSE": 0, "1": 1, "0": 0, "1.0": 1, "0.0": 0}
).fillna(0).astype(int)
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
# ml_anomaly_flag comes from detection.py (it's 1 if scoring high)
# In our case, the 'is_flagged' column is the prediction
y_true = df["is_fraud_numeric"]
y_pred = df["is_flagged"]
report = classification_report(y_true, y_pred, output_dict=True)
c1, c2, c3 = st.columns(3)
c1.metric("Accuracy", f"{report['accuracy']:.2%}")
c2.metric("Precision (Fraud)", f"{report['1']['precision']:.2%}")
c3.metric("Recall (Fraud)", f"{report['1']['recall']:.2%}")
with st.expander("Detailed Classification Report"):
st.code(classification_report(y_true, y_pred))
st.markdown("---")
st.subheader("Top 5 Highest Risk Transactions Preview")
preview = df.sort_values('risk_score', ascending=False).head(5)
cols = ['transaction_id', 'customer_id', 'amount', 'transaction_type', 'risk_score', 'risk_level', 'rule_flags']
st.dataframe(preview[cols])
return True
# --- PAGE ROUTING ---
if page == "Upload & Analyze":
st.title("πŸ›‘οΈ AML Shield")
st.write("AI-Powered Anti-Money Laundering Transaction Intelligence Platform")
col1, col2 = st.columns(2)
with col1:
uploaded_file = st.file_uploader("Upload CSV Transactions", type=['csv'])
if uploaded_file is not None:
if st.button("Analyze Uploaded File"):
run_pipeline(uploaded_file, filename=uploaded_file.name)
with col2:
st.write("Or test with pre-generated synthetic data:")
if st.button("Use Sample Dataset"):
sample_path = "sample_data/sample_transactions.csv"
if os.path.exists(sample_path):
run_pipeline(sample_path, filename="sample_transactions.csv")
else:
st.error("Sample dataset not found. Please ensure it was generated.")
elif page == "Dashboard":
if 'df_scored' not in st.session_state:
st.warning("Please upload or load sample data first in the 'Upload & Analyze' tab.")
else:
df = st.session_state.df_scored.copy()
summ = st.session_state.summary_data
# Dashboard Filters in Sidebar
st.sidebar.markdown("---")
st.sidebar.subheader("Dashboard Filters")
risk_filter = st.sidebar.multiselect("Risk Level", options=['High', 'Medium', 'Low'], default=['High', 'Medium', 'Low'])
type_filter = st.sidebar.multiselect("Transaction Type", options=df['transaction_type'].unique(), default=df['transaction_type'].unique())
min_date = df['timestamp'].min().date()
max_date = df['timestamp'].max().date()
date_filter = st.sidebar.slider("Date Range", min_value=min_date, max_value=max_date, value=(min_date, max_date))
# Apply filters
df_filtered = df[
(df['risk_level'].isin(risk_filter)) &
(df['transaction_type'].isin(type_filter)) &
(df['timestamp'].dt.date >= date_filter[0]) &
(df['timestamp'].dt.date <= date_filter[1])
]
# KPIs
c1, c2, c3, c4 = st.columns(4)
c1.metric("Total Transactions", summ['total_transactions'])
flagged_pct = (summ['flagged_count'] / summ['total_transactions']) * 100 if summ['total_transactions'] > 0 else 0
c2.metric("Flagged", summ['flagged_count'], delta=f"{flagged_pct:.1f}%")
c3.metric("High Risk", summ['high_risk_count'])
c4.metric("Avg Risk Score", f"{summ['avg_risk_score']:.1f}")
# Charts Row 1
r1c1, r1c2 = st.columns(2)
with r1c1:
st.subheader("Risk Distribution")
st.plotly_chart(viz.risk_distribution_chart(df_filtered), use_container_width=True)
with r1c2:
st.subheader("Daily Flagged Transactions")
st.plotly_chart(viz.flagged_transactions_timeline(df_filtered), use_container_width=True)
# Charts Row 2
st.subheader("Amount vs Risk Score Scatter")
st.plotly_chart(viz.amount_vs_risk_scatter(df_filtered), use_container_width=True)
# Charts Row 3
r3c1, r3c2 = st.columns(2)
with r3c1:
st.subheader("Transaction Types (Flagged vs Clean)")
st.plotly_chart(viz.transaction_type_breakdown(df_filtered), use_container_width=True)
with r3c2:
st.subheader("Rule Trigger Frequency")
st.plotly_chart(viz.rule_trigger_frequency(df_filtered), use_container_width=True)
# Charts Row 4
st.subheader("Top Flagged Customers")
st.plotly_chart(viz.top_flagged_customers_chart(df_filtered), use_container_width=True)
# Table
st.subheader("Flagged Transactions Explorer")
flagged_df = df_filtered[df_filtered['is_flagged'] == 1].copy()
# Convert rule_flags list to string for display/CSV
flagged_df['rule_flags_str'] = flagged_df['rule_flags'].apply(lambda x: ", ".join(x) if isinstance(x, list) else str(x))
disp_cols = ['transaction_id', 'customer_id', 'amount', 'transaction_type', 'risk_score', 'risk_level', 'rule_flags_str']
st.dataframe(flagged_df[disp_cols])
csv_data = flagged_df[disp_cols].to_csv(index=False).encode('utf-8')
st.download_button("Download Flagged Transactions CSV", data=csv_data, file_name="flagged_transactions.csv", mime="text/csv")
elif page == "Customer Profiles":
if 'profile_df' not in st.session_state:
st.warning("Please upload data first to analyze customer profiles.")
else:
profile_df = st.session_state.profile_df.copy()
df = st.session_state.df_scored
st.title("Customer KYC Profiles")
col1, col2 = st.columns([1, 2])
with col1:
st.subheader("KYC Tier Distribution")
st.plotly_chart(viz.kyc_tier_distribution(profile_df), use_container_width=True)
with col2:
st.subheader("All Customer Profiles")
st.dataframe(profile_df)
st.markdown("---")
st.subheader("Customer Drill-down")
selected_cust = st.selectbox("Select Customer ID", options=profile_df['customer_id'].unique())
cust_profile = profile_df[profile_df['customer_id'] == selected_cust].iloc[0]
cust_tx = df[df['customer_id'] == selected_cust].sort_values('timestamp', ascending=False)
cust_flags = cust_tx[cust_tx['is_flagged'] == 1]
c1, c2, c3 = st.columns(3)
c1.metric("KYC Tier", cust_profile['kyc_tier'])
c2.metric("Total Volume", f"${cust_profile['total_volume']:,.2f}")
c3.metric("Avg Risk Score", f"{cust_profile['avg_risk_score']:.1f}")
st.write("### Transaction History")
st.dataframe(cust_tx[['transaction_id', 'timestamp', 'amount', 'transaction_type', 'risk_score', 'risk_level']])
st.write("### Repeated Suspicious Behavior")
if len(cust_flags) > 0:
st.dataframe(cust_flags[['transaction_id', 'amount', 'rule_flags']])
else:
st.write("None detected.")
elif page == "AI Report":
if 'summary_data' not in st.session_state:
st.warning("Please upload data first to generate an AI report.")
else:
st.title("πŸ€– AI Compliance Report Generation")
summ = st.session_state.summary_data
st.info(f"**Dataset loaded:** {summ['filename']} | **Total Transactions:** {summ['total_transactions']} | **Flagged:** {summ['flagged_count']}")
if st.button("πŸ€– Generate AI Compliance Report", type="primary"):
if not os.environ.get("BYTEZ_API_KEY"):
st.error("BYTEZ_API_KEY requires to be set to generate AI report.")
else:
with st.spinner("Connecting to AI analyst..."):
placeholder = st.empty()
report_text = stream_compliance_report(summ, placeholder)
if report_text and not report_text.startswith("Error"):
st.success("βœ… Report generated using meta-llama/Llama-3.1-8B-Instruct via Bytez")
st.session_state.ai_report = report_text
if st.session_state.upload_id:
db.save_ai_report(st.session_state.upload_id, report_text, "meta-llama/Llama-3.1-8B-Instruct")
if st.session_state.get('ai_report'):
st.markdown("---")
st.write("### Actions")
# PDF generation
flagged_df = st.session_state.df_scored[st.session_state.df_scored['is_flagged'] == 1].copy()
pdf_bytes = build_pdf(st.session_state.ai_report, summ, flagged_df)
date_str = datetime.now().strftime("%Y%m%d_%H%M")
st.download_button("πŸ“„ Download PDF Report", data=pdf_bytes, file_name=f"AML_Shield_Report_{date_str}.pdf", mime="application/pdf")
st.markdown("---")
st.markdown(st.session_state.ai_report)
elif page == "Global Analytics":
st.title("🌍 Global Analytics")
with st.spinner("Fetching global stats from Supabase..."):
try:
stats = db.get_global_stats()
uploads = db.get_all_uploads()
uploads_df = pd.DataFrame(uploads)
except Exception as e:
st.error(f"Could not connect to Supabase: {e}")
stats = None
uploads_df = pd.DataFrame()
if stats:
c1, c2, c3, c4 = st.columns(4)
c1.metric("All-time Transactions", stats['total_transactions_ever'])
c2.metric("Total Uploads", stats['total_uploads'])
c3.metric("All-time Flagged", stats['total_flagged_ever'])
c4.metric("Global Avg Risk", f"{stats['avg_risk_score_global']:.1f}")
st.markdown("---")
col1, col2 = st.columns(2)
with col1:
st.subheader("Global Trend: Flagged per Upload")
if not uploads_df.empty:
if 'uploaded_at' in uploads_df.columns:
uploads_df['date'] = pd.to_datetime(uploads_df['uploaded_at']).dt.date
trend_df = uploads_df.groupby('date')['flagged_count'].sum().reset_index()
fig = px.line(trend_df, x='date', y='flagged_count', markers=True)
fig.update_traces(line_color=viz.COLOR_MED)
st.plotly_chart(viz.apply_theme(fig), use_container_width=True)
else:
st.write("No historical data available.")
with col2:
st.subheader("Most Common Rule Triggered")
st.info(stats.get('most_common_rule_triggered', 'N/A'))
st.write("*(Approximation based on available metric patterns)*")
st.subheader("Uploads History")
if not uploads_df.empty:
st.dataframe(uploads_df[['filename', 'uploaded_at', 'total_transactions', 'flagged_count', 'high_risk_count', 'avg_risk_score']])
elif page == "About":
st.title("ℹ️ About AML Shield")
st.write("""
### AI-Powered Anti-Money Laundering Transaction Intelligence Platform
AML Shield is built to demonstrate production-grade AML compliance analytics skills for financial services roles.
#### How it works:
1. **Upload CSV** β†’ ETL validation & pre-processing.
2. **Rule-based AML flags** β†’ applied to all inputs.
3. **Isolation Forest ML** β†’ anomaly detection logic.
4. **Risk scoring (0-100)** β†’ deterministic algorithm based on flags+ML.
5. **KYC customer profiling** β†’ KMeans clustering into tiers.
6. **LangChain + Bytez** β†’ streams a formal regulatory compliance report utilizing meta-llama/Llama-3.1-8B-Instruct.
7. **ReportLab** β†’ renders professional downloadable PDF.
8. **Supabase** β†’ All data natively persisted.
""")
with st.expander("AML Rules Explained"):
st.write("""
- **Structuring**: Transactions intentionally sizing just beneath the $10,000 CTR reporting requirement ($9000 - $9999).
- **Rapid Fire Transactions**: Accounts showing an abnormally high transaction velocity.
- **Large Cash Out**: Immediate cash liquidations above $50,000.
- **Dormant Account Spike**: High amounts triggered by newly created or previously dormant accounts (< 30 days).
- **International High Value**: Large wire transfers sent outside of the domestic region.
- **Suspicious Round Amount**: High net round payments generally uncharacteristic of organic spending.
""")
st.write("""
**Regulatory Frameworks Considered:**
- BSA (Bank Secrecy Act)
- FinCEN SAR (Suspicious Activity Report) requirements
- FATF Recommendation 16 (Wire transfers)
""")
st.markdown("---")
st.write("**Tech Stack:** Streamlit | Pandas | Scikit-learn | Plotly | ReportLab | LangChain | Bytez | Supabase")
st.write("**Deployments:** Live on Hugging Face Spaces")