FinTech / src /streamlit_app.py
rocky250's picture
Update src/streamlit_app.py
67009e2 verified
import streamlit as st
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import joblib
import json
from datetime import datetime
from huggingface_hub import hf_hub_download
st.set_page_config(page_title="FinTech Fraud Guard", layout="wide")
st.markdown("""
<style>
.main { background-color: #0e1117; }
.stButton>button { width: 100%; border-radius: 20px; background: linear-gradient(45deg, #ff4b4b, #ff8f8f); color: white; border: none; }
.fraud-card { padding: 20px; border-radius: 15px; background-color: #1e2130; border-left: 5px solid #ff4b4b; margin-bottom: 20px; }
.legit-card { padding: 20px; border-radius: 15px; background-color: #1e2130; border-left: 5px solid #00ffcc; margin-bottom: 20px; }
</style>
""", unsafe_allow_html=True)
class MoEFraudModel(nn.Module):
def __init__(self, cat_dims, num_cols_map, embed_dim=8):
super(MoEFraudModel, self).__init__()
self.embeddings = nn.ModuleDict({
col: nn.Embedding(num_classes, embed_dim)
for col, num_classes in cat_dims.items()
})
self.cat_cols = list(cat_dims.keys())
self.num_cols = list(num_cols_map.keys())
self.cat_idx = {name: i for i, name in enumerate(self.cat_cols)}
self.num_idx = {name: i for i, name in enumerate(self.num_cols)}
total_input_dim = (len(self.cat_cols) * embed_dim) + len(self.num_cols)
self.gating_network = nn.Sequential(
nn.Linear(total_input_dim, 64),
nn.BatchNorm1d(64),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(64, 4),
nn.Softmax(dim=1)
)
self.e1_cols_num = ['amt', 'hour', 'day_of_week', 'is_weekend', 'unix_time']
self.e1_cols_cat = ['category']
self.e2_cols_num = ['city_pop', 'age', 'time_diff_cc', 'cc_avg_amt_last_5', 'cc_std_amt_last_5', 'cc_max_amt_last_5']
self.e2_cols_cat = ['cc_num', 'gender', 'job', 'city', 'state', 'zip']
self.e3_cols_num = ['merchant_fraud_rate', 'merchant_txn_count']
self.e3_cols_cat = ['merchant', 'category']
self.e4_cols_num = ['lat', 'long', 'merch_lat', 'merch_long', 'distance_customer_merchant', 'state_mismatch_flag']
self.e4_cols_cat = []
self.expert1 = self._make_expert(self._get_dim(self.e1_cols_cat, self.e1_cols_num, embed_dim))
self.expert2 = self._make_expert(self._get_dim(self.e2_cols_cat, self.e2_cols_num, embed_dim))
self.expert3 = self._make_expert(self._get_dim(self.e3_cols_cat, self.e3_cols_num, embed_dim))
self.expert4 = self._make_expert(self._get_dim(self.e4_cols_cat, self.e4_cols_num, embed_dim))
self.classifier = nn.Sequential(
nn.Linear(32, 16),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(16, 1)
)
def _get_dim(self, cats, nums, embed_dim):
return len(nums) + (len(cats) * embed_dim)
def _make_expert(self, input_dim):
return nn.Sequential(
nn.Linear(input_dim, 128),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, 32),
nn.ReLU()
)
def get_features(self, cat_input, num_input, req_cat, req_num):
parts = []
if req_num:
indices = [self.num_idx[c] for c in req_num]
parts.append(num_input[:, indices])
if req_cat:
for c in req_cat:
idx = self.cat_idx[c]
emb = self.embeddings[c](cat_input[:, idx])
parts.append(emb)
return torch.cat(parts, dim=1)
def forward(self, cat_input, num_input):
all_embs = [self.embeddings[c](cat_input[:, i]) for i, c in enumerate(self.cat_cols)]
global_features = torch.cat([torch.cat(all_embs, dim=1), num_input], dim=1)
weights = self.gating_network(global_features)
h1 = self.expert1(self.get_features(cat_input, num_input, self.e1_cols_cat, self.e1_cols_num))
h2 = self.expert2(self.get_features(cat_input, num_input, self.e2_cols_cat, self.e2_cols_num))
h3 = self.expert3(self.get_features(cat_input, num_input, self.e3_cols_cat, self.e3_cols_num))
h4 = self.expert4(self.get_features(cat_input, num_input, self.e4_cols_cat, self.e4_cols_num))
h_final = (weights[:, 0:1]*h1 + weights[:, 1:2]*h2 + weights[:, 2:3]*h3 + weights[:, 3:4]*h4)
return self.classifier(h_final)
def haversine(lat1, lon1, lat2, lon2):
r = 6371
phi1, phi2 = np.radians(lat1), np.radians(lat2)
dphi = np.radians(lat2 - lat1)
dlambda = np.radians(lon2 - lon1)
a = np.sin(dphi/2)**2 + np.cos(phi1)*np.cos(phi2)*np.sin(dlambda/2)**2
return 2 * r * np.arcsin(np.sqrt(a))
@st.cache_resource
def load_assets():
REPO_ID = "rocky250/FinTech"
w_p = hf_hub_download(repo_id=REPO_ID, filename="proposed_moe_model.pth")
c_p = hf_hub_download(repo_id=REPO_ID, filename="config.json")
e_p = hf_hub_download(repo_id=REPO_ID, filename="label_encoders.joblib")
s_p = hf_hub_download(repo_id=REPO_ID, filename="scaler.joblib")
with open(c_p, 'r') as f:
config = json.load(f)
model = MoEFraudModel(config['cat_dims'], config['num_cols_map'], config['embed_dim'])
model.load_state_dict(torch.load(w_p, map_location=torch.device('cpu')))
model.eval()
return model, joblib.load(e_p), joblib.load(s_p), config
model, encoders, scaler, config = load_assets()
st.title("FinTech Fraud Guard")
st.markdown("MoE Transaction Verifier - Optimized Threshold: 0.9898")
with st.form("transaction_form"):
c1, c2, c3 = st.columns(3)
with c1:
st.subheader("Personal & Card")
first = st.text_input("First Name", "Jeff")
last = st.text_input("Last Name", "Elliott")
gender = st.selectbox("Gender", ["M", "F"])
dob = st.text_input("Date of Birth (DD-MM-YYYY)", "19-03-1968")
cc_num = st.text_input("CC Number", "3725537864060026")
job = st.text_input("Job", "Mechanical engineer")
with c2:
st.subheader("Transaction Details")
trans_dt = st.text_input("Transaction Time (DD-MM-YYYY HH:MM)", "21-06-2020 12:14")
merchant = st.text_input("Merchant", "fraud_Kirlin and Sons")
category = st.text_input("Category", "personal_care")
amt = st.number_input("Amount ($)", value=2.86, min_value=0.01)
unix_time = st.number_input("Unix Time", value=1371816865.0)
with c3:
st.subheader("Geography")
street = st.text_input("Street", "351 Darlene Green")
city = st.text_input("City", "Columbia")
state = st.text_input("State", "SC")
zip_v = st.text_input("Zip Code", "29209")
city_pop = st.number_input("City Population", value=333497, min_value=0)
lat = st.number_input("Customer Lat", value=33.9659)
lon = st.number_input("Customer Long", value=-80.9355)
m_lat = st.number_input("Merchant Lat", value=33.986391)
m_lon = st.number_input("Merchant Long", value=-81.200714)
submit = st.form_submit_button("ANALYZE TRANSACTION", use_container_width=True)
if submit:
try:
dt_obj = datetime.strptime(trans_dt, "%d-%m-%Y %H:%M")
dob_obj = datetime.strptime(dob, "%d-%m-%Y")
age = (dt_obj - dob_obj).days // 365
distance = haversine(lat, lon, m_lat, m_lon)
num_feats = [
amt,
dt_obj.hour,
dt_obj.weekday(),
1 if dt_obj.weekday() >= 5 else 0,
float(unix_time),
city_pop,
age,
3600.0,
50.0,
15.0,
150.0,
0.01,
500.0,
lat,
lon,
m_lat,
m_lon,
distance,
0.0
]
cat_feats_order = list(config['cat_dims'].keys())
cat_encoded = []
for col in cat_feats_order:
if col == 'merchant':
val = str(merchant)
elif col == 'category':
val = str(category)
elif col == 'gender':
val = str(gender)
elif col == 'cc_num':
val = str(int(float(cc_num)))
elif col == 'job':
val = str(job)
elif col == 'city':
val = str(city)
elif col == 'state':
val = str(state)
elif col == 'zip':
val = str(int(float(zip_v)))
else:
val = "unknown"
if col in encoders and hasattr(encoders[col], 'classes_') and val in encoders[col].classes_:
cat_encoded.append(encoders[col].transform([val])[0])
else:
cat_encoded.append(0)
num_input = scaler.transform([num_feats])
with torch.no_grad():
cat_t = torch.tensor([cat_encoded], dtype=torch.long)
num_t = torch.tensor(num_input, dtype=torch.float32)
logits = model(cat_t, num_t)
prob = torch.sigmoid(logits).item()
st.divider()
if prob > 0.9898:
st.markdown(f'''
<div class="fraud-card">
<h2>FRAUD DETECTED</h2>
<h3>Confidence: <span style="color: #ff4b4b; font-weight: bold;">{prob*100:.1f}%</span></h3>
<p>Transaction flagged as high risk</p>
</div>
''', unsafe_allow_html=True)
else:
st.markdown(f'''
<div class="legit-card">
<h2>SECURE</h2>
<h3>Confidence: <span style="color: #00ffcc; font-weight: bold;">{(1-prob)*100:.1f}%</span></h3>
<p>Transaction appears legitimate</p>
</div>
''', unsafe_allow_html=True)
col1, col2 = st.columns(2)
with col1:
st.metric("Fraud Probability", f"{prob*100:.2f}%")
with col2:
st.metric("Optimal Threshold", "98.98%")
except Exception as e:
st.error(f"Processing Error: {str(e)}")
st.sidebar.markdown("""
<div style='text-align: center; padding: 20px; background-color: #1e2130; border-radius: 10px;'>
<h3>Model Performance</h3>
<p><b>F1-Score:</b> 0.8350</p>
<p><b>Precision:</b> 0.8706</p>
<p><b>Recall:</b> 0.7995</p>
<p><b>Threshold:</b> 0.9898</p>
</div>
""", unsafe_allow_html=True)
st.sidebar.info("MoE model with 4 specialized experts for fraud detection.")