| import streamlit as st |
| import torch |
| import torch.nn as nn |
| import pandas as pd |
| import numpy as np |
| import joblib |
| import json |
| from datetime import datetime |
| from huggingface_hub import hf_hub_download |
|
|
| st.set_page_config(page_title="FinTech Fraud Guard", layout="wide") |
| st.markdown(""" |
| <style> |
| .main { background-color: #0e1117; } |
| .stButton>button { width: 100%; border-radius: 20px; background: linear-gradient(45deg, #ff4b4b, #ff8f8f); color: white; border: none; } |
| .fraud-card { padding: 20px; border-radius: 15px; background-color: #1e2130; border-left: 5px solid #ff4b4b; margin-bottom: 20px; } |
| .legit-card { padding: 20px; border-radius: 15px; background-color: #1e2130; border-left: 5px solid #00ffcc; margin-bottom: 20px; } |
| </style> |
| """, unsafe_allow_html=True) |
|
|
| class MoEFraudModel(nn.Module): |
| def __init__(self, cat_dims, num_cols_map, embed_dim=8): |
| super(MoEFraudModel, self).__init__() |
| self.embeddings = nn.ModuleDict({ |
| col: nn.Embedding(num_classes, embed_dim) |
| for col, num_classes in cat_dims.items() |
| }) |
| self.cat_cols = list(cat_dims.keys()) |
| self.num_cols = list(num_cols_map.keys()) |
| self.cat_idx = {name: i for i, name in enumerate(self.cat_cols)} |
| self.num_idx = {name: i for i, name in enumerate(self.num_cols)} |
| |
| total_input_dim = (len(self.cat_cols) * embed_dim) + len(self.num_cols) |
| |
| self.gating_network = nn.Sequential( |
| nn.Linear(total_input_dim, 64), |
| nn.BatchNorm1d(64), |
| nn.ReLU(), |
| nn.Dropout(0.2), |
| nn.Linear(64, 4), |
| nn.Softmax(dim=1) |
| ) |
| |
| self.e1_cols_num = ['amt', 'hour', 'day_of_week', 'is_weekend', 'unix_time'] |
| self.e1_cols_cat = ['category'] |
| |
| self.e2_cols_num = ['city_pop', 'age', 'time_diff_cc', 'cc_avg_amt_last_5', 'cc_std_amt_last_5', 'cc_max_amt_last_5'] |
| self.e2_cols_cat = ['cc_num', 'gender', 'job', 'city', 'state', 'zip'] |
| |
| self.e3_cols_num = ['merchant_fraud_rate', 'merchant_txn_count'] |
| self.e3_cols_cat = ['merchant', 'category'] |
| |
| self.e4_cols_num = ['lat', 'long', 'merch_lat', 'merch_long', 'distance_customer_merchant', 'state_mismatch_flag'] |
| self.e4_cols_cat = [] |
| |
| self.expert1 = self._make_expert(self._get_dim(self.e1_cols_cat, self.e1_cols_num, embed_dim)) |
| self.expert2 = self._make_expert(self._get_dim(self.e2_cols_cat, self.e2_cols_num, embed_dim)) |
| self.expert3 = self._make_expert(self._get_dim(self.e3_cols_cat, self.e3_cols_num, embed_dim)) |
| self.expert4 = self._make_expert(self._get_dim(self.e4_cols_cat, self.e4_cols_num, embed_dim)) |
| |
| self.classifier = nn.Sequential( |
| nn.Linear(32, 16), |
| nn.ReLU(), |
| nn.Dropout(0.2), |
| nn.Linear(16, 1) |
| ) |
|
|
| def _get_dim(self, cats, nums, embed_dim): |
| return len(nums) + (len(cats) * embed_dim) |
|
|
| def _make_expert(self, input_dim): |
| return nn.Sequential( |
| nn.Linear(input_dim, 128), |
| nn.BatchNorm1d(128), |
| nn.ReLU(), |
| nn.Dropout(0.2), |
| nn.Linear(128, 32), |
| nn.ReLU() |
| ) |
|
|
| def get_features(self, cat_input, num_input, req_cat, req_num): |
| parts = [] |
| if req_num: |
| indices = [self.num_idx[c] for c in req_num] |
| parts.append(num_input[:, indices]) |
| if req_cat: |
| for c in req_cat: |
| idx = self.cat_idx[c] |
| emb = self.embeddings[c](cat_input[:, idx]) |
| parts.append(emb) |
| return torch.cat(parts, dim=1) |
|
|
| def forward(self, cat_input, num_input): |
| all_embs = [self.embeddings[c](cat_input[:, i]) for i, c in enumerate(self.cat_cols)] |
| global_features = torch.cat([torch.cat(all_embs, dim=1), num_input], dim=1) |
| |
| weights = self.gating_network(global_features) |
| |
| h1 = self.expert1(self.get_features(cat_input, num_input, self.e1_cols_cat, self.e1_cols_num)) |
| h2 = self.expert2(self.get_features(cat_input, num_input, self.e2_cols_cat, self.e2_cols_num)) |
| h3 = self.expert3(self.get_features(cat_input, num_input, self.e3_cols_cat, self.e3_cols_num)) |
| h4 = self.expert4(self.get_features(cat_input, num_input, self.e4_cols_cat, self.e4_cols_num)) |
| |
| h_final = (weights[:, 0:1]*h1 + weights[:, 1:2]*h2 + weights[:, 2:3]*h3 + weights[:, 3:4]*h4) |
| |
| return self.classifier(h_final) |
|
|
| def haversine(lat1, lon1, lat2, lon2): |
| r = 6371 |
| phi1, phi2 = np.radians(lat1), np.radians(lat2) |
| dphi = np.radians(lat2 - lat1) |
| dlambda = np.radians(lon2 - lon1) |
| a = np.sin(dphi/2)**2 + np.cos(phi1)*np.cos(phi2)*np.sin(dlambda/2)**2 |
| return 2 * r * np.arcsin(np.sqrt(a)) |
|
|
| @st.cache_resource |
| def load_assets(): |
| REPO_ID = "rocky250/FinTech" |
| w_p = hf_hub_download(repo_id=REPO_ID, filename="proposed_moe_model.pth") |
| c_p = hf_hub_download(repo_id=REPO_ID, filename="config.json") |
| e_p = hf_hub_download(repo_id=REPO_ID, filename="label_encoders.joblib") |
| s_p = hf_hub_download(repo_id=REPO_ID, filename="scaler.joblib") |
| |
| with open(c_p, 'r') as f: |
| config = json.load(f) |
| |
| model = MoEFraudModel(config['cat_dims'], config['num_cols_map'], config['embed_dim']) |
| model.load_state_dict(torch.load(w_p, map_location=torch.device('cpu'))) |
| model.eval() |
| |
| return model, joblib.load(e_p), joblib.load(s_p), config |
|
|
| model, encoders, scaler, config = load_assets() |
|
|
| st.title("FinTech Fraud Guard") |
| st.markdown("MoE Transaction Verifier - Optimized Threshold: 0.9898") |
|
|
| with st.form("transaction_form"): |
| c1, c2, c3 = st.columns(3) |
| |
| with c1: |
| st.subheader("Personal & Card") |
| first = st.text_input("First Name", "Jeff") |
| last = st.text_input("Last Name", "Elliott") |
| gender = st.selectbox("Gender", ["M", "F"]) |
| dob = st.text_input("Date of Birth (DD-MM-YYYY)", "19-03-1968") |
| cc_num = st.text_input("CC Number", "3725537864060026") |
| job = st.text_input("Job", "Mechanical engineer") |
| |
| with c2: |
| st.subheader("Transaction Details") |
| trans_dt = st.text_input("Transaction Time (DD-MM-YYYY HH:MM)", "21-06-2020 12:14") |
| merchant = st.text_input("Merchant", "fraud_Kirlin and Sons") |
| category = st.text_input("Category", "personal_care") |
| amt = st.number_input("Amount ($)", value=2.86, min_value=0.01) |
| unix_time = st.number_input("Unix Time", value=1371816865.0) |
| |
| with c3: |
| st.subheader("Geography") |
| street = st.text_input("Street", "351 Darlene Green") |
| city = st.text_input("City", "Columbia") |
| state = st.text_input("State", "SC") |
| zip_v = st.text_input("Zip Code", "29209") |
| city_pop = st.number_input("City Population", value=333497, min_value=0) |
| lat = st.number_input("Customer Lat", value=33.9659) |
| lon = st.number_input("Customer Long", value=-80.9355) |
| m_lat = st.number_input("Merchant Lat", value=33.986391) |
| m_lon = st.number_input("Merchant Long", value=-81.200714) |
| |
| submit = st.form_submit_button("ANALYZE TRANSACTION", use_container_width=True) |
|
|
| if submit: |
| try: |
| dt_obj = datetime.strptime(trans_dt, "%d-%m-%Y %H:%M") |
| dob_obj = datetime.strptime(dob, "%d-%m-%Y") |
| age = (dt_obj - dob_obj).days // 365 |
| |
| distance = haversine(lat, lon, m_lat, m_lon) |
| |
| num_feats = [ |
| amt, |
| dt_obj.hour, |
| dt_obj.weekday(), |
| 1 if dt_obj.weekday() >= 5 else 0, |
| float(unix_time), |
| city_pop, |
| age, |
| 3600.0, |
| 50.0, |
| 15.0, |
| 150.0, |
| 0.01, |
| 500.0, |
| lat, |
| lon, |
| m_lat, |
| m_lon, |
| distance, |
| 0.0 |
| ] |
| |
| cat_feats_order = list(config['cat_dims'].keys()) |
| cat_encoded = [] |
| |
| for col in cat_feats_order: |
| if col == 'merchant': |
| val = str(merchant) |
| elif col == 'category': |
| val = str(category) |
| elif col == 'gender': |
| val = str(gender) |
| elif col == 'cc_num': |
| val = str(int(float(cc_num))) |
| elif col == 'job': |
| val = str(job) |
| elif col == 'city': |
| val = str(city) |
| elif col == 'state': |
| val = str(state) |
| elif col == 'zip': |
| val = str(int(float(zip_v))) |
| else: |
| val = "unknown" |
| |
| if col in encoders and hasattr(encoders[col], 'classes_') and val in encoders[col].classes_: |
| cat_encoded.append(encoders[col].transform([val])[0]) |
| else: |
| cat_encoded.append(0) |
| |
| num_input = scaler.transform([num_feats]) |
| |
| with torch.no_grad(): |
| cat_t = torch.tensor([cat_encoded], dtype=torch.long) |
| num_t = torch.tensor(num_input, dtype=torch.float32) |
| logits = model(cat_t, num_t) |
| prob = torch.sigmoid(logits).item() |
| |
| st.divider() |
| |
| if prob > 0.9898: |
| st.markdown(f''' |
| <div class="fraud-card"> |
| <h2>FRAUD DETECTED</h2> |
| <h3>Confidence: <span style="color: #ff4b4b; font-weight: bold;">{prob*100:.1f}%</span></h3> |
| <p>Transaction flagged as high risk</p> |
| </div> |
| ''', unsafe_allow_html=True) |
| else: |
| st.markdown(f''' |
| <div class="legit-card"> |
| <h2>SECURE</h2> |
| <h3>Confidence: <span style="color: #00ffcc; font-weight: bold;">{(1-prob)*100:.1f}%</span></h3> |
| <p>Transaction appears legitimate</p> |
| </div> |
| ''', unsafe_allow_html=True) |
| |
| col1, col2 = st.columns(2) |
| with col1: |
| st.metric("Fraud Probability", f"{prob*100:.2f}%") |
| with col2: |
| st.metric("Optimal Threshold", "98.98%") |
| |
| except Exception as e: |
| st.error(f"Processing Error: {str(e)}") |
|
|
| st.sidebar.markdown(""" |
| <div style='text-align: center; padding: 20px; background-color: #1e2130; border-radius: 10px;'> |
| <h3>Model Performance</h3> |
| <p><b>F1-Score:</b> 0.8350</p> |
| <p><b>Precision:</b> 0.8706</p> |
| <p><b>Recall:</b> 0.7995</p> |
| <p><b>Threshold:</b> 0.9898</p> |
| </div> |
| """, unsafe_allow_html=True) |
|
|
| st.sidebar.info("MoE model with 4 specialized experts for fraud detection.") |
|
|