rocky250 commited on
Commit
a511c47
·
verified ·
1 Parent(s): 69484be

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +54 -17
src/streamlit_app.py CHANGED
@@ -8,7 +8,7 @@ import json
8
  from datetime import datetime
9
  from huggingface_hub import hf_hub_download
10
 
11
- st.set_page_config(page_title="FinTech Fraud Guard", page_icon="🛡️", layout="wide")
12
 
13
  st.markdown("""
14
  <style>
@@ -32,17 +32,25 @@ class MoEFraudModel(nn.Module):
32
  self.num_idx = {name: i for i, name in enumerate(self.num_cols)}
33
 
34
  total_input_dim = (len(self.cat_cols) * embed_dim) + len(self.num_cols)
 
35
  self.gating_network = nn.Sequential(
36
- nn.Linear(total_input_dim, 64), nn.BatchNorm1d(64), nn.ReLU(),
37
- nn.Dropout(0.2), nn.Linear(64, 4), nn.Softmax(dim=1)
 
 
 
 
38
  )
39
 
40
  self.e1_cols_num = ['amt', 'hour', 'day_of_week', 'is_weekend', 'unix_time']
41
  self.e1_cols_cat = ['category']
 
42
  self.e2_cols_num = ['city_pop', 'age', 'time_diff_cc', 'cc_avg_amt_last_5', 'cc_std_amt_last_5', 'cc_max_amt_last_5']
43
  self.e2_cols_cat = ['cc_num', 'gender', 'job', 'city', 'state', 'zip']
 
44
  self.e3_cols_num = ['merchant_fraud_rate', 'merchant_txn_count']
45
  self.e3_cols_cat = ['merchant', 'category']
 
46
  self.e4_cols_num = ['lat', 'long', 'merch_lat', 'merch_long', 'distance_customer_merchant', 'state_mismatch_flag']
47
  self.e4_cols_cat = []
48
 
@@ -51,11 +59,25 @@ class MoEFraudModel(nn.Module):
51
  self.expert3 = self._make_expert(self._get_dim(self.e3_cols_cat, self.e3_cols_num, embed_dim))
52
  self.expert4 = self._make_expert(self._get_dim(self.e4_cols_cat, self.e4_cols_num, embed_dim))
53
 
54
- self.classifier = nn.Sequential(nn.Linear(32, 16), nn.ReLU(), nn.Dropout(0.2), nn.Linear(16, 1))
 
 
 
 
 
 
 
 
55
 
56
- def _get_dim(self, cats, nums, embed_dim): return len(nums) + (len(cats) * embed_dim)
57
  def _make_expert(self, input_dim):
58
- return nn.Sequential(nn.Linear(input_dim, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(0.2), nn.Linear(128, 32), nn.ReLU())
 
 
 
 
 
 
 
59
 
60
  def get_features(self, cat_input, num_input, req_cat, req_num):
61
  parts = []
@@ -72,12 +94,16 @@ class MoEFraudModel(nn.Module):
72
  def forward(self, cat_input, num_input):
73
  all_embs = [self.embeddings[c](cat_input[:, i]) for i, c in enumerate(self.cat_cols)]
74
  global_features = torch.cat([torch.cat(all_embs, dim=1), num_input], dim=1)
 
75
  weights = self.gating_network(global_features)
 
76
  h1 = self.expert1(self.get_features(cat_input, num_input, self.e1_cols_cat, self.e1_cols_num))
77
  h2 = self.expert2(self.get_features(cat_input, num_input, self.e2_cols_cat, self.e2_cols_num))
78
  h3 = self.expert3(self.get_features(cat_input, num_input, self.e3_cols_cat, self.e3_cols_num))
79
  h4 = self.expert4(self.get_features(cat_input, num_input, self.e4_cols_cat, self.e4_cols_num))
 
80
  h_final = (weights[:, 0:1]*h1 + weights[:, 1:2]*h2 + weights[:, 2:3]*h3 + weights[:, 3:4]*h4)
 
81
  return self.classifier(h_final)
82
 
83
  REPO_ID = "rocky250/FinTech"
@@ -133,19 +159,30 @@ if submit:
133
  try:
134
  dt_obj = datetime.strptime(trans_dt, "%d-%m-%Y %H:%M")
135
 
136
- Build numerical feature vector (Must match the 19 features in training)
137
- Order: ['amt', 'hour', 'day_of_week', 'is_weekend', 'unix_time', 'city_pop', 'age',
138
- 'time_diff_cc', 'cc_avg_amt_last_5', 'cc_std_amt_last_5', 'cc_max_amt_last_5',
139
- 'merchant_fraud_rate', 'merchant_txn_count', 'lat', 'long', 'merch_lat',
140
- 'merch_long', 'distance_customer_merchant', 'state_mismatch_flag']
141
-
142
  num_feats = [
143
- amt, dt_obj.hour, dt_obj.weekday(), 1 if dt_obj.weekday() >= 5 else 0, dt_obj.timestamp(),
144
- 50000, 35, 3600, 50.0, 10.0, 100.0, 0.02, 1000, 33.5, -86.8, 33.6, -86.9, 15.5, 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  ]
146
 
147
  cat_feats = [category, cc_num, gender, "Job", city, state, zip_code, merchant]
148
-
149
  num_input = scaler.transform([num_feats])
150
 
151
  cat_encoded = []
@@ -154,7 +191,7 @@ if submit:
154
  if val in encoders[col].classes_:
155
  cat_encoded.append(encoders[col].transform([val])[0])
156
  else:
157
- cat_encoded.append(0)
158
 
159
  with torch.no_grad():
160
  cat_t = torch.tensor([cat_encoded], dtype=torch.long)
@@ -179,4 +216,4 @@ if submit:
179
  except Exception as e:
180
  st.error(f"Error in processing: {e}")
181
 
182
- st.sidebar.info("This model uses a **MoE Architecture** with 4 specialized experts for Financial Fraud Detection.")
 
8
  from datetime import datetime
9
  from huggingface_hub import hf_hub_download
10
 
11
+ st.set_page_config(page_title="FinTech Fraud Guard", layout="wide")
12
 
13
  st.markdown("""
14
  <style>
 
32
  self.num_idx = {name: i for i, name in enumerate(self.num_cols)}
33
 
34
  total_input_dim = (len(self.cat_cols) * embed_dim) + len(self.num_cols)
35
+
36
  self.gating_network = nn.Sequential(
37
+ nn.Linear(total_input_dim, 64),
38
+ nn.BatchNorm1d(64),
39
+ nn.ReLU(),
40
+ nn.Dropout(0.2),
41
+ nn.Linear(64, 4),
42
+ nn.Softmax(dim=1)
43
  )
44
 
45
  self.e1_cols_num = ['amt', 'hour', 'day_of_week', 'is_weekend', 'unix_time']
46
  self.e1_cols_cat = ['category']
47
+
48
  self.e2_cols_num = ['city_pop', 'age', 'time_diff_cc', 'cc_avg_amt_last_5', 'cc_std_amt_last_5', 'cc_max_amt_last_5']
49
  self.e2_cols_cat = ['cc_num', 'gender', 'job', 'city', 'state', 'zip']
50
+
51
  self.e3_cols_num = ['merchant_fraud_rate', 'merchant_txn_count']
52
  self.e3_cols_cat = ['merchant', 'category']
53
+
54
  self.e4_cols_num = ['lat', 'long', 'merch_lat', 'merch_long', 'distance_customer_merchant', 'state_mismatch_flag']
55
  self.e4_cols_cat = []
56
 
 
59
  self.expert3 = self._make_expert(self._get_dim(self.e3_cols_cat, self.e3_cols_num, embed_dim))
60
  self.expert4 = self._make_expert(self._get_dim(self.e4_cols_cat, self.e4_cols_num, embed_dim))
61
 
62
+ self.classifier = nn.Sequential(
63
+ nn.Linear(32, 16),
64
+ nn.ReLU(),
65
+ nn.Dropout(0.2),
66
+ nn.Linear(16, 1)
67
+ )
68
+
69
+ def _get_dim(self, cats, nums, embed_dim):
70
+ return len(nums) + (len(cats) * embed_dim)
71
 
 
72
  def _make_expert(self, input_dim):
73
+ return nn.Sequential(
74
+ nn.Linear(input_dim, 128),
75
+ nn.BatchNorm1d(128),
76
+ nn.ReLU(),
77
+ nn.Dropout(0.2),
78
+ nn.Linear(128, 32),
79
+ nn.ReLU()
80
+ )
81
 
82
  def get_features(self, cat_input, num_input, req_cat, req_num):
83
  parts = []
 
94
  def forward(self, cat_input, num_input):
95
  all_embs = [self.embeddings[c](cat_input[:, i]) for i, c in enumerate(self.cat_cols)]
96
  global_features = torch.cat([torch.cat(all_embs, dim=1), num_input], dim=1)
97
+
98
  weights = self.gating_network(global_features)
99
+
100
  h1 = self.expert1(self.get_features(cat_input, num_input, self.e1_cols_cat, self.e1_cols_num))
101
  h2 = self.expert2(self.get_features(cat_input, num_input, self.e2_cols_cat, self.e2_cols_num))
102
  h3 = self.expert3(self.get_features(cat_input, num_input, self.e3_cols_cat, self.e3_cols_num))
103
  h4 = self.expert4(self.get_features(cat_input, num_input, self.e4_cols_cat, self.e4_cols_num))
104
+
105
  h_final = (weights[:, 0:1]*h1 + weights[:, 1:2]*h2 + weights[:, 2:3]*h3 + weights[:, 3:4]*h4)
106
+
107
  return self.classifier(h_final)
108
 
109
  REPO_ID = "rocky250/FinTech"
 
159
  try:
160
  dt_obj = datetime.strptime(trans_dt, "%d-%m-%Y %H:%M")
161
 
 
 
 
 
 
 
162
  num_feats = [
163
+ amt,
164
+ dt_obj.hour,
165
+ dt_obj.weekday(),
166
+ 1 if dt_obj.weekday() >= 5 else 0,
167
+ dt_obj.timestamp(),
168
+ 50000,
169
+ 35,
170
+ 3600,
171
+ 50.0,
172
+ 10.0,
173
+ 100.0,
174
+ 0.02,
175
+ 1000,
176
+ 33.5,
177
+ -86.8,
178
+ 33.6,
179
+ -86.9,
180
+ 15.5,
181
+ 0
182
  ]
183
 
184
  cat_feats = [category, cc_num, gender, "Job", city, state, zip_code, merchant]
185
+
186
  num_input = scaler.transform([num_feats])
187
 
188
  cat_encoded = []
 
191
  if val in encoders[col].classes_:
192
  cat_encoded.append(encoders[col].transform([val])[0])
193
  else:
194
+ cat_encoded.append(0)
195
 
196
  with torch.no_grad():
197
  cat_t = torch.tensor([cat_encoded], dtype=torch.long)
 
216
  except Exception as e:
217
  st.error(f"Error in processing: {e}")
218
 
219
+ st.sidebar.info("This model uses a MoE Architecture with 4 specialized experts for Financial Fraud Detection.")