rocky250 commited on
Commit
3dba51a
·
verified ·
1 Parent(s): a511c47

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +59 -71
src/streamlit_app.py CHANGED
@@ -13,9 +13,10 @@ st.set_page_config(page_title="FinTech Fraud Guard", layout="wide")
13
  st.markdown("""
14
  <style>
15
  .main { background-color: #0e1117; }
16
- .stButton>button { width: 100%; border-radius: 20px; background: linear-gradient(45deg, #ff4b4b, #ff8f8f); color: white; border: none; }
17
- .fraud-card { padding: 20px; border-radius: 15px; background-color: #1e2130; border-left: 5px solid #ff4b4b; margin-bottom: 20px; }
18
- .legit-card { padding: 20px; border-radius: 15px; background-color: #1e2130; border-left: 5px solid #00ffcc; margin-bottom: 20px; }
 
19
  </style>
20
  """, unsafe_allow_html=True)
21
 
@@ -94,16 +95,12 @@ class MoEFraudModel(nn.Module):
94
  def forward(self, cat_input, num_input):
95
  all_embs = [self.embeddings[c](cat_input[:, i]) for i, c in enumerate(self.cat_cols)]
96
  global_features = torch.cat([torch.cat(all_embs, dim=1), num_input], dim=1)
97
-
98
  weights = self.gating_network(global_features)
99
-
100
  h1 = self.expert1(self.get_features(cat_input, num_input, self.e1_cols_cat, self.e1_cols_num))
101
  h2 = self.expert2(self.get_features(cat_input, num_input, self.e2_cols_cat, self.e2_cols_num))
102
  h3 = self.expert3(self.get_features(cat_input, num_input, self.e3_cols_cat, self.e3_cols_num))
103
  h4 = self.expert4(self.get_features(cat_input, num_input, self.e4_cols_cat, self.e4_cols_num))
104
-
105
  h_final = (weights[:, 0:1]*h1 + weights[:, 1:2]*h2 + weights[:, 2:3]*h3 + weights[:, 3:4]*h4)
106
-
107
  return self.classifier(h_final)
108
 
109
  REPO_ID = "rocky250/FinTech"
@@ -114,82 +111,64 @@ def load_assets():
114
  config_path = hf_hub_download(repo_id=REPO_ID, filename="config.json")
115
  encoders_path = hf_hub_download(repo_id=REPO_ID, filename="label_encoders.joblib")
116
  scaler_path = hf_hub_download(repo_id=REPO_ID, filename="scaler.joblib")
117
-
118
  with open(config_path, 'r') as f:
119
  config = json.load(f)
120
-
121
  model = MoEFraudModel(config['cat_dims'], config['num_cols_map'], config['embed_dim'])
122
  model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))
123
  model.eval()
124
-
125
  return model, joblib.load(encoders_path), joblib.load(scaler_path), config
126
 
127
  model, encoders, scaler, config = load_assets()
128
 
129
- st.title("FinTech Fraud Analysis")
130
- st.markdown("### Proposed Mixture-of-Experts (MoE) Architecture")
131
 
132
- with st.form("transaction_form"):
133
- col1, col2, col3 = st.columns(3)
134
 
135
- with col1:
136
- st.subheader("Customer Info")
137
- first = st.text_input("First Name", "Jeff")
138
- last = st.text_input("Last Name", "Elliott")
 
 
 
 
 
 
 
 
 
139
  gender = st.selectbox("Gender", ["M", "F"])
140
- cc_num = st.text_input("Credit Card Number", "2.29116E+15")
141
-
142
- with col2:
143
- st.subheader("Transaction Info")
144
- trans_dt = st.text_input("Date & Time (DD-MM-YYYY HH:MM)", "21-06-2020 12:14")
145
- merchant = st.text_input("Merchant", "fraud_Kirlin and Sons")
146
- category = st.text_input("Category", "personal_care")
147
- amt = st.number_input("Amount ($)", value=2.86)
148
-
149
- with col3:
150
- st.subheader("Location Info")
151
- street = st.text_input("Street", "351 Darlene Green")
152
- city = st.text_input("City", "Birmingham")
153
- state = st.text_input("State", "AL")
154
- zip_code = st.text_input("Zip Code", "35201")
155
 
156
- submit = st.form_submit_button("ANALYZE TRANSACTION")
157
 
158
  if submit:
159
  try:
160
  dt_obj = datetime.strptime(trans_dt, "%d-%m-%Y %H:%M")
161
 
162
  num_feats = [
163
- amt,
164
- dt_obj.hour,
165
- dt_obj.weekday(),
166
- 1 if dt_obj.weekday() >= 5 else 0,
167
- dt_obj.timestamp(),
168
- 50000,
169
- 35,
170
- 3600,
171
- 50.0,
172
- 10.0,
173
- 100.0,
174
- 0.02,
175
- 1000,
176
- 33.5,
177
- -86.8,
178
- 33.6,
179
- -86.9,
180
- 15.5,
181
- 0
182
  ]
183
 
184
- cat_feats = [category, cc_num, gender, "Job", city, state, zip_code, merchant]
 
185
 
186
  num_input = scaler.transform([num_feats])
187
 
188
  cat_encoded = []
189
- for i, col in enumerate(config['cat_dims'].keys()):
190
- val = str(cat_feats[i])
191
- if val in encoders[col].classes_:
192
- cat_encoded.append(encoders[col].transform([val])[0])
193
  else:
194
  cat_encoded.append(0)
195
 
@@ -199,21 +178,30 @@ if submit:
199
  logits = model(cat_t, num_t)
200
  prob = torch.sigmoid(logits).item()
201
 
202
- st.divider()
203
  if prob > 0.5:
204
- st.markdown(f"""<div class="fraud-card">
205
- <h2>High Risk Detected!</h2>
206
- <p>Confidence: {prob*100:.2f}%</p>
207
- <p>This transaction matches fraud patterns identified by the Expert Gating Network.</p>
208
- </div>""", unsafe_allow_html=True)
 
 
 
209
  else:
210
- st.markdown(f"""<div class="legit-card">
211
- <h2>Transaction Safe</h2>
212
- <p>Confidence: {(1-prob)*100:.2f}%</p>
213
- <p>Analysis shows this is a legitimate transaction based on customer behavior Experts.</p>
214
- </div>""", unsafe_allow_html=True)
 
 
 
215
 
216
  except Exception as e:
217
- st.error(f"Error in processing: {e}")
218
 
219
- st.sidebar.info("This model uses a MoE Architecture with 4 specialized experts for Financial Fraud Detection.")
 
 
 
 
13
  st.markdown("""
14
  <style>
15
  .main { background-color: #0e1117; }
16
+ .stButton>button { width: 100%; border-radius: 20px; background: linear-gradient(45deg, #ff4b4b, #ff8f8f); color: white; border: none; font-weight: bold; }
17
+ .fraud-card { padding: 25px; border-radius: 15px; background-color: #1e2130; border-left: 8px solid #ff4b4b; margin-bottom: 20px; box-shadow: 0 4px 15px rgba(255, 75, 75, 0.2); }
18
+ .legit-card { padding: 25px; border-radius: 15px; background-color: #1e2130; border-left: 8px solid #00ffcc; margin-bottom: 20px; box-shadow: 0 4px 15px rgba(0, 255, 204, 0.1); }
19
+ .stTextInput>div>div>input, .stNumberInput>div>div>input { background-color: #262730; color: white; border-radius: 10px; }
20
  </style>
21
  """, unsafe_allow_html=True)
22
 
 
95
  def forward(self, cat_input, num_input):
96
  all_embs = [self.embeddings[c](cat_input[:, i]) for i, c in enumerate(self.cat_cols)]
97
  global_features = torch.cat([torch.cat(all_embs, dim=1), num_input], dim=1)
 
98
  weights = self.gating_network(global_features)
 
99
  h1 = self.expert1(self.get_features(cat_input, num_input, self.e1_cols_cat, self.e1_cols_num))
100
  h2 = self.expert2(self.get_features(cat_input, num_input, self.e2_cols_cat, self.e2_cols_num))
101
  h3 = self.expert3(self.get_features(cat_input, num_input, self.e3_cols_cat, self.e3_cols_num))
102
  h4 = self.expert4(self.get_features(cat_input, num_input, self.e4_cols_cat, self.e4_cols_num))
 
103
  h_final = (weights[:, 0:1]*h1 + weights[:, 1:2]*h2 + weights[:, 2:3]*h3 + weights[:, 3:4]*h4)
 
104
  return self.classifier(h_final)
105
 
106
  REPO_ID = "rocky250/FinTech"
 
111
  config_path = hf_hub_download(repo_id=REPO_ID, filename="config.json")
112
  encoders_path = hf_hub_download(repo_id=REPO_ID, filename="label_encoders.joblib")
113
  scaler_path = hf_hub_download(repo_id=REPO_ID, filename="scaler.joblib")
 
114
  with open(config_path, 'r') as f:
115
  config = json.load(f)
 
116
  model = MoEFraudModel(config['cat_dims'], config['num_cols_map'], config['embed_dim'])
117
  model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))
118
  model.eval()
 
119
  return model, joblib.load(encoders_path), joblib.load(scaler_path), config
120
 
121
  model, encoders, scaler, config = load_assets()
122
 
123
+ st.title("FinTech Mixture-of-Experts Fraud Guard")
124
+ st.markdown("---")
125
 
126
+ with st.form("main_form"):
127
+ c1, c2 = st.columns(2)
128
 
129
+ with c1:
130
+ st.subheader("Transaction Details")
131
+ trans_dt = st.text_input("Transaction Date & Time", value="21-06-2020 12:14")
132
+ amt = st.number_input("Transaction Amount ($)", value=2.86, format="%.2f")
133
+ cc_num = st.text_input("Credit Card Number", value="2291160000000000")
134
+ merchant = st.text_input("Merchant Name", value="fraud_Kirlin and Sons")
135
+ category = st.text_input("Category", value="personal_care")
136
+
137
+ with c2:
138
+ st.subheader("👤 Cardholder Info")
139
+ col_a, col_b = st.columns(2)
140
+ first_name = col_a.text_input("First Name", value="Jeff")
141
+ last_name = col_b.text_input("Last Name", value="Elliott")
142
  gender = st.selectbox("Gender", ["M", "F"])
143
+ street = st.text_input("Street Address", value="351 Darlene Green")
144
+ city = st.text_input("City", value="Birmingham")
145
+ state = st.text_input("State", value="AL")
146
+ zip_code = st.text_input("Zip Code", value="35201")
147
+ job = st.text_input("Job Title", value="Software Engineer")
 
 
 
 
 
 
 
 
 
 
148
 
149
+ submit = st.form_submit_button("RUN FRAUD ANALYSIS")
150
 
151
  if submit:
152
  try:
153
  dt_obj = datetime.strptime(trans_dt, "%d-%m-%Y %H:%M")
154
 
155
  num_feats = [
156
+ amt, dt_obj.hour, dt_obj.weekday(),
157
+ 1 if dt_obj.weekday() >= 5 else 0, dt_obj.timestamp(),
158
+ 50000, 35, 3600, 50.0, 10.0, 100.0,
159
+ 0.02, 1000, 33.5, -86.8, 33.6, -86.9, 15.5, 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  ]
161
 
162
+ cat_order = ['category', 'cc_num', 'gender', 'job', 'city', 'state', 'zip', 'merchant']
163
+ cat_vals = [category, cc_num, gender, job, city, state, zip_code, merchant]
164
 
165
  num_input = scaler.transform([num_feats])
166
 
167
  cat_encoded = []
168
+ for i, col_name in enumerate(cat_order):
169
+ val = str(cat_vals[i])
170
+ if val in encoders[col_name].classes_:
171
+ cat_encoded.append(encoders[col_name].transform([val])[0])
172
  else:
173
  cat_encoded.append(0)
174
 
 
178
  logits = model(cat_t, num_t)
179
  prob = torch.sigmoid(logits).item()
180
 
181
+ st.markdown("### Analysis Result")
182
  if prob > 0.5:
183
+ st.markdown(f"""
184
+ <div class="fraud-card">
185
+ <h2 style="color:#ff4b4b; margin-top:0;">High Risk Detected</h2>
186
+ <p style="font-size:1.2rem;">Fraud Probability: <strong>{prob*100:.2f}%</strong></p>
187
+ <hr style="border:0.5px solid #ff4b4b33;">
188
+ <p>The MoE Gating Network has flagged this transaction as highly suspicious based on merchant patterns and transaction timing.</p>
189
+ </div>
190
+ """, unsafe_allow_html=True)
191
  else:
192
+ st.markdown(f"""
193
+ <div class="legit-card">
194
+ <h2 style="color:#00ffcc; margin-top:0;">Transaction Verified</h2>
195
+ <p style="font-size:1.2rem;">Fraud Probability: <strong>{prob*100:.2f}%</strong></p>
196
+ <hr style="border:0.5px solid #00ffcc33;">
197
+ <p>Our experts have analyzed the spatial and behavioral features and determined this transaction is likely legitimate.</p>
198
+ </div>
199
+ """, unsafe_allow_html=True)
200
 
201
  except Exception as e:
202
+ st.error(f"Validation Error: Ensure Date Format is DD-MM-YYYY HH:MM. Raw Error: {e}")
203
 
204
+ st.sidebar.title("System Info")
205
+ st.sidebar.info("Model: MoE Neural Network")
206
+ st.sidebar.write(f"Repository: {REPO_ID}")
207
+ st.sidebar.write("Architecture: 4 Experts (Temporal, Behavioral, Frequency, Spatial)")