oke39 commited on
Commit
2841f6a
Β·
verified Β·
1 Parent(s): 69509f1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +291 -217
app.py CHANGED
@@ -1,217 +1,291 @@
1
- import streamlit as st
2
- import numpy as np
3
- import tensorflow as tf
4
- import shap
5
- import pickle
6
- import os
7
- import pandas as pd
8
- import matplotlib.pyplot as plt
9
- from groq import Groq
10
-
11
- # --- 1. SETUP & CONFIG ---
12
- st.set_page_config(page_title="Credit-Scout AI", layout="wide")
13
-
14
- # CSS for "Bank" styling
15
- st.markdown("""
16
- <style>
17
- .main { background-color: #f5f5f5; }
18
- .stButton>button { background-color: #000044; color: white; width: 100%; }
19
- .risk-high { color: #cc0000; font-weight: bold; font-size: 20px; }
20
- .risk-low { color: #006600; font-weight: bold; font-size: 20px; }
21
- </style>
22
- """, unsafe_allow_html=True)
23
-
24
- # Initialize Groq Client
25
- # It looks for the key in Hugging Face Secrets
26
- api_key = os.environ.get("GROQ_API_KEY")
27
- if not api_key:
28
- st.error("⚠️ GROQ_API_KEY not found in Secrets! The LLM explanation will fail.")
29
- client = None
30
- else:
31
- client = Groq(api_key=api_key)
32
-
33
- # --- 2. LOAD ARTIFACTS ---
34
- @st.cache_resource
35
- def load_resources():
36
- # Load Model (CPU mode)
37
- model = tf.keras.models.load_model('latest_checkpoint.h5')
38
-
39
- # Load Pickles
40
- with open('scaler.pkl', 'rb') as f:
41
- scaler = pickle.load(f)
42
- with open('columns.pkl', 'rb') as f:
43
- columns = pickle.load(f)
44
- with open('shap_metadata.pkl', 'rb') as f:
45
- shap_data = pickle.load(f)
46
-
47
- # Re-initialize Explainer
48
- explainer = shap.GradientExplainer(model, shap_data['background_sample'])
49
- return model, scaler, columns, explainer
50
-
51
- # Load everything once
52
- try:
53
- model, scaler, columns, explainer = load_resources()
54
- except Exception as e:
55
- st.error(f"Error loading files: {e}. Did you upload .h5 and .pkl files?")
56
- st.stop()
57
-
58
- # --- 3. BUSINESS MAPPING ---
59
- BUSINESS_MAP = {
60
- 'step': 'Transaction Hour',
61
- 'type_enc': 'Txn Type (Transfer/CashOut)',
62
- 'amount': 'Transaction Amount',
63
- 'oldbalanceOrg': 'Origin Acct Balance (Pre)',
64
- 'newbalanceOrig': 'Origin Acct Balance (Post)',
65
- 'oldbalanceDest': 'Recipient Acct Balance (Pre)',
66
- 'newbalanceDest': 'Recipient Acct Balance (Post)',
67
- 'errorBalanceOrig': 'Origin Math Discrepancy',
68
- 'errorBalanceDest': 'Recipient Math Discrepancy'
69
- }
70
-
71
- # --- 4. EXPLANATION FUNCTION (GROQ API) ---
72
- def generate_explanation_cloud(sample_idx_in_shap, shap_values, original_samples, feature_names, scaler):
73
- # A. Inverse Transform to get Real Money
74
- raw_scaled = original_samples.flatten()
75
- real_values = scaler.inverse_transform(raw_scaled.reshape(1, -1)).flatten()
76
-
77
- if isinstance(shap_values, list):
78
- vals = shap_values[0]
79
- else:
80
- vals = shap_values
81
- vals = vals.flatten()
82
-
83
- # B. Prepare Data for LLM
84
- feature_data = []
85
- for i, col_name in enumerate(feature_names):
86
- biz_name = BUSINESS_MAP.get(col_name, col_name)
87
- feature_data.append((biz_name, real_values[i], vals[i]))
88
-
89
- # Sort by absolute impact
90
- feature_data.sort(key=lambda x: abs(x[2]), reverse=True)
91
- total_shap_mass = sum([abs(v) for _, _, v in feature_data]) + 1e-9
92
-
93
- data_lines = []
94
- shap_lines = []
95
-
96
- for name, real_val, shap_val in feature_data[:3]:
97
- # Format Currency
98
- if "Amount" in name or "Balance" in name:
99
- val_str = f"${real_val:,.2f}"
100
- else:
101
- val_str = f"{real_val:.2f}"
102
-
103
- contrib_pct = (abs(shap_val) / total_shap_mass) * 100
104
- logic_hint = "ANOMALY (Increased Risk)" if shap_val > 0 else "CONSISTENT BEHAVIOR (Mitigated Risk)"
105
-
106
- data_lines.append(f"- {name}: {val_str}")
107
- shap_lines.append(f"- {name}: {logic_hint} | Contribution: {contrib_pct:.1f}%")
108
-
109
- # C. Call Groq
110
- if not client:
111
- return "Error: Groq API Key missing."
112
-
113
- prompt = f"""
114
- You are a Senior Model Risk Examiner. Write a strict, short compliance explanation.
115
-
116
- CONTEXT:
117
- {chr(10).join(data_lines)}
118
-
119
- RISK FACTORS:
120
- {chr(10).join(shap_lines)}
121
-
122
- Write a "Notice of Adverse Action" explanation.
123
- Use the provided logic hints. Interpret negative SHAP as consistency.
124
- Keep it under 150 words. Professional tone only. Add a standard disclaimer at the end.
125
- """
126
-
127
- try:
128
- completion = client.chat.completions.create(
129
- model="llama-3.1-70b-versatile",
130
- messages=[{"role": "user", "content": prompt}],
131
- temperature=0.1,
132
- max_tokens=300,
133
- )
134
- return completion.choices[0].message.content
135
- except Exception as e:
136
- return f"LLM Error: {str(e)}"
137
-
138
- # --- 5. SIDEBAR UI ---
139
- st.sidebar.image("https://cdn-icons-png.flaticon.com/512/2666/2666505.png", width=100)
140
- st.sidebar.title("πŸ’³ Transaction Details")
141
-
142
- amount = st.sidebar.number_input("Amount ($)", value=350000.0)
143
- old_bal = st.sidebar.number_input("Origin Old Balance ($)", value=1200000.0)
144
- new_bal = st.sidebar.number_input("Origin New Balance ($)", value=850000.0)
145
- txn_type = st.sidebar.selectbox("Type", ["TRANSFER", "CASH_OUT"])
146
-
147
- # Auto-calculate math features
148
- error_bal_orig = new_bal + amount - old_bal
149
- st.sidebar.info(f"Math Discrepancy: ${error_bal_orig:.2f}")
150
-
151
- # --- 6. MAIN APP LOGIC ---
152
- st.title("🏦 Credit-Scout AI Risk Engine")
153
- st.markdown("Real-time Fraud Detection with Llama 3.1 Explainability")
154
-
155
- if st.sidebar.button("Analyze Transaction"):
156
- with st.spinner("Analyzing Risk Patterns..."):
157
- # 1. Preprocess
158
- type_val = 0 if txn_type == 'TRANSFER' else 1
159
-
160
- # Construct Input Array (Must match columns.pkl order exactly!)
161
- # Standard PaySim columns: step, type, amount, oldBalOrg, newBalOrig, oldBalDest, newBalDest, errorOrig, errorDest
162
- raw_features = np.array([
163
- 150, # step (mock)
164
- type_val,
165
- amount,
166
- old_bal,
167
- new_bal,
168
- 0.0, # oldbalanceDest (mock)
169
- 0.0, # newbalanceDest (mock)
170
- error_bal_orig,
171
- 0.0 # errorBalanceDest (mock)
172
- ]).reshape(1, -1)
173
-
174
- # Scale & Reshape
175
- scaled_features = scaler.transform(raw_features)
176
- lstm_input = scaled_features.reshape(1, 1, 9)
177
-
178
- # 2. Predict
179
- risk_prob = model.predict(lstm_input)[0][0]
180
-
181
- # 3. Explain (SHAP)
182
- shap_vals = explainer.shap_values(lstm_input)
183
-
184
- # 4. Display Results
185
- col1, col2 = st.columns([1, 2])
186
-
187
- with col1:
188
- st.subheader("Risk Score")
189
- st.metric(label="Fraud Probability", value=f"{risk_prob:.2%}")
190
- if risk_prob > 0.8:
191
- st.markdown('<p class="risk-high">β›” FLAGGED</p>', unsafe_allow_html=True)
192
- else:
193
- st.markdown('<p class="risk-low">βœ… APPROVED</p>', unsafe_allow_html=True)
194
-
195
- with col2:
196
- st.subheader("Model Logic (SHAP)")
197
- # Fix SHAP plot dimensions
198
- st.set_option('deprecation.showPyplotGlobalUse', False)
199
- if isinstance(shap_vals, list):
200
- shap_vals_plot = shap_vals[0]
201
- else:
202
- shap_vals_plot = shap_vals
203
-
204
- fig = plt.figure()
205
- shap.summary_plot(shap_vals_plot, raw_features, feature_names=columns, plot_type="bar", show=False)
206
- st.pyplot(fig)
207
-
208
- # 5. LLM Report
209
- st.markdown("---")
210
- st.subheader("πŸ“ Audit Report (Llama 3.1)")
211
- with st.spinner("Drafting Compliance Notice via Groq..."):
212
- report = generate_explanation_cloud(0, shap_vals, lstm_input, columns, scaler)
213
- st.success("Report Generated")
214
- st.write(report)
215
-
216
- else:
217
- st.info(" Adjust transaction details in the sidebar to test the model.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ import shap
5
+ import pickle
6
+ import os
7
+ import pandas as pd
8
+ import matplotlib.pyplot as plt
9
+ from groq import Groq
10
+ import keras
11
+
12
+ # --- 1. SETUP & CONFIG ---
13
+ st.set_page_config(page_title="Credit-Scout AI", layout="wide")
14
+
15
+ # CSS for "Bank" styling
16
+ st.markdown("""
17
+ <style>
18
+ .main { background-color: #f5f5f5; }
19
+ .stButton>button { background-color: #000044; color: white; width: 100%; }
20
+ .risk-high { color: #cc0000; font-weight: bold; font-size: 20px; }
21
+ .risk-low { color: #006600; font-weight: bold; font-size: 20px; }
22
+ </style>
23
+ """, unsafe_allow_html=True)
24
+
25
+ # Initialize Groq Client
26
+ api_key = os.environ.get("GROQ_API_KEY")
27
+ if not api_key:
28
+ st.error("⚠️ GROQ_API_KEY not found in Secrets! The LLM explanation will fail.")
29
+ client = None
30
+ else:
31
+ client = Groq(api_key=api_key)
32
+
33
+ # --- 2. LOAD ARTIFACTS ---
34
+ @st.cache_resource
35
+ def load_resources():
36
+ class PatchedDense(tf.keras.layers.Dense):
37
+ def __init__(self, *args, **kwargs):
38
+ if 'quantization_config' in kwargs:
39
+ kwargs.pop('quantization_config')
40
+ super().__init__(*args, **kwargs)
41
+
42
+ try:
43
+ model = tf.keras.models.load_model(
44
+ 'latest_checkpoint.h5',
45
+ custom_objects={'Dense': PatchedDense},
46
+ compile=False
47
+ )
48
+ except Exception as e:
49
+ st.error(f"Critical Model Load Error: {e}")
50
+ st.stop()
51
+
52
+ with open('scaler.pkl', 'rb') as f:
53
+ scaler = pickle.load(f)
54
+ with open('columns.pkl', 'rb') as f:
55
+ columns = pickle.load(f)
56
+ with open('shap_metadata.pkl', 'rb') as f:
57
+ shap_data = pickle.load(f)
58
+
59
+ explainer = shap.GradientExplainer(model, shap_data['background_sample'])
60
+ return model, scaler, columns, explainer
61
+
62
+ try:
63
+ model, scaler, columns, explainer = load_resources()
64
+ except Exception as e:
65
+ st.error(f"Error loading files: {e}")
66
+ st.stop()
67
+
68
+ # --- 3. BUSINESS MAPPING ---
69
+ BUSINESS_MAP = {
70
+ 'step': 'Transaction Hour',
71
+ 'type_enc': 'Txn Type (Transfer/CashOut)',
72
+ 'amount': 'Transaction Amount',
73
+ 'oldbalanceOrg': 'Origin Acct Balance (Pre)',
74
+ 'newbalanceOrig': 'Origin Acct Balance (Post)',
75
+ 'oldbalanceDest': 'Recipient Acct Balance (Pre)',
76
+ 'newbalanceDest': 'Recipient Acct Balance (Post)',
77
+ 'errorBalanceOrig': 'Origin Math Discrepancy',
78
+ 'errorBalanceDest': 'Recipient Math Discrepancy'
79
+ }
80
+
81
+ # --- 4. EXPLANATION FUNCTION (GROQ API) ---
82
+ def generate_explanation_cloud(shap_values, original_samples, feature_names, scaler):
83
+ # Inverse transform to get real values
84
+ raw_scaled = original_samples.flatten()
85
+ real_values = scaler.inverse_transform(raw_scaled.reshape(1, -1)).flatten()
86
+
87
+ if isinstance(shap_values, list):
88
+ vals = shap_values[0]
89
+ else:
90
+ vals = shap_values
91
+ vals = vals.flatten()
92
+
93
+ # Prepare data for LLM
94
+ feature_data = []
95
+ for i, col_name in enumerate(feature_names):
96
+ biz_name = BUSINESS_MAP.get(col_name, col_name)
97
+ feature_data.append((biz_name, real_values[i], vals[i]))
98
+
99
+ # Sort by absolute impact
100
+ feature_data.sort(key=lambda x: abs(x[2]), reverse=True)
101
+ total_shap_mass = sum([abs(v) for _, _, v in feature_data]) + 1e-9
102
+
103
+ data_lines = []
104
+ shap_lines = []
105
+
106
+ for name, real_val, shap_val in feature_data[:3]:
107
+ # Format currency
108
+ if "Amount" in name or "Balance" in name or "Discrepancy" in name:
109
+ val_str = f"${real_val:,.2f}"
110
+ else:
111
+ val_str = f"{real_val:.2f}"
112
+
113
+ contrib_pct = (abs(shap_val) / total_shap_mass) * 100
114
+ logic_hint = "ANOMALY (Increased Risk)" if shap_val > 0 else "CONSISTENT BEHAVIOR (Mitigated Risk)"
115
+
116
+ data_lines.append(f"- {name}: {val_str}")
117
+ shap_lines.append(f"- {name}: {logic_hint} | Contribution: {contrib_pct:.1f}%")
118
+
119
+ if not client:
120
+ return "Error: Groq API Key missing."
121
+
122
+ prompt = f"""
123
+ You are a Senior Model Risk Examiner. Write a strict, short compliance explanation.
124
+
125
+ CONTEXT:
126
+ {chr(10).join(data_lines)}
127
+
128
+ RISK FACTORS:
129
+ {chr(10).join(shap_lines)}
130
+
131
+ Write a "Notice of Adverse Action" explanation.
132
+ Use the provided logic hints. Interpret negative SHAP as consistency.
133
+ Keep it under 150 words. Professional tone only. Add a standard disclaimer at the end.
134
+ """
135
+
136
+ try:
137
+ completion = client.chat.completions.create(
138
+ model="llama-3.3-70b-versatile",
139
+ messages=[{"role": "user", "content": prompt}],
140
+ temperature=0.1,
141
+ max_tokens=300,
142
+ )
143
+ return completion.choices[0].message.content
144
+ except Exception as e:
145
+ return f"LLM Error: {str(e)}"
146
+
147
+ # --- 5. SIDEBAR UI ---
148
+ st.sidebar.image("https://cdn-icons-png.flaticon.com/512/2666/2666505.png", width=100)
149
+ st.sidebar.title("πŸ’³ Transaction Details")
150
+
151
+ amount = st.sidebar.number_input("Amount ($)", value=350000.0, min_value=0.0, format="%.2f")
152
+ old_bal = st.sidebar.number_input("Origin Old Balance ($)", value=80000.0, min_value=0.0, format="%.2f")
153
+ new_bal = st.sidebar.number_input("Origin New Balance ($)", value=2000000.0, min_value=0.0, format="%.2f")
154
+ txn_type = st.sidebar.selectbox("Type", ["TRANSFER", "CASH_OUT"])
155
+
156
+ # Calculate math discrepancy
157
+ error_bal_orig = old_bal + amount - new_bal
158
+ st.sidebar.info(f"Math Discrepancy: ${error_bal_orig:,.2f}")
159
+
160
+ st.sidebar.markdown("---")
161
+ st.sidebar.markdown("**πŸ§ͺ Test Scenarios:**")
162
+
163
+ col_a, col_b = st.sidebar.columns(2)
164
+
165
+ with col_a:
166
+ if st.button("πŸ’° Suspicious", use_container_width=True):
167
+ st.session_state.update({
168
+ 'amount': 5000000.0,
169
+ 'old_bal': 100000.0,
170
+ 'new_bal': 100000.0,
171
+ 'txn_type': 'CASH_OUT'
172
+ })
173
+ st.rerun()
174
+
175
+ with col_b:
176
+ if st.button("βœ… Normal", use_container_width=True):
177
+ st.session_state.update({
178
+ 'amount': 10000.0,
179
+ 'old_bal': 50000.0,
180
+ 'new_bal': 40000.0,
181
+ 'txn_type': 'TRANSFER'
182
+ })
183
+ st.rerun()
184
+
185
+ # Apply session state
186
+ if 'amount' in st.session_state:
187
+ amount = st.session_state.amount
188
+ if 'old_bal' in st.session_state:
189
+ old_bal = st.session_state.old_bal
190
+ if 'new_bal' in st.session_state:
191
+ new_bal = st.session_state.new_bal
192
+ if 'txn_type' in st.session_state:
193
+ txn_type = st.session_state.txn_type
194
+
195
+ # --- 6. MAIN APP LOGIC ---
196
+ st.title("🏦 Credit-Scout AI Risk Engine")
197
+ st.markdown("Real-time Fraud Detection with Llama 3.3 Explainability")
198
+
199
+ if st.sidebar.button("Analyze Transaction"):
200
+ with st.spinner("Analyzing Risk Patterns..."):
201
+ # 1. Preprocess
202
+ type_val = 1 if txn_type == 'CASH_OUT' else 0
203
+
204
+ # Build feature dictionary
205
+ feature_dict = {
206
+ 'step': 150,
207
+ 'type_enc': type_val,
208
+ 'amount': amount,
209
+ 'oldbalanceOrg': old_bal,
210
+ 'newbalanceOrig': new_bal,
211
+ 'oldbalanceDest': 0.0,
212
+ 'newbalanceDest': 0.0,
213
+ 'errorBalanceOrig': error_bal_orig,
214
+ 'errorBalanceDest': 0.0
215
+ }
216
+
217
+ # Build array in exact column order
218
+ raw_features = np.array([feature_dict[col] for col in columns]).reshape(1, -1)
219
+
220
+ # Scale features with automatic fallback
221
+ try:
222
+ scaled_features = scaler.transform(raw_features)
223
+
224
+ # Check if scaler is working properly
225
+ if np.abs(scaled_features).max() > 100:
226
+ # Scaler appears broken, use manual scaling
227
+ manual_means = np.array([243.39, 0.5, 180000, 834000, 855000, 1100000, 1225000, 0, 0])
228
+ manual_stds = np.array([142.3, 0.5, 604000, 2900000, 2940000, 3400000, 3670000, 380000, 420000])
229
+ scaled_features = (raw_features - manual_means) / (manual_stds + 1e-8)
230
+ except Exception:
231
+ # Fallback to manual scaling
232
+ manual_means = np.array([243.39, 0.5, 180000, 834000, 855000, 1100000, 1225000, 0, 0])
233
+ manual_stds = np.array([142.3, 0.5, 604000, 2900000, 2940000, 3400000, 3670000, 380000, 420000])
234
+ scaled_features = (raw_features - manual_means) / (manual_stds + 1e-8)
235
+
236
+ # Reshape for LSTM
237
+ lstm_input = scaled_features.reshape(1, 1, 9)
238
+
239
+ # 2. Predict
240
+ prediction_raw = model.predict(lstm_input, verbose=0)
241
+ risk_prob = float(prediction_raw[0][0])
242
+
243
+ # 3. Explain (SHAP)
244
+ shap_vals = explainer.shap_values(lstm_input)
245
+
246
+ # 4. Display Results
247
+ col1, col2 = st.columns([1, 2])
248
+
249
+ with col1:
250
+ st.subheader("Risk Score")
251
+ st.metric(label="Fraud Probability", value=f"{risk_prob:.2%}")
252
+
253
+ threshold = 0.5
254
+ if risk_prob > threshold:
255
+ st.markdown('<p class="risk-high">β›” FLAGGED</p>', unsafe_allow_html=True)
256
+ else:
257
+ st.markdown('<p class="risk-low">βœ… APPROVED</p>', unsafe_allow_html=True)
258
+
259
+ with col2:
260
+ st.subheader("Model Logic (SHAP)")
261
+
262
+ # Process SHAP values
263
+ if isinstance(shap_vals, list):
264
+ shap_vals_plot = shap_vals[0]
265
+ else:
266
+ shap_vals_plot = shap_vals
267
+
268
+ if len(shap_vals_plot.shape) > 2:
269
+ shap_vals_plot = shap_vals_plot.reshape(1, -1)
270
+
271
+ # Create SHAP plot
272
+ fig, ax = plt.subplots(figsize=(10, 6))
273
+ shap.summary_plot(
274
+ shap_vals_plot,
275
+ raw_features,
276
+ feature_names=columns,
277
+ plot_type="bar",
278
+ show=False
279
+ )
280
+ st.pyplot(fig, clear_figure=True)
281
+ plt.close('all')
282
+
283
+ # 5. LLM Report
284
+ st.markdown("---")
285
+ st.subheader("πŸ“ Audit Report (Llama 3.3)")
286
+ with st.spinner("Drafting Compliance Notice..."):
287
+ report = generate_explanation_cloud(shap_vals, scaled_features, columns, scaler)
288
+ st.success("Report Generated")
289
+ st.write(report)
290
+ else:
291
+ st.info("πŸ‘ˆ Adjust transaction details in the sidebar and click 'Analyze Transaction' to begin.")