rocky250 commited on
Commit
7b90ff3
·
verified ·
1 Parent(s): 864a918

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +180 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,182 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
1
  import streamlit as st
2
+ import torch
3
+ import torch.nn as nn
4
+ import pandas as pd
5
+ import numpy as np
6
+ import joblib
7
+ import json
8
+ from datetime import datetime
9
+ from huggingface_hub import hf_hub_download
10
+
11
+ st.set_page_config(page_title="FinTech Fraud Guard", page_icon="🛡️", layout="wide")
12
+
13
+ st.markdown("""
14
+ <style>
15
+ .main { background-color: #0e1117; }
16
+ .stButton>button { width: 100%; border-radius: 20px; background: linear-gradient(45deg, #ff4b4b, #ff8f8f); color: white; border: none; }
17
+ .fraud-card { padding: 20px; border-radius: 15px; background-color: #1e2130; border-left: 5px solid #ff4b4b; margin-bottom: 20px; }
18
+ .legit-card { padding: 20px; border-radius: 15px; background-color: #1e2130; border-left: 5px solid #00ffcc; margin-bottom: 20px; }
19
+ </style>
20
+ """, unsafe_allow_html=True)
21
+
22
+ class MoEFraudModel(nn.Module):
23
+ def __init__(self, cat_dims, num_cols_map, embed_dim=8):
24
+ super(MoEFraudModel, self).__init__()
25
+ self.embeddings = nn.ModuleDict({
26
+ col: nn.Embedding(num_classes, embed_dim)
27
+ for col, num_classes in cat_dims.items()
28
+ })
29
+ self.cat_cols = list(cat_dims.keys())
30
+ self.num_cols = list(num_cols_map.keys())
31
+ self.cat_idx = {name: i for i, name in enumerate(self.cat_cols)}
32
+ self.num_idx = {name: i for i, name in enumerate(self.num_cols)}
33
+
34
+ total_input_dim = (len(self.cat_cols) * embed_dim) + len(self.num_cols)
35
+ self.gating_network = nn.Sequential(
36
+ nn.Linear(total_input_dim, 64), nn.BatchNorm1d(64), nn.ReLU(),
37
+ nn.Dropout(0.2), nn.Linear(64, 4), nn.Softmax(dim=1)
38
+ )
39
+
40
+ self.e1_cols_num = ['amt', 'hour', 'day_of_week', 'is_weekend', 'unix_time']
41
+ self.e1_cols_cat = ['category']
42
+ self.e2_cols_num = ['city_pop', 'age', 'time_diff_cc', 'cc_avg_amt_last_5', 'cc_std_amt_last_5', 'cc_max_amt_last_5']
43
+ self.e2_cols_cat = ['cc_num', 'gender', 'job', 'city', 'state', 'zip']
44
+ self.e3_cols_num = ['merchant_fraud_rate', 'merchant_txn_count']
45
+ self.e3_cols_cat = ['merchant', 'category']
46
+ self.e4_cols_num = ['lat', 'long', 'merch_lat', 'merch_long', 'distance_customer_merchant', 'state_mismatch_flag']
47
+ self.e4_cols_cat = []
48
+
49
+ self.expert1 = self._make_expert(self._get_dim(self.e1_cols_cat, self.e1_cols_num, embed_dim))
50
+ self.expert2 = self._make_expert(self._get_dim(self.e2_cols_cat, self.e2_cols_num, embed_dim))
51
+ self.expert3 = self._make_expert(self._get_dim(self.e3_cols_cat, self.e3_cols_num, embed_dim))
52
+ self.expert4 = self._make_expert(self._get_dim(self.e4_cols_cat, self.e4_cols_num, embed_dim))
53
+
54
+ self.classifier = nn.Sequential(nn.Linear(32, 16), nn.ReLU(), nn.Dropout(0.2), nn.Linear(16, 1))
55
+
56
+ def _get_dim(self, cats, nums, embed_dim): return len(nums) + (len(cats) * embed_dim)
57
+ def _make_expert(self, input_dim):
58
+ return nn.Sequential(nn.Linear(input_dim, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(0.2), nn.Linear(128, 32), nn.ReLU())
59
+
60
+ def get_features(self, cat_input, num_input, req_cat, req_num):
61
+ parts = []
62
+ if req_num:
63
+ indices = [self.num_idx[c] for c in req_num]
64
+ parts.append(num_input[:, indices])
65
+ if req_cat:
66
+ for c in req_cat:
67
+ idx = self.cat_idx[c]
68
+ emb = self.embeddings[c](cat_input[:, idx])
69
+ parts.append(emb)
70
+ return torch.cat(parts, dim=1)
71
+
72
+ def forward(self, cat_input, num_input):
73
+ all_embs = [self.embeddings[c](cat_input[:, i]) for i, c in enumerate(self.cat_cols)]
74
+ global_features = torch.cat([torch.cat(all_embs, dim=1), num_input], dim=1)
75
+ weights = self.gating_network(global_features)
76
+ h1 = self.expert1(self.get_features(cat_input, num_input, self.e1_cols_cat, self.e1_cols_num))
77
+ h2 = self.expert2(self.get_features(cat_input, num_input, self.e2_cols_cat, self.e2_cols_num))
78
+ h3 = self.expert3(self.get_features(cat_input, num_input, self.e3_cols_cat, self.e3_cols_num))
79
+ h4 = self.expert4(self.get_features(cat_input, num_input, self.e4_cols_cat, self.e4_cols_num))
80
+ h_final = (weights[:, 0:1]*h1 + weights[:, 1:2]*h2 + weights[:, 2:3]*h3 + weights[:, 3:4]*h4)
81
+ return self.classifier(h_final)
82
+
83
+ REPO_ID = "rocky250/FinTech"
84
+
85
+ @st.cache_resource
86
+ def load_assets():
87
+ weights_path = hf_hub_download(repo_id=REPO_ID, filename="proposed_moe_model.pth")
88
+ config_path = hf_hub_download(repo_id=REPO_ID, filename="config.json")
89
+ encoders_path = hf_hub_download(repo_id=REPO_ID, filename="label_encoders.joblib")
90
+ scaler_path = hf_hub_download(repo_id=REPO_ID, filename="scaler.joblib")
91
+
92
+ with open(config_path, 'r') as f:
93
+ config = json.load(f)
94
+
95
+ model = MoEFraudModel(config['cat_dims'], config['num_cols_map'], config['embed_dim'])
96
+ model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))
97
+ model.eval()
98
+
99
+ return model, joblib.load(encoders_path), joblib.load(scaler_path), config
100
+
101
+ model, encoders, scaler, config = load_assets()
102
+
103
+ st.title("FinTech Fraud Analysis")
104
+ st.markdown("### Proposed Mixture-of-Experts (MoE) Architecture")
105
+
106
+ with st.form("transaction_form"):
107
+ col1, col2, col3 = st.columns(3)
108
+
109
+ with col1:
110
+ st.subheader("Customer Info")
111
+ first = st.text_input("First Name", "Jeff")
112
+ last = st.text_input("Last Name", "Elliott")
113
+ gender = st.selectbox("Gender", ["M", "F"])
114
+ cc_num = st.text_input("Credit Card Number", "2.29116E+15")
115
+
116
+ with col2:
117
+ st.subheader("Transaction Info")
118
+ trans_dt = st.text_input("Date & Time (DD-MM-YYYY HH:MM)", "21-06-2020 12:14")
119
+ merchant = st.text_input("Merchant", "fraud_Kirlin and Sons")
120
+ category = st.text_input("Category", "personal_care")
121
+ amt = st.number_input("Amount ($)", value=2.86)
122
+
123
+ with col3:
124
+ st.subheader("Location Info")
125
+ street = st.text_input("Street", "351 Darlene Green")
126
+ city = st.text_input("City", "Birmingham")
127
+ state = st.text_input("State", "AL")
128
+ zip_code = st.text_input("Zip Code", "35201")
129
+
130
+ submit = st.form_submit_button("ANALYZE TRANSACTION")
131
+
132
+ if submit:
133
+ try:
134
+ dt_obj = datetime.strptime(trans_dt, "%d-%m-%Y %H:%M")
135
+
136
+ Build numerical feature vector (Must match the 19 features in training)
137
+ Order: ['amt', 'hour', 'day_of_week', 'is_weekend', 'unix_time', 'city_pop', 'age',
138
+ 'time_diff_cc', 'cc_avg_amt_last_5', 'cc_std_amt_last_5', 'cc_max_amt_last_5',
139
+ 'merchant_fraud_rate', 'merchant_txn_count', 'lat', 'long', 'merch_lat',
140
+ 'merch_long', 'distance_customer_merchant', 'state_mismatch_flag']
141
+
142
+ num_feats = [
143
+ amt, dt_obj.hour, dt_obj.weekday(), 1 if dt_obj.weekday() >= 5 else 0, dt_obj.timestamp(),
144
+ 50000, 35, 3600, 50.0, 10.0, 100.0, 0.02, 1000, 33.5, -86.8, 33.6, -86.9, 15.5, 0
145
+ ]
146
+
147
+ cat_feats = [category, cc_num, gender, "Job", city, state, zip_code, merchant]
148
+
149
+ num_input = scaler.transform([num_feats])
150
+
151
+ cat_encoded = []
152
+ for i, col in enumerate(config['cat_dims'].keys()):
153
+ val = str(cat_feats[i])
154
+ if val in encoders[col].classes_:
155
+ cat_encoded.append(encoders[col].transform([val])[0])
156
+ else:
157
+ cat_encoded.append(0)
158
+
159
+ with torch.no_grad():
160
+ cat_t = torch.tensor([cat_encoded], dtype=torch.long)
161
+ num_t = torch.tensor(num_input, dtype=torch.float32)
162
+ logits = model(cat_t, num_t)
163
+ prob = torch.sigmoid(logits).item()
164
+
165
+ st.divider()
166
+ if prob > 0.5:
167
+ st.markdown(f"""<div class="fraud-card">
168
+ <h2>High Risk Detected!</h2>
169
+ <p>Confidence: {prob*100:.2f}%</p>
170
+ <p>This transaction matches fraud patterns identified by the Expert Gating Network.</p>
171
+ </div>""", unsafe_allow_html=True)
172
+ else:
173
+ st.markdown(f"""<div class="legit-card">
174
+ <h2>Transaction Safe</h2>
175
+ <p>Confidence: {(1-prob)*100:.2f}%</p>
176
+ <p>Analysis shows this is a legitimate transaction based on customer behavior Experts.</p>
177
+ </div>""", unsafe_allow_html=True)
178
+
179
+ except Exception as e:
180
+ st.error(f"Error in processing: {e}")
181
 
182
+ st.sidebar.info("This model uses a **MoE Architecture** with 4 specialized experts for Financial Fraud Detection.")