Update app/streamlit_app.py
Browse files- app/streamlit_app.py +147 -85
app/streamlit_app.py
CHANGED
|
@@ -19,136 +19,198 @@ from src.monitoring.db import save_to_db
|
|
| 19 |
|
| 20 |
st.set_page_config(page_title="Fraud Guard", layout="wide")
|
| 21 |
|
| 22 |
-
# --- Initialize Session State ---
|
| 23 |
if "t_time" not in st.session_state: st.session_state.t_time = 10000.0
|
| 24 |
if "t_amount" not in st.session_state: st.session_state.t_amount = 100.0
|
| 25 |
for i in range(1, 29):
|
| 26 |
if f"v_{i}" not in st.session_state: st.session_state[f"v_{i}"] = 0.0
|
| 27 |
|
| 28 |
def generate_sample(is_fraud=False):
|
|
|
|
| 29 |
if is_fraud:
|
| 30 |
fraud_database = [
|
| 31 |
-
{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
]
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
| 38 |
else:
|
| 39 |
-
st.session_state.t_time = random.uniform(100,150000)
|
| 40 |
-
st.session_state.t_amount = random.uniform(5,150)
|
| 41 |
-
for i in range(1,29):
|
| 42 |
-
st.session_state[f"v_{i}"] = random.uniform(-1,1)
|
| 43 |
|
| 44 |
-
# --- Sidebar Navigation ---
|
| 45 |
page = st.sidebar.selectbox("π Choose Section", ["Prediction", "Explainability (SHAP)", "Drift Monitoring"])
|
| 46 |
|
| 47 |
-
# =========================
|
| 48 |
-
# πΉ PREDICTION PAGE (NO SHAP)
|
| 49 |
-
# =========================
|
| 50 |
if page == "Prediction":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
-
st.
|
| 53 |
-
|
| 54 |
-
col1, col2 = st.columns([1,2])
|
| 55 |
|
| 56 |
with col1:
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
with
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
with col2:
|
| 74 |
-
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
try:
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
else:
|
| 84 |
-
st.success("β
SAFE")
|
| 85 |
-
|
| 86 |
-
st.
|
|
|
|
|
|
|
|
|
|
| 87 |
st.progress(float(prob))
|
| 88 |
-
|
| 89 |
-
# π₯ Store payload for SHAP page
|
| 90 |
-
st.session_state["last_payload"] = payload
|
| 91 |
|
| 92 |
except Exception as e:
|
| 93 |
-
st.error(e)
|
| 94 |
|
| 95 |
|
| 96 |
-
#
|
| 97 |
-
# πΉ SHAP PAGE (SEPARATE)
|
| 98 |
-
# =========================
|
| 99 |
elif page == "Explainability (SHAP)":
|
| 100 |
-
|
| 101 |
st.title("π§ Explainable AI (SHAP)")
|
| 102 |
|
| 103 |
if "last_payload" not in st.session_state:
|
| 104 |
-
st.warning("
|
| 105 |
else:
|
| 106 |
if st.button("Generate SHAP Explanation"):
|
| 107 |
try:
|
| 108 |
payload = st.session_state["last_payload"]
|
| 109 |
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
|
|
|
| 121 |
|
| 122 |
except Exception as e:
|
| 123 |
-
st.error(e)
|
| 124 |
|
| 125 |
|
| 126 |
-
# =========================
|
| 127 |
-
# πΉ DRIFT PAGE (UNCHANGED)
|
| 128 |
-
# =========================
|
| 129 |
elif page == "Drift Monitoring":
|
| 130 |
-
|
| 131 |
-
st.
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
|
| 141 |
try:
|
| 142 |
from src.monitoring.drift import detect_drift
|
| 143 |
from src.pipeline.retrain_pipeline import retrain
|
| 144 |
except:
|
| 145 |
-
st.error("
|
| 146 |
st.stop()
|
| 147 |
|
| 148 |
-
if st.button("Run Drift"):
|
| 149 |
-
|
| 150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
-
if st.
|
| 153 |
-
|
| 154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
st.set_page_config(page_title="Fraud Guard", layout="wide")
|
| 21 |
|
| 22 |
+
# --- Initialize Session State for Auto-Fill Buttons ---
|
| 23 |
if "t_time" not in st.session_state: st.session_state.t_time = 10000.0
|
| 24 |
if "t_amount" not in st.session_state: st.session_state.t_amount = 100.0
|
| 25 |
for i in range(1, 29):
|
| 26 |
if f"v_{i}" not in st.session_state: st.session_state[f"v_{i}"] = 0.0
|
| 27 |
|
| 28 |
def generate_sample(is_fraud=False):
|
| 29 |
+
"""Fills the UI with either a normal transaction or a simulated fraud attack"""
|
| 30 |
if is_fraud:
|
| 31 |
fraud_database = [
|
| 32 |
+
{
|
| 33 |
+
"Time": 406.0, "Amount": 0.00,
|
| 34 |
+
"V": [-2.312, 1.951, -1.609, 3.997, -0.522, -1.426, -2.537, 1.391, -2.770, -2.772,
|
| 35 |
+
3.202, -2.899, -0.595, -4.289, 0.389, -1.140, -2.830, -0.016, 0.416, 0.126,
|
| 36 |
+
0.517, -0.035, -0.465, 0.320, 0.044, 0.177, 0.261, -0.143]
|
| 37 |
+
},
|
| 38 |
+
{
|
| 39 |
+
"Time": 12500.0, "Amount": 99.99,
|
| 40 |
+
"V": [-0.95, 0.52, -1.53, 0.85, -0.21, 0.11, -0.45, 0.22, -0.63, -1.05,
|
| 41 |
+
1.20, -1.55, 0.30, -2.01, 0.10, -0.55, -1.22, 0.20, 0.45, -0.10,
|
| 42 |
+
0.25, 0.15, -0.12, 0.05, 0.22, -0.15, 0.02, 0.05]
|
| 43 |
+
},
|
| 44 |
+
{
|
| 45 |
+
"Time": 4462.0, "Amount": 1.00,
|
| 46 |
+
"V": [-2.303, 1.759, -0.359, 2.330, -0.821, -0.075, -0.560, 1.214, -1.385, -2.776,
|
| 47 |
+
3.231, -2.719, -1.059, -3.535, -1.583, -1.488, -2.573, -0.739, 0.380, -0.430,
|
| 48 |
+
-0.294, -0.932, 0.172, -0.087, -0.156, -0.542, 0.039, -0.153]
|
| 49 |
+
}
|
| 50 |
]
|
| 51 |
+
|
| 52 |
+
chosen_fraud = random.choice(fraud_database)
|
| 53 |
+
|
| 54 |
+
st.session_state.t_time = chosen_fraud["Time"]
|
| 55 |
+
st.session_state.t_amount = chosen_fraud["Amount"]
|
| 56 |
+
for i in range(1, 29):
|
| 57 |
+
st.session_state[f"v_{i}"] = chosen_fraud["V"][i-1]
|
| 58 |
+
|
| 59 |
else:
|
| 60 |
+
st.session_state.t_time = random.uniform(100, 150000)
|
| 61 |
+
st.session_state.t_amount = random.uniform(5, 150)
|
| 62 |
+
for i in range(1, 29):
|
| 63 |
+
st.session_state[f"v_{i}"] = random.uniform(-1.0, 1.0)
|
| 64 |
|
| 65 |
+
# --- UI Sidebar & Navigation ---
|
| 66 |
page = st.sidebar.selectbox("π Choose Section", ["Prediction", "Explainability (SHAP)", "Drift Monitoring"])
|
| 67 |
|
|
|
|
|
|
|
|
|
|
| 68 |
if page == "Prediction":
|
| 69 |
+
st.markdown("""
|
| 70 |
+
<div style='text-align: center; padding: 1rem 0;'>
|
| 71 |
+
<h1 style='color: #1E3A8A;'>π³ Fraud Guard Intelligence</h1>
|
| 72 |
+
<p style='color: #6B7280; font-size: 1.2rem;'>Real-Time Transaction Risk Analysis</p>
|
| 73 |
+
</div>
|
| 74 |
+
""", unsafe_allow_html=True)
|
| 75 |
|
| 76 |
+
col1, col2 = st.columns([1, 2])
|
|
|
|
|
|
|
| 77 |
|
| 78 |
with col1:
|
| 79 |
+
st.markdown("### π οΈ Demo Controls")
|
| 80 |
+
demo_col1, demo_col2 = st.columns(2)
|
| 81 |
+
with demo_col1:
|
| 82 |
+
if st.button("β
Simulate Normal User", use_container_width=True):
|
| 83 |
+
generate_sample(is_fraud=False)
|
| 84 |
+
with demo_col2:
|
| 85 |
+
if st.button("π¨ Simulate Fraud Attack", type="primary", use_container_width=True):
|
| 86 |
+
generate_sample(is_fraud=True)
|
| 87 |
+
|
| 88 |
+
st.markdown("### π₯ Transaction Input")
|
| 89 |
+
with st.container(border=True):
|
| 90 |
+
with st.form("transaction_form"):
|
| 91 |
+
|
| 92 |
+
t_time = st.slider("Time (Sec)", 0.0, 172800.0, key="t_time")
|
| 93 |
+
t_amount = st.slider("Amount ($)", 0.0, 5000.0, key="t_amount")
|
| 94 |
+
|
| 95 |
+
with st.expander("PCA Feature Vectors (V1 - V28)", expanded=False):
|
| 96 |
+
v_data = {}
|
| 97 |
+
for i in range(1, 29):
|
| 98 |
+
v_data[f"V{i}"] = st.number_input(f"V{i}", key=f"v_{i}", format="%.4f")
|
| 99 |
+
|
| 100 |
+
st.markdown("---")
|
| 101 |
+
threshold = st.slider("AI Sensitivity (Threshold)", 0.05, 0.95, 0.15)
|
| 102 |
+
submit_btn = st.form_submit_button("π Run Analysis", use_container_width=True)
|
| 103 |
|
| 104 |
with col2:
|
| 105 |
+
st.markdown("### π Live Telemetry & Assessment")
|
| 106 |
+
if not submit_btn:
|
| 107 |
+
st.info("Awaiting transaction payload. Click 'Simulate' then 'Run Analysis'.")
|
| 108 |
+
|
| 109 |
+
if submit_btn:
|
| 110 |
+
payload = {"Time": st.session_state.t_time, "Amount": st.session_state.t_amount, **v_data}
|
| 111 |
+
|
| 112 |
+
# β
Save payload for SHAP page
|
| 113 |
+
st.session_state["last_payload"] = payload
|
| 114 |
+
|
| 115 |
try:
|
| 116 |
+
with st.spinner("Analyzing threat vectors..."):
|
| 117 |
+
response = requests.post(API_URL, json=payload, timeout=30)
|
| 118 |
+
result = response.json()
|
| 119 |
+
|
| 120 |
+
prob = result["fraud_probability"]
|
| 121 |
+
pred = 1 if prob > threshold else 0
|
| 122 |
+
action = "π« Block Transaction" if pred == 1 else "β
Allow Transaction"
|
| 123 |
+
|
| 124 |
+
if pred == 1:
|
| 125 |
+
st.error(f"π¨ FRAUD DETECTED: {action}")
|
| 126 |
else:
|
| 127 |
+
st.success(f"β
TRANSACTION SAFE: {action}")
|
| 128 |
+
|
| 129 |
+
m_col1, m_col2 = st.columns(2)
|
| 130 |
+
m_col1.metric("Risk Level", f"{prob:.4%}")
|
| 131 |
+
m_col2.metric("Prediction Output", pred)
|
| 132 |
+
|
| 133 |
st.progress(float(prob))
|
| 134 |
+
st.markdown("---")
|
|
|
|
|
|
|
| 135 |
|
| 136 |
except Exception as e:
|
| 137 |
+
st.error(f"β³ Error connecting to API: {e}")
|
| 138 |
|
| 139 |
|
| 140 |
+
# π₯ NEW SHAP PAGE
|
|
|
|
|
|
|
| 141 |
elif page == "Explainability (SHAP)":
|
|
|
|
| 142 |
st.title("π§ Explainable AI (SHAP)")
|
| 143 |
|
| 144 |
if "last_payload" not in st.session_state:
|
| 145 |
+
st.warning("β οΈ Please run a prediction first.")
|
| 146 |
else:
|
| 147 |
if st.button("Generate SHAP Explanation"):
|
| 148 |
try:
|
| 149 |
payload = st.session_state["last_payload"]
|
| 150 |
|
| 151 |
+
with st.spinner("Generating explanations..."):
|
| 152 |
+
input_df = pd.DataFrame([payload])
|
| 153 |
+
|
| 154 |
+
pipeline = PredictPipeline()
|
| 155 |
+
processed_df = pipeline.preprocess(input_df)
|
| 156 |
+
|
| 157 |
+
explainer = ShapExplainer()
|
| 158 |
+
shap_values = explainer.explain(processed_df)
|
| 159 |
+
|
| 160 |
+
fig, ax = plt.subplots(figsize=(8, 4))
|
| 161 |
+
shap.plots.waterfall(shap_values[0], show=False)
|
| 162 |
+
st.pyplot(fig)
|
| 163 |
|
| 164 |
except Exception as e:
|
| 165 |
+
st.error(f"β Error: {e}")
|
| 166 |
|
| 167 |
|
|
|
|
|
|
|
|
|
|
| 168 |
elif page == "Drift Monitoring":
|
| 169 |
+
st.title("π Data Drift Monitoring")
|
| 170 |
+
st.markdown("### Monitor model health & trigger verified retraining")
|
| 171 |
+
|
| 172 |
+
with st.expander("π οΈ Demo Tools: Force Data Drift", expanded=True):
|
| 173 |
+
st.write("Inject 50 heavily skewed transactions into the database to trigger a statistical drift warning.")
|
| 174 |
+
if st.button("π Inject Synthetic Drift Data", type="primary"):
|
| 175 |
+
with st.spinner("Injecting bad data into Cloud DB..."):
|
| 176 |
+
for _ in range(50):
|
| 177 |
+
skewed_data = {"Time": random.uniform(10, 50000), "Amount": random.uniform(1000, 5000)}
|
| 178 |
+
for i in range(1, 29):
|
| 179 |
+
skewed_data[f"V{i}"] = random.uniform(-15.0, 15.0)
|
| 180 |
+
save_to_db(skewed_data, pred=1, prob=0.99)
|
| 181 |
+
st.success("β
50 Skewed rows injected!")
|
| 182 |
+
|
| 183 |
+
st.markdown("---")
|
| 184 |
|
| 185 |
try:
|
| 186 |
from src.monitoring.drift import detect_drift
|
| 187 |
from src.pipeline.retrain_pipeline import retrain
|
| 188 |
except:
|
| 189 |
+
st.error("β οΈ Drift feature not supported in this environment")
|
| 190 |
st.stop()
|
| 191 |
|
| 192 |
+
if st.button("π Run Drift Detection"):
|
| 193 |
+
with st.spinner("Running statistical drift analysis..."):
|
| 194 |
+
try:
|
| 195 |
+
report_path = detect_drift("data/creditcard.csv")
|
| 196 |
+
if report_path:
|
| 197 |
+
st.success("β
Drift report generated!")
|
| 198 |
+
st.session_state["drift_done"] = True
|
| 199 |
+
else:
|
| 200 |
+
st.warning("β οΈ Not enough data in live DB")
|
| 201 |
+
except Exception as e:
|
| 202 |
+
st.error(f"β οΈ Error running drift: {e}")
|
| 203 |
+
|
| 204 |
+
report_path = "reports/drift_report.html"
|
| 205 |
+
if os.path.exists(report_path):
|
| 206 |
+
with open(report_path, "r", encoding="utf-8") as f:
|
| 207 |
+
html = f.read()
|
| 208 |
+
components.html(html, height=800, scrolling=True)
|
| 209 |
|
| 210 |
+
if st.session_state.get("drift_done", False):
|
| 211 |
+
if st.button("β‘ Retrain Model"):
|
| 212 |
+
try:
|
| 213 |
+
retrain()
|
| 214 |
+
st.success("β
Model retrained successfully!")
|
| 215 |
+
except Exception as e:
|
| 216 |
+
st.error(f"β Error: {e}")
|