import streamlit as st import numpy as np import tensorflow as tf import shap import pickle import os import pandas as pd import matplotlib.pyplot as plt from groq import Groq import keras # --- 1. SETUP & CONFIG --- st.set_page_config(page_title="Credit-Scout AI", layout="wide") # CSS for "Bank" styling st.markdown(""" """, unsafe_allow_html=True) # Initialize Groq Client api_key = os.environ.get("GROQ_API_KEY") if not api_key: st.error("โ ๏ธ GROQ_API_KEY not found in Secrets! The LLM explanation will fail.") client = None else: client = Groq(api_key=api_key) # --- 2. LOAD ARTIFACTS --- @st.cache_resource def load_resources(): class PatchedDense(tf.keras.layers.Dense): def __init__(self, *args, **kwargs): if 'quantization_config' in kwargs: kwargs.pop('quantization_config') super().__init__(*args, **kwargs) try: model = tf.keras.models.load_model( 'latest_checkpoint.h5', custom_objects={'Dense': PatchedDense}, compile=False ) except Exception as e: st.error(f"Critical Model Load Error: {e}") st.stop() with open('scaler.pkl', 'rb') as f: scaler = pickle.load(f) with open('columns.pkl', 'rb') as f: columns = pickle.load(f) with open('shap_metadata.pkl', 'rb') as f: shap_data = pickle.load(f) explainer = shap.GradientExplainer(model, shap_data['background_sample']) return model, scaler, columns, explainer try: model, scaler, columns, explainer = load_resources() except Exception as e: st.error(f"Error loading files: {e}") st.stop() # --- 3. BUSINESS MAPPING --- BUSINESS_MAP = { 'step': 'Transaction Hour', 'type_enc': 'Txn Type (Transfer/CashOut)', 'amount': 'Transaction Amount', 'oldbalanceOrg': 'Origin Acct Balance (Pre)', 'newbalanceOrig': 'Origin Acct Balance (Post)', 'oldbalanceDest': 'Recipient Acct Balance (Pre)', 'newbalanceDest': 'Recipient Acct Balance (Post)', 'errorBalanceOrig': 'Origin Math Discrepancy', 'errorBalanceDest': 'Recipient Math Discrepancy' } # --- 4. EXPLANATION FUNCTION (GROQ API) --- def generate_explanation_cloud(shap_values, original_samples, feature_names, scaler): # Inverse transform to get real values raw_scaled = original_samples.flatten() real_values = scaler.inverse_transform(raw_scaled.reshape(1, -1)).flatten() if isinstance(shap_values, list): vals = shap_values[0] else: vals = shap_values vals = vals.flatten() # Prepare data for LLM feature_data = [] for i, col_name in enumerate(feature_names): biz_name = BUSINESS_MAP.get(col_name, col_name) feature_data.append((biz_name, real_values[i], vals[i])) # Sort by absolute impact feature_data.sort(key=lambda x: abs(x[2]), reverse=True) total_shap_mass = sum([abs(v) for _, _, v in feature_data]) + 1e-9 data_lines = [] shap_lines = [] for name, real_val, shap_val in feature_data[:3]: # Format currency if "Amount" in name or "Balance" in name or "Discrepancy" in name: val_str = f"${real_val:,.2f}" else: val_str = f"{real_val:.2f}" contrib_pct = (abs(shap_val) / total_shap_mass) * 100 logic_hint = "ANOMALY (Increased Risk)" if shap_val > 0 else "CONSISTENT BEHAVIOR (Mitigated Risk)" data_lines.append(f"- {name}: {val_str}") shap_lines.append(f"- {name}: {logic_hint} | Contribution: {contrib_pct:.1f}%") if not client: return "Error: Groq API Key missing." prompt = f""" You are a Senior Model Risk Examiner. Write a strict, short compliance explanation. CONTEXT: {chr(10).join(data_lines)} RISK FACTORS: {chr(10).join(shap_lines)} Write a "Notice of Adverse Action" explanation. Use the provided logic hints. Interpret negative SHAP as consistency. Keep it under 150 words. Professional tone only. Add a standard disclaimer at the end. """ try: completion = client.chat.completions.create( model="llama-3.3-70b-versatile", messages=[{"role": "user", "content": prompt}], temperature=0.1, max_tokens=300, ) return completion.choices[0].message.content except Exception as e: return f"LLM Error: {str(e)}" # --- 5. SIDEBAR UI --- st.sidebar.image("https://cdn-icons-png.flaticon.com/512/2666/2666505.png", width=100) st.sidebar.title("๐ณ Transaction Details") amount = st.sidebar.number_input("Amount ($)", value=350000.0, min_value=0.0, format="%.2f") old_bal = st.sidebar.number_input("Origin Old Balance ($)", value=80000.0, min_value=0.0, format="%.2f") new_bal = st.sidebar.number_input("Origin New Balance ($)", value=2000000.0, min_value=0.0, format="%.2f") txn_type = st.sidebar.selectbox("Type", ["TRANSFER", "CASH_OUT"]) # Calculate math discrepancy error_bal_orig = old_bal + amount - new_bal st.sidebar.info(f"Math Discrepancy: ${error_bal_orig:,.2f}") st.sidebar.markdown("---") st.sidebar.markdown("**๐งช Test Scenarios:**") col_a, col_b = st.sidebar.columns(2) with col_a: if st.button("๐ฐ Suspicious", use_container_width=True): st.session_state.update({ 'amount': 5000000.0, 'old_bal': 100000.0, 'new_bal': 100000.0, 'txn_type': 'CASH_OUT' }) st.rerun() with col_b: if st.button("โ Normal", use_container_width=True): st.session_state.update({ 'amount': 10000.0, 'old_bal': 50000.0, 'new_bal': 40000.0, 'txn_type': 'TRANSFER' }) st.rerun() # Apply session state if 'amount' in st.session_state: amount = st.session_state.amount if 'old_bal' in st.session_state: old_bal = st.session_state.old_bal if 'new_bal' in st.session_state: new_bal = st.session_state.new_bal if 'txn_type' in st.session_state: txn_type = st.session_state.txn_type # --- 6. MAIN APP LOGIC --- st.title("๐ฆ Credit-Scout AI Risk Engine") st.markdown("Real-time Fraud Detection with Llama 3.3 Explainability") if st.sidebar.button("Analyze Transaction"): with st.spinner("Analyzing Risk Patterns..."): # 1. Preprocess type_val = 1 if txn_type == 'CASH_OUT' else 0 # Build feature dictionary feature_dict = { 'step': 150, 'type_enc': type_val, 'amount': amount, 'oldbalanceOrg': old_bal, 'newbalanceOrig': new_bal, 'oldbalanceDest': 0.0, 'newbalanceDest': 0.0, 'errorBalanceOrig': error_bal_orig, 'errorBalanceDest': 0.0 } # Build array in exact column order raw_features = np.array([feature_dict[col] for col in columns]).reshape(1, -1) # Scale features with automatic fallback try: scaled_features = scaler.transform(raw_features) # Check if scaler is working properly if np.abs(scaled_features).max() > 100: # Scaler appears broken, use manual scaling manual_means = np.array([243.39, 0.5, 180000, 834000, 855000, 1100000, 1225000, 0, 0]) manual_stds = np.array([142.3, 0.5, 604000, 2900000, 2940000, 3400000, 3670000, 380000, 420000]) scaled_features = (raw_features - manual_means) / (manual_stds + 1e-8) except Exception: # Fallback to manual scaling manual_means = np.array([243.39, 0.5, 180000, 834000, 855000, 1100000, 1225000, 0, 0]) manual_stds = np.array([142.3, 0.5, 604000, 2900000, 2940000, 3400000, 3670000, 380000, 420000]) scaled_features = (raw_features - manual_means) / (manual_stds + 1e-8) # Reshape for LSTM lstm_input = scaled_features.reshape(1, 1, 9) # 2. Predict prediction_raw = model.predict(lstm_input, verbose=0) risk_prob = float(prediction_raw[0][0]) # 3. Explain (SHAP) shap_vals = explainer.shap_values(lstm_input) # 4. Display Results col1, col2 = st.columns([1, 2]) with col1: st.subheader("Risk Score") st.metric(label="Fraud Probability", value=f"{risk_prob:.2%}") threshold = 0.5 if risk_prob > threshold: st.markdown('
โ FLAGGED
', unsafe_allow_html=True) else: st.markdown('โ APPROVED
', unsafe_allow_html=True) with col2: st.subheader("Model Logic (SHAP)") # Process SHAP values if isinstance(shap_vals, list): shap_vals_plot = shap_vals[0] else: shap_vals_plot = shap_vals if len(shap_vals_plot.shape) > 2: shap_vals_plot = shap_vals_plot.reshape(1, -1) # Create SHAP plot fig, ax = plt.subplots(figsize=(10, 6)) shap.summary_plot( shap_vals_plot, raw_features, feature_names=columns, plot_type="bar", show=False ) st.pyplot(fig, clear_figure=True) plt.close('all') # 5. LLM Report st.markdown("---") st.subheader("๐ Audit Report (Llama 3.3)") with st.spinner("Drafting Compliance Notice..."): report = generate_explanation_cloud(shap_vals, scaled_features, columns, scaler) st.success("Report Generated") st.write(report) else: st.info("๐ Adjust transaction details in the sidebar and click 'Analyze Transaction' to begin.")