rocky250 commited on
Commit
39fdda1
·
verified ·
1 Parent(s): 47b09f7

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +64 -65
src/streamlit_app.py CHANGED
@@ -13,10 +13,9 @@ st.set_page_config(page_title="FinTech Fraud Guard", layout="wide")
13
  st.markdown("""
14
  <style>
15
  .main { background-color: #0e1117; }
16
- .stButton>button { width: 100%; border-radius: 20px; background: linear-gradient(45deg, #ff4b4b, #ff8f8f); color: white; border: none; font-weight: bold; }
17
- .fraud-card { padding: 25px; border-radius: 15px; background-color: #1e2130; border-left: 8px solid #ff4b4b; margin-bottom: 20px; box-shadow: 0 4px 15px rgba(255, 75, 75, 0.2); }
18
- .legit-card { padding: 25px; border-radius: 15px; background-color: #1e2130; border-left: 8px solid #00ffcc; margin-bottom: 20px; box-shadow: 0 4px 15px rgba(0, 255, 204, 0.1); }
19
- .stTextInput>div>div>input, .stNumberInput>div>div>input { background-color: #262730; color: white; border-radius: 10px; }
20
  </style>
21
  """, unsafe_allow_html=True)
22
 
@@ -103,100 +102,100 @@ class MoEFraudModel(nn.Module):
103
  h_final = (weights[:, 0:1]*h1 + weights[:, 1:2]*h2 + weights[:, 2:3]*h3 + weights[:, 3:4]*h4)
104
  return self.classifier(h_final)
105
 
 
 
 
 
 
 
 
 
106
  REPO_ID = "rocky250/FinTech"
107
 
108
  @st.cache_resource
109
  def load_assets():
110
- weights_path = hf_hub_download(repo_id=REPO_ID, filename="proposed_moe_model.pth")
111
- config_path = hf_hub_download(repo_id=REPO_ID, filename="config.json")
112
- encoders_path = hf_hub_download(repo_id=REPO_ID, filename="label_encoders.joblib")
113
- scaler_path = hf_hub_download(repo_id=REPO_ID, filename="scaler.joblib")
114
- with open(config_path, 'r') as f:
115
  config = json.load(f)
116
  model = MoEFraudModel(config['cat_dims'], config['num_cols_map'], config['embed_dim'])
117
- model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))
118
  model.eval()
119
- return model, joblib.load(encoders_path), joblib.load(scaler_path), config
120
 
121
  model, encoders, scaler, config = load_assets()
122
 
123
- st.title("FinTech Mixture-of-Experts Fraud Guard")
124
- st.markdown("---")
125
 
126
- with st.form("main_form"):
127
- c1, c2 = st.columns(2)
128
-
129
  with c1:
130
- st.subheader("Transaction Details")
131
- trans_dt = st.text_input("Transaction Date & Time", value="21-06-2020 12:14")
132
- amt = st.number_input("Transaction Amount ($)", value=2.86, format="%.2f")
133
- cc_num = st.text_input("Credit Card Number", "2.29116E+15")
134
- merchant = st.text_input("Merchant Name", value="fraud_Kirlin and Sons")
135
- category = st.text_input("Category", value="personal_care")
136
-
137
- with c2:
138
- st.subheader("Cardholder Info")
139
- col_a, col_b = st.columns(2)
140
- first_name = col_a.text_input("First Name", value="Jeff")
141
- last_name = col_b.text_input("Last Name", value="Elliott")
142
  gender = st.selectbox("Gender", ["M", "F"])
143
- street = st.text_input("Street Address", value="351 Darlene Green")
144
-
145
- submit = st.form_submit_button("RUN FRAUD ANALYSIS")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
  if submit:
148
  try:
149
  dt_obj = datetime.strptime(trans_dt, "%d-%m-%Y %H:%M")
 
 
 
150
 
151
  num_feats = [
152
- amt, dt_obj.hour, dt_obj.weekday(),
153
- 1 if dt_obj.weekday() >= 5 else 0, dt_obj.timestamp(),
154
- 50000, 35, 3600, 50.0, 10.0, 100.0,
155
- 0.02, 1000, 33.5, -86.8, 33.6, -86.9, 15.5, 0
156
  ]
157
 
158
- cat_order = ['category', 'cc_num', 'gender', 'job', 'city', 'state', 'zip', 'merchant']
159
- cat_vals = [category, cc_num, gender, job, city, state, zip_code, merchant]
160
 
161
  num_input = scaler.transform([num_feats])
162
-
163
  cat_encoded = []
164
- for i, col_name in enumerate(cat_order):
165
- val = str(cat_vals[i])
166
- if val in encoders[col_name].classes_:
167
- cat_encoded.append(encoders[col_name].transform([val])[0])
168
  else:
169
  cat_encoded.append(0)
170
 
171
  with torch.no_grad():
172
  cat_t = torch.tensor([cat_encoded], dtype=torch.long)
173
  num_t = torch.tensor(num_input, dtype=torch.float32)
174
- logits = model(cat_t, num_t)
175
- prob = torch.sigmoid(logits).item()
176
 
177
- st.markdown("### Analysis Result")
178
  if prob > 0.5:
179
- st.markdown(f"""
180
- <div class="fraud-card">
181
- <h2 style="color:#ff4b4b; margin-top:0;">High Risk Detected</h2>
182
- <p style="font-size:1.2rem;">Fraud Probability: <strong>{prob*100:.2f}%</strong></p>
183
- <hr style="border:0.5px solid #ff4b4b33;">
184
- <p>The MoE Gating Network has flagged this transaction as highly suspicious based on merchant patterns and transaction timing.</p>
185
- </div>
186
- """, unsafe_allow_html=True)
187
  else:
188
- st.markdown(f"""
189
- <div class="legit-card">
190
- <h2 style="color:#00ffcc; margin-top:0;">Transaction Verified</h2>
191
- <p style="font-size:1.2rem;">Fraud Probability: <strong>{prob*100:.2f}%</strong></p>
192
- <hr style="border:0.5px solid #00ffcc33;">
193
- <p>Our experts have analyzed the spatial and behavioral features and determined this transaction is likely legitimate.</p>
194
- </div>
195
- """, unsafe_allow_html=True)
196
 
197
  except Exception as e:
198
- st.error(f"Validation Error: Ensure Date Format is DD-MM-YYYY HH:MM. Raw Error: {e}")
199
 
200
- st.sidebar.title("System Info")
201
- st.sidebar.info("Model: MoE Neural Network")
202
- st.sidebar.write("Architecture: 4 Experts (Temporal, Behavioral, Frequency, Spatial)")
 
13
  st.markdown("""
14
  <style>
15
  .main { background-color: #0e1117; }
16
+ .stButton>button { width: 100%; border-radius: 20px; background: linear-gradient(45deg, #ff4b4b, #ff8f8f); color: white; border: none; }
17
+ .fraud-card { padding: 20px; border-radius: 15px; background-color: #1e2130; border-left: 5px solid #ff4b4b; margin-bottom: 20px; }
18
+ .legit-card { padding: 20px; border-radius: 15px; background-color: #1e2130; border-left: 5px solid #00ffcc; margin-bottom: 20px; }
 
19
  </style>
20
  """, unsafe_allow_html=True)
21
 
 
102
  h_final = (weights[:, 0:1]*h1 + weights[:, 1:2]*h2 + weights[:, 2:3]*h3 + weights[:, 3:4]*h4)
103
  return self.classifier(h_final)
104
 
105
+ def haversine(lat1, lon1, lat2, lon2):
106
+ r = 6371
107
+ phi1, phi2 = np.radians(lat1), np.radians(lat2)
108
+ dphi = np.radians(lat2 - lat1)
109
+ dlambda = np.radians(lon2 - lon1)
110
+ a = np.sin(dphi/2)**2 + np.cos(phi1)*np.cos(phi2)*np.sin(dlambda/2)**2
111
+ return 2 * r * np.arcsin(np.sqrt(a))
112
+
113
  REPO_ID = "rocky250/FinTech"
114
 
115
  @st.cache_resource
116
  def load_assets():
117
+ w_p = hf_hub_download(repo_id=REPO_ID, filename="proposed_moe_model.pth")
118
+ c_p = hf_hub_download(repo_id=REPO_ID, filename="config.json")
119
+ e_p = hf_hub_download(repo_id=REPO_ID, filename="label_encoders.joblib")
120
+ s_p = hf_hub_download(repo_id=REPO_ID, filename="scaler.joblib")
121
+ with open(c_p, 'r') as f:
122
  config = json.load(f)
123
  model = MoEFraudModel(config['cat_dims'], config['num_cols_map'], config['embed_dim'])
124
+ model.load_state_dict(torch.load(w_p, map_location=torch.device('cpu')))
125
  model.eval()
126
+ return model, joblib.load(e_p), joblib.load(s_p), config
127
 
128
  model, encoders, scaler, config = load_assets()
129
 
130
+ st.title("FinTech Fraud Analysis")
131
+ st.markdown("### MoE Transaction Verifier")
132
 
133
+ with st.form("transaction_form"):
134
+ c1, c2, c3 = st.columns(3)
 
135
  with c1:
136
+ st.subheader("Personal & Card")
137
+ first = st.text_input("First Name", "Jeff")
138
+ last = st.text_input("Last Name", "Elliott")
 
 
 
 
 
 
 
 
 
139
  gender = st.selectbox("Gender", ["M", "F"])
140
+ dob = st.text_input("Date of Birth (DD-MM-YYYY)", "19-03-1968")
141
+ cc_num = st.text_input("CC Number", "2.29E+15")
142
+ job = st.text_input("Job", "Mechanical engineer")
143
+ with c2:
144
+ st.subheader("Transaction Details")
145
+ trans_dt = st.text_input("Transaction Time (DD-MM-YYYY HH:MM)", "21-06-2020 12:14")
146
+ merchant = st.text_input("Merchant", "fraud_Kirlin and Sons")
147
+ category = st.text_input("Category", "personal_care")
148
+ amt = st.number_input("Amount ($)", value=2.86)
149
+ u_time = st.number_input("Unix Time", value=1371816865)
150
+ with c3:
151
+ st.subheader("Geography")
152
+ street = st.text_input("Street", "351 Darlene Green")
153
+ city = st.text_input("City", "Columbia")
154
+ state = st.text_input("State", "SC")
155
+ zip_v = st.text_input("Zip Code", "29209")
156
+ city_pop = st.number_input("City Population", value=333497)
157
+ lat = st.number_input("Lat", value=33.9659)
158
+ lon = st.number_input("Long", value=-80.9355)
159
+ m_lat = st.number_input("Merch Lat", value=33.986391)
160
+ m_lon = st.number_input("Merch Long", value=-81.200714)
161
+
162
+ submit = st.form_submit_button("ANALYZE TRANSACTION")
163
 
164
  if submit:
165
  try:
166
  dt_obj = datetime.strptime(trans_dt, "%d-%m-%Y %H:%M")
167
+ dob_obj = datetime.strptime(dob, "%d-%m-%Y")
168
+ age = (dt_obj - dob_obj).days // 365
169
+ dist = haversine(lat, lon, m_lat, m_lon)
170
 
171
  num_feats = [
172
+ amt, dt_obj.hour, dt_obj.weekday(), 1 if dt_obj.weekday() >= 5 else 0, u_time,
173
+ city_pop, age, 3600, 50.0, 15.0, 150.0, 0.01, 500, lat, lon, m_lat, m_lon, dist, 0
 
 
174
  ]
175
 
176
+ cat_feats = [category, cc_num, gender, job, city, state, zip_v, merchant]
 
177
 
178
  num_input = scaler.transform([num_feats])
 
179
  cat_encoded = []
180
+ for i, col in enumerate(config['cat_dims'].keys()):
181
+ val = str(cat_feats[i])
182
+ if val in encoders[col].classes_:
183
+ cat_encoded.append(encoders[col].transform([val])[0])
184
  else:
185
  cat_encoded.append(0)
186
 
187
  with torch.no_grad():
188
  cat_t = torch.tensor([cat_encoded], dtype=torch.long)
189
  num_t = torch.tensor(num_input, dtype=torch.float32)
190
+ prob = torch.sigmoid(model(cat_t, num_t)).item()
 
191
 
192
+ st.divider()
193
  if prob > 0.5:
194
+ st.markdown(f'<div class="fraud-card"><h2>High Risk Detected!</h2><p>Confidence: {prob*100:.2f}%</p></div>', unsafe_allow_html=True)
 
 
 
 
 
 
 
195
  else:
196
+ st.markdown(f'<div class="legit-card"><h2>Transaction Safe</h2><p>Confidence: {(1-prob)*100:.2f}%</p></div>', unsafe_allow_html=True)
 
 
 
 
 
 
 
197
 
198
  except Exception as e:
199
+ st.error(f"Error: {e}")
200
 
201
+ st.sidebar.info("MoE model utilizing 4 expert networks for fraud detection.")