MohitRajput45 commited on
Commit
2a1adf2
Β·
verified Β·
1 Parent(s): be4f339

Update app/streamlit_app.py

Browse files
Files changed (1) hide show
  1. app/streamlit_app.py +147 -85
app/streamlit_app.py CHANGED
@@ -19,136 +19,198 @@ from src.monitoring.db import save_to_db
19
 
20
  st.set_page_config(page_title="Fraud Guard", layout="wide")
21
 
22
- # --- Initialize Session State ---
23
  if "t_time" not in st.session_state: st.session_state.t_time = 10000.0
24
  if "t_amount" not in st.session_state: st.session_state.t_amount = 100.0
25
  for i in range(1, 29):
26
  if f"v_{i}" not in st.session_state: st.session_state[f"v_{i}"] = 0.0
27
 
28
  def generate_sample(is_fraud=False):
 
29
  if is_fraud:
30
  fraud_database = [
31
- {"Time": 406.0, "Amount": 0.00, "V": [-2.312,1.951,-1.609,3.997,-0.522,-1.426,-2.537,1.391,-2.770,-2.772,3.202,-2.899,-0.595,-4.289,0.389,-1.140,-2.830,-0.016,0.416,0.126,0.517,-0.035,-0.465,0.320,0.044,0.177,0.261,-0.143]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  ]
33
- chosen = random.choice(fraud_database)
34
- st.session_state.t_time = chosen["Time"]
35
- st.session_state.t_amount = chosen["Amount"]
36
- for i in range(1,29):
37
- st.session_state[f"v_{i}"] = chosen["V"][i-1]
 
 
 
38
  else:
39
- st.session_state.t_time = random.uniform(100,150000)
40
- st.session_state.t_amount = random.uniform(5,150)
41
- for i in range(1,29):
42
- st.session_state[f"v_{i}"] = random.uniform(-1,1)
43
 
44
- # --- Sidebar Navigation ---
45
  page = st.sidebar.selectbox("πŸ“Œ Choose Section", ["Prediction", "Explainability (SHAP)", "Drift Monitoring"])
46
 
47
- # =========================
48
- # πŸ”Ή PREDICTION PAGE (NO SHAP)
49
- # =========================
50
  if page == "Prediction":
 
 
 
 
 
 
51
 
52
- st.markdown("## πŸ’³ Fraud Guard Intelligence")
53
-
54
- col1, col2 = st.columns([1,2])
55
 
56
  with col1:
57
- if st.button("Simulate Normal"):
58
- generate_sample(False)
59
- if st.button("Simulate Fraud"):
60
- generate_sample(True)
61
-
62
- with st.form("form"):
63
- t_time = st.slider("Time",0.0,172800.0,key="t_time")
64
- t_amount = st.slider("Amount",0.0,5000.0,key="t_amount")
65
-
66
- v_data = {}
67
- for i in range(1,29):
68
- v_data[f"V{i}"] = st.number_input(f"V{i}",key=f"v_{i}")
69
-
70
- threshold = st.slider("Threshold",0.05,0.95,0.15)
71
- submit = st.form_submit_button("Predict")
 
 
 
 
 
 
 
 
 
72
 
73
  with col2:
74
- if submit:
75
- payload = {"Time":st.session_state.t_time,"Amount":st.session_state.t_amount,**v_data}
 
 
 
 
 
 
 
 
76
  try:
77
- res = requests.post(API_URL,json=payload).json()
78
- prob = res["fraud_probability"]
79
- pred = 1 if prob>threshold else 0
80
-
81
- if pred:
82
- st.error("🚨 FRAUD")
 
 
 
 
83
  else:
84
- st.success("βœ… SAFE")
85
-
86
- st.metric("Risk",f"{prob:.4%}")
 
 
 
87
  st.progress(float(prob))
88
-
89
- # πŸ”₯ Store payload for SHAP page
90
- st.session_state["last_payload"] = payload
91
 
92
  except Exception as e:
93
- st.error(e)
94
 
95
 
96
- # =========================
97
- # πŸ”Ή SHAP PAGE (SEPARATE)
98
- # =========================
99
  elif page == "Explainability (SHAP)":
100
-
101
  st.title("🧠 Explainable AI (SHAP)")
102
 
103
  if "last_payload" not in st.session_state:
104
- st.warning("Run prediction first")
105
  else:
106
  if st.button("Generate SHAP Explanation"):
107
  try:
108
  payload = st.session_state["last_payload"]
109
 
110
- input_df = pd.DataFrame([payload])
111
-
112
- pipeline = PredictPipeline()
113
- processed = pipeline.preprocess(input_df)
114
-
115
- explainer = ShapExplainer()
116
- shap_values = explainer.explain(processed)
117
-
118
- fig, ax = plt.subplots(figsize=(8,4))
119
- shap.plots.waterfall(shap_values[0], show=False)
120
- st.pyplot(fig)
 
121
 
122
  except Exception as e:
123
- st.error(e)
124
 
125
 
126
- # =========================
127
- # πŸ”Ή DRIFT PAGE (UNCHANGED)
128
- # =========================
129
  elif page == "Drift Monitoring":
130
-
131
- st.title("πŸ“‰ Drift Monitoring")
132
-
133
- if st.button("Inject Drift"):
134
- for _ in range(50):
135
- data={"Time":random.uniform(10,50000),"Amount":random.uniform(1000,5000)}
136
- for i in range(1,29):
137
- data[f"V{i}"]=random.uniform(-15,15)
138
- save_to_db(data,1,0.99)
139
- st.success("Injected")
 
 
 
 
 
140
 
141
  try:
142
  from src.monitoring.drift import detect_drift
143
  from src.pipeline.retrain_pipeline import retrain
144
  except:
145
- st.error("Not supported")
146
  st.stop()
147
 
148
- if st.button("Run Drift"):
149
- detect_drift("data/creditcard.csv")
150
- st.success("Done")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
- if st.button("Retrain"):
153
- retrain()
154
- st.success("Retrained")
 
 
 
 
 
19
 
20
  st.set_page_config(page_title="Fraud Guard", layout="wide")
21
 
22
+ # --- Initialize Session State for Auto-Fill Buttons ---
23
  if "t_time" not in st.session_state: st.session_state.t_time = 10000.0
24
  if "t_amount" not in st.session_state: st.session_state.t_amount = 100.0
25
  for i in range(1, 29):
26
  if f"v_{i}" not in st.session_state: st.session_state[f"v_{i}"] = 0.0
27
 
28
  def generate_sample(is_fraud=False):
29
+ """Fills the UI with either a normal transaction or a simulated fraud attack"""
30
  if is_fraud:
31
  fraud_database = [
32
+ {
33
+ "Time": 406.0, "Amount": 0.00,
34
+ "V": [-2.312, 1.951, -1.609, 3.997, -0.522, -1.426, -2.537, 1.391, -2.770, -2.772,
35
+ 3.202, -2.899, -0.595, -4.289, 0.389, -1.140, -2.830, -0.016, 0.416, 0.126,
36
+ 0.517, -0.035, -0.465, 0.320, 0.044, 0.177, 0.261, -0.143]
37
+ },
38
+ {
39
+ "Time": 12500.0, "Amount": 99.99,
40
+ "V": [-0.95, 0.52, -1.53, 0.85, -0.21, 0.11, -0.45, 0.22, -0.63, -1.05,
41
+ 1.20, -1.55, 0.30, -2.01, 0.10, -0.55, -1.22, 0.20, 0.45, -0.10,
42
+ 0.25, 0.15, -0.12, 0.05, 0.22, -0.15, 0.02, 0.05]
43
+ },
44
+ {
45
+ "Time": 4462.0, "Amount": 1.00,
46
+ "V": [-2.303, 1.759, -0.359, 2.330, -0.821, -0.075, -0.560, 1.214, -1.385, -2.776,
47
+ 3.231, -2.719, -1.059, -3.535, -1.583, -1.488, -2.573, -0.739, 0.380, -0.430,
48
+ -0.294, -0.932, 0.172, -0.087, -0.156, -0.542, 0.039, -0.153]
49
+ }
50
  ]
51
+
52
+ chosen_fraud = random.choice(fraud_database)
53
+
54
+ st.session_state.t_time = chosen_fraud["Time"]
55
+ st.session_state.t_amount = chosen_fraud["Amount"]
56
+ for i in range(1, 29):
57
+ st.session_state[f"v_{i}"] = chosen_fraud["V"][i-1]
58
+
59
  else:
60
+ st.session_state.t_time = random.uniform(100, 150000)
61
+ st.session_state.t_amount = random.uniform(5, 150)
62
+ for i in range(1, 29):
63
+ st.session_state[f"v_{i}"] = random.uniform(-1.0, 1.0)
64
 
65
+ # --- UI Sidebar & Navigation ---
66
  page = st.sidebar.selectbox("πŸ“Œ Choose Section", ["Prediction", "Explainability (SHAP)", "Drift Monitoring"])
67
 
 
 
 
68
  if page == "Prediction":
69
+ st.markdown("""
70
+ <div style='text-align: center; padding: 1rem 0;'>
71
+ <h1 style='color: #1E3A8A;'>πŸ’³ Fraud Guard Intelligence</h1>
72
+ <p style='color: #6B7280; font-size: 1.2rem;'>Real-Time Transaction Risk Analysis</p>
73
+ </div>
74
+ """, unsafe_allow_html=True)
75
 
76
+ col1, col2 = st.columns([1, 2])
 
 
77
 
78
  with col1:
79
+ st.markdown("### πŸ› οΈ Demo Controls")
80
+ demo_col1, demo_col2 = st.columns(2)
81
+ with demo_col1:
82
+ if st.button("βœ… Simulate Normal User", use_container_width=True):
83
+ generate_sample(is_fraud=False)
84
+ with demo_col2:
85
+ if st.button("🚨 Simulate Fraud Attack", type="primary", use_container_width=True):
86
+ generate_sample(is_fraud=True)
87
+
88
+ st.markdown("### πŸ“₯ Transaction Input")
89
+ with st.container(border=True):
90
+ with st.form("transaction_form"):
91
+
92
+ t_time = st.slider("Time (Sec)", 0.0, 172800.0, key="t_time")
93
+ t_amount = st.slider("Amount ($)", 0.0, 5000.0, key="t_amount")
94
+
95
+ with st.expander("PCA Feature Vectors (V1 - V28)", expanded=False):
96
+ v_data = {}
97
+ for i in range(1, 29):
98
+ v_data[f"V{i}"] = st.number_input(f"V{i}", key=f"v_{i}", format="%.4f")
99
+
100
+ st.markdown("---")
101
+ threshold = st.slider("AI Sensitivity (Threshold)", 0.05, 0.95, 0.15)
102
+ submit_btn = st.form_submit_button("πŸ” Run Analysis", use_container_width=True)
103
 
104
  with col2:
105
+ st.markdown("### πŸ“Š Live Telemetry & Assessment")
106
+ if not submit_btn:
107
+ st.info("Awaiting transaction payload. Click 'Simulate' then 'Run Analysis'.")
108
+
109
+ if submit_btn:
110
+ payload = {"Time": st.session_state.t_time, "Amount": st.session_state.t_amount, **v_data}
111
+
112
+ # βœ… Save payload for SHAP page
113
+ st.session_state["last_payload"] = payload
114
+
115
  try:
116
+ with st.spinner("Analyzing threat vectors..."):
117
+ response = requests.post(API_URL, json=payload, timeout=30)
118
+ result = response.json()
119
+
120
+ prob = result["fraud_probability"]
121
+ pred = 1 if prob > threshold else 0
122
+ action = "🚫 Block Transaction" if pred == 1 else "βœ… Allow Transaction"
123
+
124
+ if pred == 1:
125
+ st.error(f"🚨 FRAUD DETECTED: {action}")
126
  else:
127
+ st.success(f"βœ… TRANSACTION SAFE: {action}")
128
+
129
+ m_col1, m_col2 = st.columns(2)
130
+ m_col1.metric("Risk Level", f"{prob:.4%}")
131
+ m_col2.metric("Prediction Output", pred)
132
+
133
  st.progress(float(prob))
134
+ st.markdown("---")
 
 
135
 
136
  except Exception as e:
137
+ st.error(f"⏳ Error connecting to API: {e}")
138
 
139
 
140
+ # πŸ”₯ NEW SHAP PAGE
 
 
141
  elif page == "Explainability (SHAP)":
 
142
  st.title("🧠 Explainable AI (SHAP)")
143
 
144
  if "last_payload" not in st.session_state:
145
+ st.warning("⚠️ Please run a prediction first.")
146
  else:
147
  if st.button("Generate SHAP Explanation"):
148
  try:
149
  payload = st.session_state["last_payload"]
150
 
151
+ with st.spinner("Generating explanations..."):
152
+ input_df = pd.DataFrame([payload])
153
+
154
+ pipeline = PredictPipeline()
155
+ processed_df = pipeline.preprocess(input_df)
156
+
157
+ explainer = ShapExplainer()
158
+ shap_values = explainer.explain(processed_df)
159
+
160
+ fig, ax = plt.subplots(figsize=(8, 4))
161
+ shap.plots.waterfall(shap_values[0], show=False)
162
+ st.pyplot(fig)
163
 
164
  except Exception as e:
165
+ st.error(f"❌ Error: {e}")
166
 
167
 
 
 
 
168
  elif page == "Drift Monitoring":
169
+ st.title("πŸ“‰ Data Drift Monitoring")
170
+ st.markdown("### Monitor model health & trigger verified retraining")
171
+
172
+ with st.expander("πŸ› οΈ Demo Tools: Force Data Drift", expanded=True):
173
+ st.write("Inject 50 heavily skewed transactions into the database to trigger a statistical drift warning.")
174
+ if st.button("πŸ’‰ Inject Synthetic Drift Data", type="primary"):
175
+ with st.spinner("Injecting bad data into Cloud DB..."):
176
+ for _ in range(50):
177
+ skewed_data = {"Time": random.uniform(10, 50000), "Amount": random.uniform(1000, 5000)}
178
+ for i in range(1, 29):
179
+ skewed_data[f"V{i}"] = random.uniform(-15.0, 15.0)
180
+ save_to_db(skewed_data, pred=1, prob=0.99)
181
+ st.success("βœ… 50 Skewed rows injected!")
182
+
183
+ st.markdown("---")
184
 
185
  try:
186
  from src.monitoring.drift import detect_drift
187
  from src.pipeline.retrain_pipeline import retrain
188
  except:
189
+ st.error("⚠️ Drift feature not supported in this environment")
190
  st.stop()
191
 
192
+ if st.button("πŸš€ Run Drift Detection"):
193
+ with st.spinner("Running statistical drift analysis..."):
194
+ try:
195
+ report_path = detect_drift("data/creditcard.csv")
196
+ if report_path:
197
+ st.success("βœ… Drift report generated!")
198
+ st.session_state["drift_done"] = True
199
+ else:
200
+ st.warning("⚠️ Not enough data in live DB")
201
+ except Exception as e:
202
+ st.error(f"⚠️ Error running drift: {e}")
203
+
204
+ report_path = "reports/drift_report.html"
205
+ if os.path.exists(report_path):
206
+ with open(report_path, "r", encoding="utf-8") as f:
207
+ html = f.read()
208
+ components.html(html, height=800, scrolling=True)
209
 
210
+ if st.session_state.get("drift_done", False):
211
+ if st.button("⚑ Retrain Model"):
212
+ try:
213
+ retrain()
214
+ st.success("βœ… Model retrained successfully!")
215
+ except Exception as e:
216
+ st.error(f"❌ Error: {e}")