oke39's picture
Update app.py
2841f6a verified
import streamlit as st
import numpy as np
import tensorflow as tf
import shap
import pickle
import os
import pandas as pd
import matplotlib.pyplot as plt
from groq import Groq
import keras
# --- 1. SETUP & CONFIG ---
st.set_page_config(page_title="Credit-Scout AI", layout="wide")
# CSS for "Bank" styling
st.markdown("""
<style>
.main { background-color: #f5f5f5; }
.stButton>button { background-color: #000044; color: white; width: 100%; }
.risk-high { color: #cc0000; font-weight: bold; font-size: 20px; }
.risk-low { color: #006600; font-weight: bold; font-size: 20px; }
</style>
""", unsafe_allow_html=True)
# Initialize Groq Client
api_key = os.environ.get("GROQ_API_KEY")
if not api_key:
st.error("⚠️ GROQ_API_KEY not found in Secrets! The LLM explanation will fail.")
client = None
else:
client = Groq(api_key=api_key)
# --- 2. LOAD ARTIFACTS ---
@st.cache_resource
def load_resources():
class PatchedDense(tf.keras.layers.Dense):
def __init__(self, *args, **kwargs):
if 'quantization_config' in kwargs:
kwargs.pop('quantization_config')
super().__init__(*args, **kwargs)
try:
model = tf.keras.models.load_model(
'latest_checkpoint.h5',
custom_objects={'Dense': PatchedDense},
compile=False
)
except Exception as e:
st.error(f"Critical Model Load Error: {e}")
st.stop()
with open('scaler.pkl', 'rb') as f:
scaler = pickle.load(f)
with open('columns.pkl', 'rb') as f:
columns = pickle.load(f)
with open('shap_metadata.pkl', 'rb') as f:
shap_data = pickle.load(f)
explainer = shap.GradientExplainer(model, shap_data['background_sample'])
return model, scaler, columns, explainer
try:
model, scaler, columns, explainer = load_resources()
except Exception as e:
st.error(f"Error loading files: {e}")
st.stop()
# --- 3. BUSINESS MAPPING ---
BUSINESS_MAP = {
'step': 'Transaction Hour',
'type_enc': 'Txn Type (Transfer/CashOut)',
'amount': 'Transaction Amount',
'oldbalanceOrg': 'Origin Acct Balance (Pre)',
'newbalanceOrig': 'Origin Acct Balance (Post)',
'oldbalanceDest': 'Recipient Acct Balance (Pre)',
'newbalanceDest': 'Recipient Acct Balance (Post)',
'errorBalanceOrig': 'Origin Math Discrepancy',
'errorBalanceDest': 'Recipient Math Discrepancy'
}
# --- 4. EXPLANATION FUNCTION (GROQ API) ---
def generate_explanation_cloud(shap_values, original_samples, feature_names, scaler):
# Inverse transform to get real values
raw_scaled = original_samples.flatten()
real_values = scaler.inverse_transform(raw_scaled.reshape(1, -1)).flatten()
if isinstance(shap_values, list):
vals = shap_values[0]
else:
vals = shap_values
vals = vals.flatten()
# Prepare data for LLM
feature_data = []
for i, col_name in enumerate(feature_names):
biz_name = BUSINESS_MAP.get(col_name, col_name)
feature_data.append((biz_name, real_values[i], vals[i]))
# Sort by absolute impact
feature_data.sort(key=lambda x: abs(x[2]), reverse=True)
total_shap_mass = sum([abs(v) for _, _, v in feature_data]) + 1e-9
data_lines = []
shap_lines = []
for name, real_val, shap_val in feature_data[:3]:
# Format currency
if "Amount" in name or "Balance" in name or "Discrepancy" in name:
val_str = f"${real_val:,.2f}"
else:
val_str = f"{real_val:.2f}"
contrib_pct = (abs(shap_val) / total_shap_mass) * 100
logic_hint = "ANOMALY (Increased Risk)" if shap_val > 0 else "CONSISTENT BEHAVIOR (Mitigated Risk)"
data_lines.append(f"- {name}: {val_str}")
shap_lines.append(f"- {name}: {logic_hint} | Contribution: {contrib_pct:.1f}%")
if not client:
return "Error: Groq API Key missing."
prompt = f"""
You are a Senior Model Risk Examiner. Write a strict, short compliance explanation.
CONTEXT:
{chr(10).join(data_lines)}
RISK FACTORS:
{chr(10).join(shap_lines)}
Write a "Notice of Adverse Action" explanation.
Use the provided logic hints. Interpret negative SHAP as consistency.
Keep it under 150 words. Professional tone only. Add a standard disclaimer at the end.
"""
try:
completion = client.chat.completions.create(
model="llama-3.3-70b-versatile",
messages=[{"role": "user", "content": prompt}],
temperature=0.1,
max_tokens=300,
)
return completion.choices[0].message.content
except Exception as e:
return f"LLM Error: {str(e)}"
# --- 5. SIDEBAR UI ---
st.sidebar.image("https://cdn-icons-png.flaticon.com/512/2666/2666505.png", width=100)
st.sidebar.title("πŸ’³ Transaction Details")
amount = st.sidebar.number_input("Amount ($)", value=350000.0, min_value=0.0, format="%.2f")
old_bal = st.sidebar.number_input("Origin Old Balance ($)", value=80000.0, min_value=0.0, format="%.2f")
new_bal = st.sidebar.number_input("Origin New Balance ($)", value=2000000.0, min_value=0.0, format="%.2f")
txn_type = st.sidebar.selectbox("Type", ["TRANSFER", "CASH_OUT"])
# Calculate math discrepancy
error_bal_orig = old_bal + amount - new_bal
st.sidebar.info(f"Math Discrepancy: ${error_bal_orig:,.2f}")
st.sidebar.markdown("---")
st.sidebar.markdown("**πŸ§ͺ Test Scenarios:**")
col_a, col_b = st.sidebar.columns(2)
with col_a:
if st.button("πŸ’° Suspicious", use_container_width=True):
st.session_state.update({
'amount': 5000000.0,
'old_bal': 100000.0,
'new_bal': 100000.0,
'txn_type': 'CASH_OUT'
})
st.rerun()
with col_b:
if st.button("βœ… Normal", use_container_width=True):
st.session_state.update({
'amount': 10000.0,
'old_bal': 50000.0,
'new_bal': 40000.0,
'txn_type': 'TRANSFER'
})
st.rerun()
# Apply session state
if 'amount' in st.session_state:
amount = st.session_state.amount
if 'old_bal' in st.session_state:
old_bal = st.session_state.old_bal
if 'new_bal' in st.session_state:
new_bal = st.session_state.new_bal
if 'txn_type' in st.session_state:
txn_type = st.session_state.txn_type
# --- 6. MAIN APP LOGIC ---
st.title("🏦 Credit-Scout AI Risk Engine")
st.markdown("Real-time Fraud Detection with Llama 3.3 Explainability")
if st.sidebar.button("Analyze Transaction"):
with st.spinner("Analyzing Risk Patterns..."):
# 1. Preprocess
type_val = 1 if txn_type == 'CASH_OUT' else 0
# Build feature dictionary
feature_dict = {
'step': 150,
'type_enc': type_val,
'amount': amount,
'oldbalanceOrg': old_bal,
'newbalanceOrig': new_bal,
'oldbalanceDest': 0.0,
'newbalanceDest': 0.0,
'errorBalanceOrig': error_bal_orig,
'errorBalanceDest': 0.0
}
# Build array in exact column order
raw_features = np.array([feature_dict[col] for col in columns]).reshape(1, -1)
# Scale features with automatic fallback
try:
scaled_features = scaler.transform(raw_features)
# Check if scaler is working properly
if np.abs(scaled_features).max() > 100:
# Scaler appears broken, use manual scaling
manual_means = np.array([243.39, 0.5, 180000, 834000, 855000, 1100000, 1225000, 0, 0])
manual_stds = np.array([142.3, 0.5, 604000, 2900000, 2940000, 3400000, 3670000, 380000, 420000])
scaled_features = (raw_features - manual_means) / (manual_stds + 1e-8)
except Exception:
# Fallback to manual scaling
manual_means = np.array([243.39, 0.5, 180000, 834000, 855000, 1100000, 1225000, 0, 0])
manual_stds = np.array([142.3, 0.5, 604000, 2900000, 2940000, 3400000, 3670000, 380000, 420000])
scaled_features = (raw_features - manual_means) / (manual_stds + 1e-8)
# Reshape for LSTM
lstm_input = scaled_features.reshape(1, 1, 9)
# 2. Predict
prediction_raw = model.predict(lstm_input, verbose=0)
risk_prob = float(prediction_raw[0][0])
# 3. Explain (SHAP)
shap_vals = explainer.shap_values(lstm_input)
# 4. Display Results
col1, col2 = st.columns([1, 2])
with col1:
st.subheader("Risk Score")
st.metric(label="Fraud Probability", value=f"{risk_prob:.2%}")
threshold = 0.5
if risk_prob > threshold:
st.markdown('<p class="risk-high">β›” FLAGGED</p>', unsafe_allow_html=True)
else:
st.markdown('<p class="risk-low">βœ… APPROVED</p>', unsafe_allow_html=True)
with col2:
st.subheader("Model Logic (SHAP)")
# Process SHAP values
if isinstance(shap_vals, list):
shap_vals_plot = shap_vals[0]
else:
shap_vals_plot = shap_vals
if len(shap_vals_plot.shape) > 2:
shap_vals_plot = shap_vals_plot.reshape(1, -1)
# Create SHAP plot
fig, ax = plt.subplots(figsize=(10, 6))
shap.summary_plot(
shap_vals_plot,
raw_features,
feature_names=columns,
plot_type="bar",
show=False
)
st.pyplot(fig, clear_figure=True)
plt.close('all')
# 5. LLM Report
st.markdown("---")
st.subheader("πŸ“ Audit Report (Llama 3.3)")
with st.spinner("Drafting Compliance Notice..."):
report = generate_explanation_cloud(shap_vals, scaled_features, columns, scaler)
st.success("Report Generated")
st.write(report)
else:
st.info("πŸ‘ˆ Adjust transaction details in the sidebar and click 'Analyze Transaction' to begin.")