import streamlit as st import pandas as pd import sys import os import requests import shap import matplotlib.pyplot as plt import streamlit.components.v1 as components import random from dotenv import load_dotenv sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) load_dotenv() API_URL = os.getenv("API_URL", "http://127.0.0.1:8000/predict") from src.pipeline.predict_pipeline import PredictPipeline from src.explanability.shap_explainer import ShapExplainer from src.monitoring.db import save_to_db st.set_page_config(page_title="Fraud Guard", layout="wide") # --- Initialize Session State for Auto-Fill Buttons --- if "t_time" not in st.session_state: st.session_state.t_time = 10000.0 if "t_amount" not in st.session_state: st.session_state.t_amount = 100.0 for i in range(1, 29): if f"v_{i}" not in st.session_state: st.session_state[f"v_{i}"] = 0.0 def generate_sample(is_fraud=False): """Fills the UI with either a normal transaction or a simulated fraud attack""" if is_fraud: fraud_database = [ { "Time": 406.0, "Amount": 0.00, "V": [-2.312, 1.951, -1.609, 3.997, -0.522, -1.426, -2.537, 1.391, -2.770, -2.772, 3.202, -2.899, -0.595, -4.289, 0.389, -1.140, -2.830, -0.016, 0.416, 0.126, 0.517, -0.035, -0.465, 0.320, 0.044, 0.177, 0.261, -0.143] }, { "Time": 12500.0, "Amount": 99.99, "V": [-0.95, 0.52, -1.53, 0.85, -0.21, 0.11, -0.45, 0.22, -0.63, -1.05, 1.20, -1.55, 0.30, -2.01, 0.10, -0.55, -1.22, 0.20, 0.45, -0.10, 0.25, 0.15, -0.12, 0.05, 0.22, -0.15, 0.02, 0.05] }, { "Time": 4462.0, "Amount": 1.00, "V": [-2.303, 1.759, -0.359, 2.330, -0.821, -0.075, -0.560, 1.214, -1.385, -2.776, 3.231, -2.719, -1.059, -3.535, -1.583, -1.488, -2.573, -0.739, 0.380, -0.430, -0.294, -0.932, 0.172, -0.087, -0.156, -0.542, 0.039, -0.153] } ] chosen_fraud = random.choice(fraud_database) st.session_state.t_time = chosen_fraud["Time"] st.session_state.t_amount = chosen_fraud["Amount"] for i in range(1, 29): st.session_state[f"v_{i}"] = chosen_fraud["V"][i-1] else: st.session_state.t_time = random.uniform(100, 150000) st.session_state.t_amount = random.uniform(5, 150) for i in range(1, 29): st.session_state[f"v_{i}"] = random.uniform(-1.0, 1.0) # --- UI Sidebar & Navigation --- page = st.sidebar.selectbox("📌 Choose Section", ["Prediction", "Explainability (SHAP)", "Drift Monitoring"]) if page == "Prediction": st.markdown("""

💳 Fraud Guard Intelligence

Real-Time Transaction Risk Analysis

""", unsafe_allow_html=True) col1, col2 = st.columns([1, 2]) with col1: st.markdown("### 🛠️ Demo Controls") demo_col1, demo_col2 = st.columns(2) with demo_col1: if st.button("✅ Simulate Normal User", use_container_width=True): generate_sample(is_fraud=False) with demo_col2: if st.button("🚨 Simulate Fraud Attack", type="primary", use_container_width=True): generate_sample(is_fraud=True) st.markdown("### 📥 Transaction Input") with st.container(border=True): with st.form("transaction_form"): t_time = st.slider("Time (Sec)", 0.0, 172800.0, key="t_time") t_amount = st.slider("Amount ($)", 0.0, 5000.0, key="t_amount") with st.expander("PCA Feature Vectors (V1 - V28)", expanded=False): v_data = {} for i in range(1, 29): v_data[f"V{i}"] = st.number_input(f"V{i}", key=f"v_{i}", format="%.4f") st.markdown("---") threshold = st.slider("AI Sensitivity (Threshold)", 0.05, 0.95, 0.15) submit_btn = st.form_submit_button("🔍 Run Analysis", use_container_width=True) with col2: st.markdown("### 📊 Live Telemetry & Assessment") if not submit_btn: st.info("Awaiting transaction payload. Click 'Simulate' then 'Run Analysis'.") if submit_btn: payload = {"Time": st.session_state.t_time, "Amount": st.session_state.t_amount, **v_data} # store for SHAP page st.session_state["last_payload"] = payload try: with st.spinner("Analyzing threat vectors..."): response = requests.post(API_URL, json=payload, timeout=50) result = response.json() prob = result["fraud_probability"] pred = 1 if prob > threshold else 0 action = "🚫 Block Transaction" if pred == 1 else "✅ Allow Transaction" if pred == 1: st.error(f"🚨 FRAUD DETECTED: {action}") else: st.success(f"✅ TRANSACTION SAFE: {action}") m_col1, m_col2 = st.columns(2) m_col1.metric("Risk Level", f"{prob:.4%}") m_col2.metric("Prediction Output", pred) st.progress(float(prob)) st.markdown("---") except Exception as e: st.error(f"⏳ Error connecting to API: {e}") st.info("server is waking up wait,please") # --- NEW SHAP PAGE --- elif page == "Explainability (SHAP)": st.title("🧠 Explainable AI (SHAP)") if "last_payload" not in st.session_state: st.warning("⚠️ Please run a prediction first.") else: if st.button("Generate SHAP Explanation"): try: payload = st.session_state["last_payload"] with st.spinner("Generating explanations..."): input_df = pd.DataFrame([payload]) pipeline = PredictPipeline() processed_df = pipeline.preprocess(input_df) explainer = ShapExplainer() shap_values = explainer.explain(processed_df) fig, ax = plt.subplots(figsize=(8, 4)) shap.plots.waterfall(shap_values[0], show=False) st.pyplot(fig) except Exception as e: st.error(f"❌ Error: {e}") # --- DRIFT MONITORING (UNCHANGED EXACTLY) --- elif page == "Drift Monitoring": st.title("📉 Data Drift Monitoring") st.markdown("### Monitor model health & trigger verified retraining") with st.expander("🛠️ Demo Tools: Force Data Drift", expanded=True): st.write("Inject 50 heavily skewed transactions into the database to trigger a statistical drift warning.") if st.button("💉 Inject Synthetic Drift Data", type="primary"): with st.spinner("Injecting bad data into Cloud DB..."): for _ in range(50): skewed_data = {"Time": random.uniform(10, 50000), "Amount": random.uniform(1000, 5000)} for i in range(1, 29): skewed_data[f"V{i}"] = random.uniform(-15.0, 15.0) save_to_db(skewed_data, pred=1, prob=0.99) st.success("✅ 50 Skewed rows injected! Now click 'Run Drift Detection' below.") st.markdown("---") try: from src.monitoring.drift import detect_drift from src.pipeline.retrain_pipeline import retrain except: st.error("⚠️ Drift feature not supported in this environment") st.stop() if st.button("🚀 Run Drift Detection"): with st.spinner("Running statistical drift analysis..."): try: report_path = detect_drift("data/creditcard.csv") if report_path: st.success("✅ Drift report generated!") st.session_state["drift_done"] = True else: st.warning("⚠️ Not enough data in live DB (Needs 50 rows). Use the Demo Injector above!") except Exception as e: st.error(f"⚠️ Error running drift: {e}") report_path = "reports/drift_report.html" if os.path.exists(report_path): with open(report_path, "r", encoding="utf-8") as f: html = f.read() components.html(html, height=800, scrolling=True) st.markdown("---") st.subheader("🔁 Human-in-the-Loop Retraining") st.write("Ensure database contains human-verified `Actual_Class` labels before retraining.") if st.session_state.get("drift_done", False): if st.button("⚡ Retrain Model (Requires Verified Data)"): with st.spinner("Retraining model..."): try: retrain() st.success("✅ Model retrained successfully with verified data!") except ValueError as ve: st.error(f"❌ {ve}") except Exception as e: st.error(f"❌ Error: {e}")