Spaces:
Running
Running
| import os | |
| from dotenv import load_dotenv | |
| # Load for local testing only | |
| if os.path.exists(".env"): | |
| load_dotenv() | |
| import streamlit as st | |
| import pandas as pd | |
| import plotly.express as px | |
| import time | |
| from datetime import datetime | |
| # Initialize database connections properly | |
| # Assuming database uses Supabase, if Supabase client init fails due to missing secrets, app will handle gracefully | |
| import modules.database as db | |
| from modules.etl import load_and_validate, engineer_features | |
| from modules.detection import apply_detection | |
| from modules.risk_profiling import build_customer_profiles, assign_kyc_tier | |
| import modules.visualizations as viz | |
| from modules.ai_agent import stream_compliance_report | |
| from modules.pdf_report import build_pdf | |
| st.set_page_config(page_title="AML Shield", page_icon="π‘οΈ", layout="wide", initial_sidebar_state="expanded") | |
| # --- CSS Overrides --- | |
| st.markdown(""" | |
| <style> | |
| .stMetric { background-color: #1e1e2d; padding: 15px; border-radius: 5px; } | |
| h1, h2, h3 { color: #f8f9fa; } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # --- Sidebar --- | |
| st.sidebar.title("π‘οΈ AML Shield Navigation") | |
| tabs = ["Upload & Analyze", "Dashboard", "Customer Profiles", "AI Report", "Global Analytics", "About"] | |
| page = st.sidebar.radio("Go to", tabs) | |
| # --- Helper logic for analysis --- | |
| def run_pipeline(file_obj, filename="Uploaded Data"): | |
| progress_bar = st.progress(0) | |
| status_text = st.empty() | |
| # Step 1 | |
| status_text.text("[βββββ] Loading & validating data...") | |
| df, msg = load_and_validate(file_obj) | |
| if df is None: | |
| st.error(msg) | |
| progress_bar.empty() | |
| status_text.empty() | |
| return False | |
| progress_bar.progress(25) | |
| time.sleep(0.5) | |
| # Step 2 | |
| status_text.text("[βββββ] Engineering features...") | |
| df = engineer_features(df) | |
| progress_bar.progress(50) | |
| time.sleep(0.5) | |
| # Step 3 | |
| status_text.text("[βββββ] Running anomaly detection...") | |
| df = apply_detection(df) | |
| progress_bar.progress(75) | |
| time.sleep(0.5) | |
| # Step 4 | |
| status_text.text("[βββββ] Building customer profiles...") | |
| profile_df = build_customer_profiles(df) | |
| profile_df = assign_kyc_tier(profile_df) | |
| progress_bar.progress(100) | |
| time.sleep(0.5) | |
| status_text.empty() | |
| progress_bar.empty() | |
| # Summary Metrics | |
| total_tx = len(df) | |
| flagged = df['is_flagged'].sum() | |
| high_risk = len(df[df['risk_level'] == 'High']) | |
| med_risk = len(df[df['risk_level'] == 'Medium']) | |
| avg_score = df['risk_score'].mean() | |
| date_range = f"{df['timestamp'].dt.date.min()} to {df['timestamp'].dt.date.max()}" | |
| # Structuring & Intl stats for report | |
| struct_attempts = profile_df['structuring_attempts'].sum() | |
| intl_high = len(df[(df['is_international'] == 1) & (df['amount'] > 25000)]) | |
| kyc_counts = profile_df['kyc_tier'].value_counts().to_dict() | |
| # Top flagged rules | |
| rules_flat = [rule for sublist in df['rule_flags'] if isinstance(sublist, list) for rule in sublist] | |
| top_rules = pd.Series(rules_flat).value_counts().head(3).to_dict() if rules_flat else {} | |
| top_customers = profile_df.sort_values('avg_risk_score', ascending=False)['customer_id'].head(3).tolist() | |
| # Save Upload to DB | |
| upload_id = db.save_upload( | |
| filename=filename, total=total_tx, flagged=flagged, | |
| high_risk=high_risk, medium_risk=med_risk, | |
| avg_score=avg_score, date_range=date_range | |
| ) | |
| if upload_id: | |
| # Batch insert chunks | |
| db.save_transactions(df, upload_id) | |
| db.save_customer_profiles(profile_df, upload_id) | |
| summary_data = { | |
| "filename": filename, | |
| "total_transactions": int(total_tx), | |
| "flagged_count": int(flagged), | |
| "high_risk_count": int(high_risk), | |
| "medium_risk_count": int(med_risk), | |
| "avg_risk_score": float(avg_score), | |
| "date_range": date_range, | |
| "structuring_attempts": int(struct_attempts), | |
| "international_high_value_count": int(intl_high), | |
| "kyc_tier_breakdown": kyc_counts, | |
| "top_rules_triggered": top_rules, | |
| "top_flagged_customers": top_customers | |
| } | |
| # Session State | |
| st.session_state.df_raw = df.copy() | |
| st.session_state.df_scored = df.copy() | |
| st.session_state.profile_df = profile_df.copy() | |
| st.session_state.upload_id = upload_id | |
| st.session_state.summary_data = summary_data | |
| st.session_state.ai_report = None | |
| st.success(f"β {total_tx} transactions analyzed | β οΈ {flagged} flagged | π΄ {high_risk} high risk | π Avg risk score: {avg_score:.1f}") | |
| # --- Model Evaluation (if is_fraud exists) --- | |
| if "is_fraud" in df.columns: | |
| st.markdown("---") | |
| st.subheader("π― Model Performance Evaluation") | |
| # Normalize fraud labels | |
| df["is_fraud_numeric"] = df["is_fraud"].astype(str).str.upper().map( | |
| {"TRUE": 1, "FALSE": 0, "1": 1, "0": 0, "1.0": 1, "0.0": 0} | |
| ).fillna(0).astype(int) | |
| from sklearn.metrics import classification_report, confusion_matrix | |
| import seaborn as sns | |
| import matplotlib.pyplot as plt | |
| # ml_anomaly_flag comes from detection.py (it's 1 if scoring high) | |
| # In our case, the 'is_flagged' column is the prediction | |
| y_true = df["is_fraud_numeric"] | |
| y_pred = df["is_flagged"] | |
| report = classification_report(y_true, y_pred, output_dict=True) | |
| c1, c2, c3 = st.columns(3) | |
| c1.metric("Accuracy", f"{report['accuracy']:.2%}") | |
| c2.metric("Precision (Fraud)", f"{report['1']['precision']:.2%}") | |
| c3.metric("Recall (Fraud)", f"{report['1']['recall']:.2%}") | |
| with st.expander("Detailed Classification Report"): | |
| st.code(classification_report(y_true, y_pred)) | |
| st.markdown("---") | |
| st.subheader("Top 5 Highest Risk Transactions Preview") | |
| preview = df.sort_values('risk_score', ascending=False).head(5) | |
| cols = ['transaction_id', 'customer_id', 'amount', 'transaction_type', 'risk_score', 'risk_level', 'rule_flags'] | |
| st.dataframe(preview[cols]) | |
| return True | |
| # --- PAGE ROUTING --- | |
| if page == "Upload & Analyze": | |
| st.title("π‘οΈ AML Shield") | |
| st.write("AI-Powered Anti-Money Laundering Transaction Intelligence Platform") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| uploaded_file = st.file_uploader("Upload CSV Transactions", type=['csv']) | |
| if uploaded_file is not None: | |
| if st.button("Analyze Uploaded File"): | |
| run_pipeline(uploaded_file, filename=uploaded_file.name) | |
| with col2: | |
| st.write("Or test with pre-generated synthetic data:") | |
| if st.button("Use Sample Dataset"): | |
| sample_path = "sample_data/sample_transactions.csv" | |
| if os.path.exists(sample_path): | |
| run_pipeline(sample_path, filename="sample_transactions.csv") | |
| else: | |
| st.error("Sample dataset not found. Please ensure it was generated.") | |
| elif page == "Dashboard": | |
| if 'df_scored' not in st.session_state: | |
| st.warning("Please upload or load sample data first in the 'Upload & Analyze' tab.") | |
| else: | |
| df = st.session_state.df_scored.copy() | |
| summ = st.session_state.summary_data | |
| # Dashboard Filters in Sidebar | |
| st.sidebar.markdown("---") | |
| st.sidebar.subheader("Dashboard Filters") | |
| risk_filter = st.sidebar.multiselect("Risk Level", options=['High', 'Medium', 'Low'], default=['High', 'Medium', 'Low']) | |
| type_filter = st.sidebar.multiselect("Transaction Type", options=df['transaction_type'].unique(), default=df['transaction_type'].unique()) | |
| min_date = df['timestamp'].min().date() | |
| max_date = df['timestamp'].max().date() | |
| date_filter = st.sidebar.slider("Date Range", min_value=min_date, max_value=max_date, value=(min_date, max_date)) | |
| # Apply filters | |
| df_filtered = df[ | |
| (df['risk_level'].isin(risk_filter)) & | |
| (df['transaction_type'].isin(type_filter)) & | |
| (df['timestamp'].dt.date >= date_filter[0]) & | |
| (df['timestamp'].dt.date <= date_filter[1]) | |
| ] | |
| # KPIs | |
| c1, c2, c3, c4 = st.columns(4) | |
| c1.metric("Total Transactions", summ['total_transactions']) | |
| flagged_pct = (summ['flagged_count'] / summ['total_transactions']) * 100 if summ['total_transactions'] > 0 else 0 | |
| c2.metric("Flagged", summ['flagged_count'], delta=f"{flagged_pct:.1f}%") | |
| c3.metric("High Risk", summ['high_risk_count']) | |
| c4.metric("Avg Risk Score", f"{summ['avg_risk_score']:.1f}") | |
| # Charts Row 1 | |
| r1c1, r1c2 = st.columns(2) | |
| with r1c1: | |
| st.subheader("Risk Distribution") | |
| st.plotly_chart(viz.risk_distribution_chart(df_filtered), use_container_width=True) | |
| with r1c2: | |
| st.subheader("Daily Flagged Transactions") | |
| st.plotly_chart(viz.flagged_transactions_timeline(df_filtered), use_container_width=True) | |
| # Charts Row 2 | |
| st.subheader("Amount vs Risk Score Scatter") | |
| st.plotly_chart(viz.amount_vs_risk_scatter(df_filtered), use_container_width=True) | |
| # Charts Row 3 | |
| r3c1, r3c2 = st.columns(2) | |
| with r3c1: | |
| st.subheader("Transaction Types (Flagged vs Clean)") | |
| st.plotly_chart(viz.transaction_type_breakdown(df_filtered), use_container_width=True) | |
| with r3c2: | |
| st.subheader("Rule Trigger Frequency") | |
| st.plotly_chart(viz.rule_trigger_frequency(df_filtered), use_container_width=True) | |
| # Charts Row 4 | |
| st.subheader("Top Flagged Customers") | |
| st.plotly_chart(viz.top_flagged_customers_chart(df_filtered), use_container_width=True) | |
| # Table | |
| st.subheader("Flagged Transactions Explorer") | |
| flagged_df = df_filtered[df_filtered['is_flagged'] == 1].copy() | |
| # Convert rule_flags list to string for display/CSV | |
| flagged_df['rule_flags_str'] = flagged_df['rule_flags'].apply(lambda x: ", ".join(x) if isinstance(x, list) else str(x)) | |
| disp_cols = ['transaction_id', 'customer_id', 'amount', 'transaction_type', 'risk_score', 'risk_level', 'rule_flags_str'] | |
| st.dataframe(flagged_df[disp_cols]) | |
| csv_data = flagged_df[disp_cols].to_csv(index=False).encode('utf-8') | |
| st.download_button("Download Flagged Transactions CSV", data=csv_data, file_name="flagged_transactions.csv", mime="text/csv") | |
| elif page == "Customer Profiles": | |
| if 'profile_df' not in st.session_state: | |
| st.warning("Please upload data first to analyze customer profiles.") | |
| else: | |
| profile_df = st.session_state.profile_df.copy() | |
| df = st.session_state.df_scored | |
| st.title("Customer KYC Profiles") | |
| col1, col2 = st.columns([1, 2]) | |
| with col1: | |
| st.subheader("KYC Tier Distribution") | |
| st.plotly_chart(viz.kyc_tier_distribution(profile_df), use_container_width=True) | |
| with col2: | |
| st.subheader("All Customer Profiles") | |
| st.dataframe(profile_df) | |
| st.markdown("---") | |
| st.subheader("Customer Drill-down") | |
| selected_cust = st.selectbox("Select Customer ID", options=profile_df['customer_id'].unique()) | |
| cust_profile = profile_df[profile_df['customer_id'] == selected_cust].iloc[0] | |
| cust_tx = df[df['customer_id'] == selected_cust].sort_values('timestamp', ascending=False) | |
| cust_flags = cust_tx[cust_tx['is_flagged'] == 1] | |
| c1, c2, c3 = st.columns(3) | |
| c1.metric("KYC Tier", cust_profile['kyc_tier']) | |
| c2.metric("Total Volume", f"${cust_profile['total_volume']:,.2f}") | |
| c3.metric("Avg Risk Score", f"{cust_profile['avg_risk_score']:.1f}") | |
| st.write("### Transaction History") | |
| st.dataframe(cust_tx[['transaction_id', 'timestamp', 'amount', 'transaction_type', 'risk_score', 'risk_level']]) | |
| st.write("### Repeated Suspicious Behavior") | |
| if len(cust_flags) > 0: | |
| st.dataframe(cust_flags[['transaction_id', 'amount', 'rule_flags']]) | |
| else: | |
| st.write("None detected.") | |
| elif page == "AI Report": | |
| if 'summary_data' not in st.session_state: | |
| st.warning("Please upload data first to generate an AI report.") | |
| else: | |
| st.title("π€ AI Compliance Report Generation") | |
| summ = st.session_state.summary_data | |
| st.info(f"**Dataset loaded:** {summ['filename']} | **Total Transactions:** {summ['total_transactions']} | **Flagged:** {summ['flagged_count']}") | |
| if st.button("π€ Generate AI Compliance Report", type="primary"): | |
| if not os.environ.get("BYTEZ_API_KEY"): | |
| st.error("BYTEZ_API_KEY requires to be set to generate AI report.") | |
| else: | |
| with st.spinner("Connecting to AI analyst..."): | |
| placeholder = st.empty() | |
| report_text = stream_compliance_report(summ, placeholder) | |
| if report_text and not report_text.startswith("Error"): | |
| st.success("β Report generated using meta-llama/Llama-3.1-8B-Instruct via Bytez") | |
| st.session_state.ai_report = report_text | |
| if st.session_state.upload_id: | |
| db.save_ai_report(st.session_state.upload_id, report_text, "meta-llama/Llama-3.1-8B-Instruct") | |
| if st.session_state.get('ai_report'): | |
| st.markdown("---") | |
| st.write("### Actions") | |
| # PDF generation | |
| flagged_df = st.session_state.df_scored[st.session_state.df_scored['is_flagged'] == 1].copy() | |
| pdf_bytes = build_pdf(st.session_state.ai_report, summ, flagged_df) | |
| date_str = datetime.now().strftime("%Y%m%d_%H%M") | |
| st.download_button("π Download PDF Report", data=pdf_bytes, file_name=f"AML_Shield_Report_{date_str}.pdf", mime="application/pdf") | |
| st.markdown("---") | |
| st.markdown(st.session_state.ai_report) | |
| elif page == "Global Analytics": | |
| st.title("π Global Analytics") | |
| with st.spinner("Fetching global stats from Supabase..."): | |
| try: | |
| stats = db.get_global_stats() | |
| uploads = db.get_all_uploads() | |
| uploads_df = pd.DataFrame(uploads) | |
| except Exception as e: | |
| st.error(f"Could not connect to Supabase: {e}") | |
| stats = None | |
| uploads_df = pd.DataFrame() | |
| if stats: | |
| c1, c2, c3, c4 = st.columns(4) | |
| c1.metric("All-time Transactions", stats['total_transactions_ever']) | |
| c2.metric("Total Uploads", stats['total_uploads']) | |
| c3.metric("All-time Flagged", stats['total_flagged_ever']) | |
| c4.metric("Global Avg Risk", f"{stats['avg_risk_score_global']:.1f}") | |
| st.markdown("---") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.subheader("Global Trend: Flagged per Upload") | |
| if not uploads_df.empty: | |
| if 'uploaded_at' in uploads_df.columns: | |
| uploads_df['date'] = pd.to_datetime(uploads_df['uploaded_at']).dt.date | |
| trend_df = uploads_df.groupby('date')['flagged_count'].sum().reset_index() | |
| fig = px.line(trend_df, x='date', y='flagged_count', markers=True) | |
| fig.update_traces(line_color=viz.COLOR_MED) | |
| st.plotly_chart(viz.apply_theme(fig), use_container_width=True) | |
| else: | |
| st.write("No historical data available.") | |
| with col2: | |
| st.subheader("Most Common Rule Triggered") | |
| st.info(stats.get('most_common_rule_triggered', 'N/A')) | |
| st.write("*(Approximation based on available metric patterns)*") | |
| st.subheader("Uploads History") | |
| if not uploads_df.empty: | |
| st.dataframe(uploads_df[['filename', 'uploaded_at', 'total_transactions', 'flagged_count', 'high_risk_count', 'avg_risk_score']]) | |
| elif page == "About": | |
| st.title("βΉοΈ About AML Shield") | |
| st.write(""" | |
| ### AI-Powered Anti-Money Laundering Transaction Intelligence Platform | |
| AML Shield is built to demonstrate production-grade AML compliance analytics skills for financial services roles. | |
| #### How it works: | |
| 1. **Upload CSV** β ETL validation & pre-processing. | |
| 2. **Rule-based AML flags** β applied to all inputs. | |
| 3. **Isolation Forest ML** β anomaly detection logic. | |
| 4. **Risk scoring (0-100)** β deterministic algorithm based on flags+ML. | |
| 5. **KYC customer profiling** β KMeans clustering into tiers. | |
| 6. **LangChain + Bytez** β streams a formal regulatory compliance report utilizing meta-llama/Llama-3.1-8B-Instruct. | |
| 7. **ReportLab** β renders professional downloadable PDF. | |
| 8. **Supabase** β All data natively persisted. | |
| """) | |
| with st.expander("AML Rules Explained"): | |
| st.write(""" | |
| - **Structuring**: Transactions intentionally sizing just beneath the $10,000 CTR reporting requirement ($9000 - $9999). | |
| - **Rapid Fire Transactions**: Accounts showing an abnormally high transaction velocity. | |
| - **Large Cash Out**: Immediate cash liquidations above $50,000. | |
| - **Dormant Account Spike**: High amounts triggered by newly created or previously dormant accounts (< 30 days). | |
| - **International High Value**: Large wire transfers sent outside of the domestic region. | |
| - **Suspicious Round Amount**: High net round payments generally uncharacteristic of organic spending. | |
| """) | |
| st.write(""" | |
| **Regulatory Frameworks Considered:** | |
| - BSA (Bank Secrecy Act) | |
| - FinCEN SAR (Suspicious Activity Report) requirements | |
| - FATF Recommendation 16 (Wire transfers) | |
| """) | |
| st.markdown("---") | |
| st.write("**Tech Stack:** Streamlit | Pandas | Scikit-learn | Plotly | ReportLab | LangChain | Bytez | Supabase") | |
| st.write("**Deployments:** Live on Hugging Face Spaces") | |