Spaces:
Sleeping
Sleeping
File size: 10,596 Bytes
8a08300 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 |
import streamlit as st
import pandas as pd
import requests
import plotly.graph_objects as ob
import json
from datetime import datetime, time
import plotly.express as px
from src.features.constants import category_names, job_names
# 1. Configuration & Layout
st.set_page_config(page_title="PayShield Monitor", layout="wide", page_icon="🛡️")
# Custom CSS for FinTech Aesthetic
st.markdown("""
<style>
.main {
background-color: #f8f9fa;
}
.stMetric {
background-color: #ffffff;
padding: 15px;
border-radius: 10px;
box-shadow: 0 2px 4px rgba(0,0,0,0.05);
}
.status-banner {
padding: 20px;
border-radius: 10px;
color: white;
text-align: center;
font-weight: bold;
font-size: 24px;
margin-bottom: 20px;
}
.banner-safe {
background-color: #28a745;
}
.banner-fraud {
background-color: #dc3545;
}
.banner-shadow {
background-color: #ffc107;
color: #212529;
}
</style>
""", unsafe_allow_html=True)
st.title("🛡️ PayShield-ML Analyst Workbench")
st.markdown("---")
# 2. Sidebar: Transaction Simulator
with st.sidebar:
st.header("🛒 Transaction Simulator")
with st.expander("👤 User Profile", expanded=True):
user_id = st.text_input("User ID", value="u12345")
job = st.selectbox("Job Title", options=sorted(job_names), index=sorted(job_names).index("Engineer, biomedical") if "Engineer, biomedical" in job_names else 0)
dob = st.date_input("Date of Birth", value=datetime(1985, 3, 20))
gender = st.radio("Gender", ["M", "F"], horizontal=True)
with st.expander("💳 Transaction Details", expanded=True):
amt = st.number_input("Amount ($)", min_value=0.01, value=150.0, step=10.0)
category = st.selectbox("Category", options=sorted(category_names), index=sorted(category_names).index("grocery_pos") if "grocery_pos" in category_names else 0)
trans_date = st.date_input("Transaction Date", value=datetime.now())
trans_time = st.time_input("Transaction Time", value=time(14, 30))
with st.expander("📍 Location Details", expanded=True):
st.caption("User Coordinates")
lat = st.number_input("User Lat", value=40.7128, format="%.4f")
long = st.number_input("User Long", value=-74.0060, format="%.4f")
st.caption("Merchant Coordinates")
merch_lat = st.number_input("Merchant Lat", value=40.7200, format="%.4f")
merch_long = st.number_input("Merchant Long", value=-74.0100, format="%.4f")
with st.expander("🛠️ Advanced: Feature Overrides", expanded=False):
st.caption("Force specific values for sensitivity analysis")
override_trans_count = st.number_input("Count (24h)", min_value=0, value=0, help="Leave 0 to use real-time data")
override_avg_spend = st.number_input("Avg Spend (24h)", min_value=0.0, value=0.0, help="Leave 0 to use real-time data")
override_user_avg_all_time = st.number_input("User Avg (All Time)", min_value=0.0, value=0.0, help="Leave 0 to use calculated value")
analyze_btn = st.button("🚀 Analyze Transaction", use_container_width=True, type="primary")
# 3. Main Dashboard Logic
if analyze_btn:
# Construct Payload
trans_dt = datetime.combine(trans_date, trans_time)
payload = {
"user_id": user_id,
"trans_date_trans_time": trans_dt.strftime("%Y-%m-%d %H:%M:%S"),
"amt": amt,
"lat": lat,
"long": long,
"merch_lat": merch_lat,
"merch_long": merch_long,
"job": job,
"category": category,
"gender": gender,
"dob": dob.strftime("%Y-%m-%d")
}
# Add overrides if set
if override_trans_count > 0:
payload["trans_count_24h"] = override_trans_count
if override_avg_spend > 0:
payload["avg_spend_24h"] = override_avg_spend
if override_user_avg_all_time > 0:
payload["user_avg_amt_all_time"] = override_user_avg_all_time
try:
# Step 1: Call API
with st.spinner("Analyzing with XGBoost Engine..."):
# Use docker-internal DNS or localhost depending on environment
import os
api_url = os.getenv("API_URL", "http://127.0.0.1:8000/v1/predict")
response = requests.post(api_url, json=payload, timeout=5)
if response.status_code == 200:
res = response.json()
score = res["risk_score"]
decision = res["decision"]
latency = res["latency_ms"]
is_shadow = res.get("shadow_mode", False)
features = res.get("features", {}) # Get real features used
# Columns for high-level metrics
m1, m2, m3 = st.columns(3)
with m1:
st.metric("Decision Result", decision)
with m2:
st.metric("Risk Score", f"{score:.1f}/100")
with m3:
st.metric("Inference Latency", f"{latency:.2f}ms")
# Decision Banner
if decision == "BLOCK":
st.markdown('<div class="status-banner banner-fraud">❌ FRAUD DETECTED - TRANSACTION BLOCKED</div>', unsafe_allow_html=True)
else:
st.markdown('<div class="status-banner banner-safe">✅ TRANSACTION APPROVED</div>', unsafe_allow_html=True)
if is_shadow and decision == "BLOCK":
st.markdown('<div class="status-banner banner-shadow">⚠️ SHADOW MODE: Transaction allowed in simulation.</div>', unsafe_allow_html=True)
# 4. Visualizations
v1, v2 = st.columns([1, 1])
with v1:
st.subheader("🎯 Risk Gauge")
fig = ob.Figure(ob.Indicator(
mode = "gauge+number",
value = score,
domain = {'x': [0, 1], 'y': [0, 1]},
title = {'text': "Confidence Score (%)"},
gauge = {
'axis': {'range': [0, 100]},
'bar': {'color': "#333"},
'steps': [
{'range': [0, 50], 'color': "rgba(40, 167, 69, 0.3)"},
{'range': [50, 82], 'color': "rgba(255, 193, 7, 0.3)"},
{'range': [82, 100], 'color': "rgba(220, 53, 69, 0.3)"}
],
'threshold': {
'line': {'color': "black", 'width': 4},
'thickness': 0.75,
'value': 82
}
}
))
fig.update_layout(height=350, margin=dict(l=20, r=20, t=50, b=20))
st.plotly_chart(fig, use_container_width=True)
with v2:
st.subheader("📊 Feature Explainability (SHAP)")
shap_data = res.get("shap_values", {})
if shap_data:
# Create DataFrame from real SHAP values
shap_df = pd.DataFrame([
{"Feature": k, "Impact": v, "Abs_Impact": abs(v)}
for k, v in shap_data.items()
]).sort_values("Abs_Impact", ascending=True)
# Color based on positive/negative contribution
colors = ["#dc3545" if x > 0 else "#28a745" for x in shap_df["Impact"]]
fig_shap = px.bar(
shap_df,
x="Impact",
y="Feature",
orientation="h",
title="Top Feature Contributions to Risk Score",
color="Impact",
color_continuous_scale=["#28a745", "#ffc107", "#dc3545"],
labels={"Impact": "SHAP Value (Impact on Prediction)"}
)
fig_shap.update_layout(
height=350,
margin=dict(l=20, r=20, t=50, b=20),
showlegend=False
)
st.plotly_chart(fig_shap, use_container_width=True)
st.caption("🔴 Red = Increases fraud risk | 🟢 Green = Decreases fraud risk")
else:
st.info("SHAP explainability not available. Enable it in API settings.")
# 5. Internal Data State (Architectural Demo)
st.markdown("---")
d1, d2 = st.columns([1, 1])
with d1:
st.subheader("🗄️ Feature Store Payload")
if features:
st.info("Real-time features used for inference (from Redis or Overrides):")
# Format for better readability
display_features = {
"Velocity (24h)": features.get("trans_count_24h"),
"Avg Spend (24h)": f"${features.get('avg_spend_24h', 0):.2f}",
"Current/Avg Ratio": f"{features.get('amt_to_avg_ratio_24h', 0):.2f}x",
"User Avg (All Time)": f"${features.get('user_avg_amt_all_time', 0):.2f}"
}
st.table(pd.DataFrame([display_features]))
else:
st.warning("No feature data returned from API.")
with d2:
st.subheader("🔍 API RAW Response")
with st.expander("View JSON", expanded=False):
st.json(res)
else:
st.error(f"API Error: Status {response.status_code}")
st.json(response.json())
except requests.exceptions.ConnectionError:
st.error("🛑 Connection Failed: Could not connect to Inference API. Is the server running at http://127.0.0.1:8000?")
st.info("Try running: `uv run uvicorn src.api.main:app --reload`")
except Exception as e:
st.error(f"⚠️ Unexpected Error: {str(e)}")
else:
# Landing State
st.info("👈 Fill in the details in the sidebar and click 'Analyze Transaction' to start.")
# Hero/Demo Content
c1, c2, c3 = st.columns(3)
c1.markdown("### ⚡ Low Latency\nSub-50ms inference utilizing XGBoost and Redis.")
c2.markdown("### 📋 Explainable\nSHAP integration for transparent fraud scoring.")
c3.markdown("### 🧪 Shadow Mode\nSafe production testing of new model versions.")
|