Sibi Krishnamoorthy
prod
8a08300
import streamlit as st
import pandas as pd
import requests
import plotly.graph_objects as ob
import json
from datetime import datetime, time
import plotly.express as px
from src.features.constants import category_names, job_names
# 1. Configuration & Layout
st.set_page_config(page_title="PayShield Monitor", layout="wide", page_icon="🛡️")
# Custom CSS for FinTech Aesthetic
st.markdown("""
<style>
.main {
background-color: #f8f9fa;
}
.stMetric {
background-color: #ffffff;
padding: 15px;
border-radius: 10px;
box-shadow: 0 2px 4px rgba(0,0,0,0.05);
}
.status-banner {
padding: 20px;
border-radius: 10px;
color: white;
text-align: center;
font-weight: bold;
font-size: 24px;
margin-bottom: 20px;
}
.banner-safe {
background-color: #28a745;
}
.banner-fraud {
background-color: #dc3545;
}
.banner-shadow {
background-color: #ffc107;
color: #212529;
}
</style>
""", unsafe_allow_html=True)
st.title("🛡️ PayShield-ML Analyst Workbench")
st.markdown("---")
# 2. Sidebar: Transaction Simulator
with st.sidebar:
st.header("🛒 Transaction Simulator")
with st.expander("👤 User Profile", expanded=True):
user_id = st.text_input("User ID", value="u12345")
job = st.selectbox("Job Title", options=sorted(job_names), index=sorted(job_names).index("Engineer, biomedical") if "Engineer, biomedical" in job_names else 0)
dob = st.date_input("Date of Birth", value=datetime(1985, 3, 20))
gender = st.radio("Gender", ["M", "F"], horizontal=True)
with st.expander("💳 Transaction Details", expanded=True):
amt = st.number_input("Amount ($)", min_value=0.01, value=150.0, step=10.0)
category = st.selectbox("Category", options=sorted(category_names), index=sorted(category_names).index("grocery_pos") if "grocery_pos" in category_names else 0)
trans_date = st.date_input("Transaction Date", value=datetime.now())
trans_time = st.time_input("Transaction Time", value=time(14, 30))
with st.expander("📍 Location Details", expanded=True):
st.caption("User Coordinates")
lat = st.number_input("User Lat", value=40.7128, format="%.4f")
long = st.number_input("User Long", value=-74.0060, format="%.4f")
st.caption("Merchant Coordinates")
merch_lat = st.number_input("Merchant Lat", value=40.7200, format="%.4f")
merch_long = st.number_input("Merchant Long", value=-74.0100, format="%.4f")
with st.expander("🛠️ Advanced: Feature Overrides", expanded=False):
st.caption("Force specific values for sensitivity analysis")
override_trans_count = st.number_input("Count (24h)", min_value=0, value=0, help="Leave 0 to use real-time data")
override_avg_spend = st.number_input("Avg Spend (24h)", min_value=0.0, value=0.0, help="Leave 0 to use real-time data")
override_user_avg_all_time = st.number_input("User Avg (All Time)", min_value=0.0, value=0.0, help="Leave 0 to use calculated value")
analyze_btn = st.button("🚀 Analyze Transaction", use_container_width=True, type="primary")
# 3. Main Dashboard Logic
if analyze_btn:
# Construct Payload
trans_dt = datetime.combine(trans_date, trans_time)
payload = {
"user_id": user_id,
"trans_date_trans_time": trans_dt.strftime("%Y-%m-%d %H:%M:%S"),
"amt": amt,
"lat": lat,
"long": long,
"merch_lat": merch_lat,
"merch_long": merch_long,
"job": job,
"category": category,
"gender": gender,
"dob": dob.strftime("%Y-%m-%d")
}
# Add overrides if set
if override_trans_count > 0:
payload["trans_count_24h"] = override_trans_count
if override_avg_spend > 0:
payload["avg_spend_24h"] = override_avg_spend
if override_user_avg_all_time > 0:
payload["user_avg_amt_all_time"] = override_user_avg_all_time
try:
# Step 1: Call API
with st.spinner("Analyzing with XGBoost Engine..."):
# Use docker-internal DNS or localhost depending on environment
import os
api_url = os.getenv("API_URL", "http://127.0.0.1:8000/v1/predict")
response = requests.post(api_url, json=payload, timeout=5)
if response.status_code == 200:
res = response.json()
score = res["risk_score"]
decision = res["decision"]
latency = res["latency_ms"]
is_shadow = res.get("shadow_mode", False)
features = res.get("features", {}) # Get real features used
# Columns for high-level metrics
m1, m2, m3 = st.columns(3)
with m1:
st.metric("Decision Result", decision)
with m2:
st.metric("Risk Score", f"{score:.1f}/100")
with m3:
st.metric("Inference Latency", f"{latency:.2f}ms")
# Decision Banner
if decision == "BLOCK":
st.markdown('<div class="status-banner banner-fraud">❌ FRAUD DETECTED - TRANSACTION BLOCKED</div>', unsafe_allow_html=True)
else:
st.markdown('<div class="status-banner banner-safe">✅ TRANSACTION APPROVED</div>', unsafe_allow_html=True)
if is_shadow and decision == "BLOCK":
st.markdown('<div class="status-banner banner-shadow">⚠️ SHADOW MODE: Transaction allowed in simulation.</div>', unsafe_allow_html=True)
# 4. Visualizations
v1, v2 = st.columns([1, 1])
with v1:
st.subheader("🎯 Risk Gauge")
fig = ob.Figure(ob.Indicator(
mode = "gauge+number",
value = score,
domain = {'x': [0, 1], 'y': [0, 1]},
title = {'text': "Confidence Score (%)"},
gauge = {
'axis': {'range': [0, 100]},
'bar': {'color': "#333"},
'steps': [
{'range': [0, 50], 'color': "rgba(40, 167, 69, 0.3)"},
{'range': [50, 82], 'color': "rgba(255, 193, 7, 0.3)"},
{'range': [82, 100], 'color': "rgba(220, 53, 69, 0.3)"}
],
'threshold': {
'line': {'color': "black", 'width': 4},
'thickness': 0.75,
'value': 82
}
}
))
fig.update_layout(height=350, margin=dict(l=20, r=20, t=50, b=20))
st.plotly_chart(fig, use_container_width=True)
with v2:
st.subheader("📊 Feature Explainability (SHAP)")
shap_data = res.get("shap_values", {})
if shap_data:
# Create DataFrame from real SHAP values
shap_df = pd.DataFrame([
{"Feature": k, "Impact": v, "Abs_Impact": abs(v)}
for k, v in shap_data.items()
]).sort_values("Abs_Impact", ascending=True)
# Color based on positive/negative contribution
colors = ["#dc3545" if x > 0 else "#28a745" for x in shap_df["Impact"]]
fig_shap = px.bar(
shap_df,
x="Impact",
y="Feature",
orientation="h",
title="Top Feature Contributions to Risk Score",
color="Impact",
color_continuous_scale=["#28a745", "#ffc107", "#dc3545"],
labels={"Impact": "SHAP Value (Impact on Prediction)"}
)
fig_shap.update_layout(
height=350,
margin=dict(l=20, r=20, t=50, b=20),
showlegend=False
)
st.plotly_chart(fig_shap, use_container_width=True)
st.caption("🔴 Red = Increases fraud risk | 🟢 Green = Decreases fraud risk")
else:
st.info("SHAP explainability not available. Enable it in API settings.")
# 5. Internal Data State (Architectural Demo)
st.markdown("---")
d1, d2 = st.columns([1, 1])
with d1:
st.subheader("🗄️ Feature Store Payload")
if features:
st.info("Real-time features used for inference (from Redis or Overrides):")
# Format for better readability
display_features = {
"Velocity (24h)": features.get("trans_count_24h"),
"Avg Spend (24h)": f"${features.get('avg_spend_24h', 0):.2f}",
"Current/Avg Ratio": f"{features.get('amt_to_avg_ratio_24h', 0):.2f}x",
"User Avg (All Time)": f"${features.get('user_avg_amt_all_time', 0):.2f}"
}
st.table(pd.DataFrame([display_features]))
else:
st.warning("No feature data returned from API.")
with d2:
st.subheader("🔍 API RAW Response")
with st.expander("View JSON", expanded=False):
st.json(res)
else:
st.error(f"API Error: Status {response.status_code}")
st.json(response.json())
except requests.exceptions.ConnectionError:
st.error("🛑 Connection Failed: Could not connect to Inference API. Is the server running at http://127.0.0.1:8000?")
st.info("Try running: `uv run uvicorn src.api.main:app --reload`")
except Exception as e:
st.error(f"⚠️ Unexpected Error: {str(e)}")
else:
# Landing State
st.info("👈 Fill in the details in the sidebar and click 'Analyze Transaction' to start.")
# Hero/Demo Content
c1, c2, c3 = st.columns(3)
c1.markdown("### ⚡ Low Latency\nSub-50ms inference utilizing XGBoost and Redis.")
c2.markdown("### 📋 Explainable\nSHAP integration for transparent fraud scoring.")
c3.markdown("### 🧪 Shadow Mode\nSafe production testing of new model versions.")