Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import numpy as np | |
| import tensorflow as tf | |
| import shap | |
| import pickle | |
| import os | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| from groq import Groq | |
| import keras | |
| # --- 1. SETUP & CONFIG --- | |
| st.set_page_config(page_title="Credit-Scout AI", layout="wide") | |
| # CSS for "Bank" styling | |
| st.markdown(""" | |
| <style> | |
| .main { background-color: #f5f5f5; } | |
| .stButton>button { background-color: #000044; color: white; width: 100%; } | |
| .risk-high { color: #cc0000; font-weight: bold; font-size: 20px; } | |
| .risk-low { color: #006600; font-weight: bold; font-size: 20px; } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Initialize Groq Client | |
| api_key = os.environ.get("GROQ_API_KEY") | |
| if not api_key: | |
| st.error("β οΈ GROQ_API_KEY not found in Secrets! The LLM explanation will fail.") | |
| client = None | |
| else: | |
| client = Groq(api_key=api_key) | |
| # --- 2. LOAD ARTIFACTS --- | |
| def load_resources(): | |
| class PatchedDense(tf.keras.layers.Dense): | |
| def __init__(self, *args, **kwargs): | |
| if 'quantization_config' in kwargs: | |
| kwargs.pop('quantization_config') | |
| super().__init__(*args, **kwargs) | |
| try: | |
| model = tf.keras.models.load_model( | |
| 'latest_checkpoint.h5', | |
| custom_objects={'Dense': PatchedDense}, | |
| compile=False | |
| ) | |
| except Exception as e: | |
| st.error(f"Critical Model Load Error: {e}") | |
| st.stop() | |
| with open('scaler.pkl', 'rb') as f: | |
| scaler = pickle.load(f) | |
| with open('columns.pkl', 'rb') as f: | |
| columns = pickle.load(f) | |
| with open('shap_metadata.pkl', 'rb') as f: | |
| shap_data = pickle.load(f) | |
| explainer = shap.GradientExplainer(model, shap_data['background_sample']) | |
| return model, scaler, columns, explainer | |
| try: | |
| model, scaler, columns, explainer = load_resources() | |
| except Exception as e: | |
| st.error(f"Error loading files: {e}") | |
| st.stop() | |
| # --- 3. BUSINESS MAPPING --- | |
| BUSINESS_MAP = { | |
| 'step': 'Transaction Hour', | |
| 'type_enc': 'Txn Type (Transfer/CashOut)', | |
| 'amount': 'Transaction Amount', | |
| 'oldbalanceOrg': 'Origin Acct Balance (Pre)', | |
| 'newbalanceOrig': 'Origin Acct Balance (Post)', | |
| 'oldbalanceDest': 'Recipient Acct Balance (Pre)', | |
| 'newbalanceDest': 'Recipient Acct Balance (Post)', | |
| 'errorBalanceOrig': 'Origin Math Discrepancy', | |
| 'errorBalanceDest': 'Recipient Math Discrepancy' | |
| } | |
| # --- 4. EXPLANATION FUNCTION (GROQ API) --- | |
| def generate_explanation_cloud(shap_values, original_samples, feature_names, scaler): | |
| # Inverse transform to get real values | |
| raw_scaled = original_samples.flatten() | |
| real_values = scaler.inverse_transform(raw_scaled.reshape(1, -1)).flatten() | |
| if isinstance(shap_values, list): | |
| vals = shap_values[0] | |
| else: | |
| vals = shap_values | |
| vals = vals.flatten() | |
| # Prepare data for LLM | |
| feature_data = [] | |
| for i, col_name in enumerate(feature_names): | |
| biz_name = BUSINESS_MAP.get(col_name, col_name) | |
| feature_data.append((biz_name, real_values[i], vals[i])) | |
| # Sort by absolute impact | |
| feature_data.sort(key=lambda x: abs(x[2]), reverse=True) | |
| total_shap_mass = sum([abs(v) for _, _, v in feature_data]) + 1e-9 | |
| data_lines = [] | |
| shap_lines = [] | |
| for name, real_val, shap_val in feature_data[:3]: | |
| # Format currency | |
| if "Amount" in name or "Balance" in name or "Discrepancy" in name: | |
| val_str = f"${real_val:,.2f}" | |
| else: | |
| val_str = f"{real_val:.2f}" | |
| contrib_pct = (abs(shap_val) / total_shap_mass) * 100 | |
| logic_hint = "ANOMALY (Increased Risk)" if shap_val > 0 else "CONSISTENT BEHAVIOR (Mitigated Risk)" | |
| data_lines.append(f"- {name}: {val_str}") | |
| shap_lines.append(f"- {name}: {logic_hint} | Contribution: {contrib_pct:.1f}%") | |
| if not client: | |
| return "Error: Groq API Key missing." | |
| prompt = f""" | |
| You are a Senior Model Risk Examiner. Write a strict, short compliance explanation. | |
| CONTEXT: | |
| {chr(10).join(data_lines)} | |
| RISK FACTORS: | |
| {chr(10).join(shap_lines)} | |
| Write a "Notice of Adverse Action" explanation. | |
| Use the provided logic hints. Interpret negative SHAP as consistency. | |
| Keep it under 150 words. Professional tone only. Add a standard disclaimer at the end. | |
| """ | |
| try: | |
| completion = client.chat.completions.create( | |
| model="llama-3.3-70b-versatile", | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=0.1, | |
| max_tokens=300, | |
| ) | |
| return completion.choices[0].message.content | |
| except Exception as e: | |
| return f"LLM Error: {str(e)}" | |
| # --- 5. SIDEBAR UI --- | |
| st.sidebar.image("https://cdn-icons-png.flaticon.com/512/2666/2666505.png", width=100) | |
| st.sidebar.title("π³ Transaction Details") | |
| amount = st.sidebar.number_input("Amount ($)", value=350000.0, min_value=0.0, format="%.2f") | |
| old_bal = st.sidebar.number_input("Origin Old Balance ($)", value=80000.0, min_value=0.0, format="%.2f") | |
| new_bal = st.sidebar.number_input("Origin New Balance ($)", value=2000000.0, min_value=0.0, format="%.2f") | |
| txn_type = st.sidebar.selectbox("Type", ["TRANSFER", "CASH_OUT"]) | |
| # Calculate math discrepancy | |
| error_bal_orig = old_bal + amount - new_bal | |
| st.sidebar.info(f"Math Discrepancy: ${error_bal_orig:,.2f}") | |
| st.sidebar.markdown("---") | |
| st.sidebar.markdown("**π§ͺ Test Scenarios:**") | |
| col_a, col_b = st.sidebar.columns(2) | |
| with col_a: | |
| if st.button("π° Suspicious", use_container_width=True): | |
| st.session_state.update({ | |
| 'amount': 5000000.0, | |
| 'old_bal': 100000.0, | |
| 'new_bal': 100000.0, | |
| 'txn_type': 'CASH_OUT' | |
| }) | |
| st.rerun() | |
| with col_b: | |
| if st.button("β Normal", use_container_width=True): | |
| st.session_state.update({ | |
| 'amount': 10000.0, | |
| 'old_bal': 50000.0, | |
| 'new_bal': 40000.0, | |
| 'txn_type': 'TRANSFER' | |
| }) | |
| st.rerun() | |
| # Apply session state | |
| if 'amount' in st.session_state: | |
| amount = st.session_state.amount | |
| if 'old_bal' in st.session_state: | |
| old_bal = st.session_state.old_bal | |
| if 'new_bal' in st.session_state: | |
| new_bal = st.session_state.new_bal | |
| if 'txn_type' in st.session_state: | |
| txn_type = st.session_state.txn_type | |
| # --- 6. MAIN APP LOGIC --- | |
| st.title("π¦ Credit-Scout AI Risk Engine") | |
| st.markdown("Real-time Fraud Detection with Llama 3.3 Explainability") | |
| if st.sidebar.button("Analyze Transaction"): | |
| with st.spinner("Analyzing Risk Patterns..."): | |
| # 1. Preprocess | |
| type_val = 1 if txn_type == 'CASH_OUT' else 0 | |
| # Build feature dictionary | |
| feature_dict = { | |
| 'step': 150, | |
| 'type_enc': type_val, | |
| 'amount': amount, | |
| 'oldbalanceOrg': old_bal, | |
| 'newbalanceOrig': new_bal, | |
| 'oldbalanceDest': 0.0, | |
| 'newbalanceDest': 0.0, | |
| 'errorBalanceOrig': error_bal_orig, | |
| 'errorBalanceDest': 0.0 | |
| } | |
| # Build array in exact column order | |
| raw_features = np.array([feature_dict[col] for col in columns]).reshape(1, -1) | |
| # Scale features with automatic fallback | |
| try: | |
| scaled_features = scaler.transform(raw_features) | |
| # Check if scaler is working properly | |
| if np.abs(scaled_features).max() > 100: | |
| # Scaler appears broken, use manual scaling | |
| manual_means = np.array([243.39, 0.5, 180000, 834000, 855000, 1100000, 1225000, 0, 0]) | |
| manual_stds = np.array([142.3, 0.5, 604000, 2900000, 2940000, 3400000, 3670000, 380000, 420000]) | |
| scaled_features = (raw_features - manual_means) / (manual_stds + 1e-8) | |
| except Exception: | |
| # Fallback to manual scaling | |
| manual_means = np.array([243.39, 0.5, 180000, 834000, 855000, 1100000, 1225000, 0, 0]) | |
| manual_stds = np.array([142.3, 0.5, 604000, 2900000, 2940000, 3400000, 3670000, 380000, 420000]) | |
| scaled_features = (raw_features - manual_means) / (manual_stds + 1e-8) | |
| # Reshape for LSTM | |
| lstm_input = scaled_features.reshape(1, 1, 9) | |
| # 2. Predict | |
| prediction_raw = model.predict(lstm_input, verbose=0) | |
| risk_prob = float(prediction_raw[0][0]) | |
| # 3. Explain (SHAP) | |
| shap_vals = explainer.shap_values(lstm_input) | |
| # 4. Display Results | |
| col1, col2 = st.columns([1, 2]) | |
| with col1: | |
| st.subheader("Risk Score") | |
| st.metric(label="Fraud Probability", value=f"{risk_prob:.2%}") | |
| threshold = 0.5 | |
| if risk_prob > threshold: | |
| st.markdown('<p class="risk-high">β FLAGGED</p>', unsafe_allow_html=True) | |
| else: | |
| st.markdown('<p class="risk-low">β APPROVED</p>', unsafe_allow_html=True) | |
| with col2: | |
| st.subheader("Model Logic (SHAP)") | |
| # Process SHAP values | |
| if isinstance(shap_vals, list): | |
| shap_vals_plot = shap_vals[0] | |
| else: | |
| shap_vals_plot = shap_vals | |
| if len(shap_vals_plot.shape) > 2: | |
| shap_vals_plot = shap_vals_plot.reshape(1, -1) | |
| # Create SHAP plot | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| shap.summary_plot( | |
| shap_vals_plot, | |
| raw_features, | |
| feature_names=columns, | |
| plot_type="bar", | |
| show=False | |
| ) | |
| st.pyplot(fig, clear_figure=True) | |
| plt.close('all') | |
| # 5. LLM Report | |
| st.markdown("---") | |
| st.subheader("π Audit Report (Llama 3.3)") | |
| with st.spinner("Drafting Compliance Notice..."): | |
| report = generate_explanation_cloud(shap_vals, scaled_features, columns, scaler) | |
| st.success("Report Generated") | |
| st.write(report) | |
| else: | |
| st.info("π Adjust transaction details in the sidebar and click 'Analyze Transaction' to begin.") |