Update src/streamlit_app.py
Browse files- src/streamlit_app.py +59 -71
src/streamlit_app.py
CHANGED
|
@@ -13,9 +13,10 @@ st.set_page_config(page_title="FinTech Fraud Guard", layout="wide")
|
|
| 13 |
st.markdown("""
|
| 14 |
<style>
|
| 15 |
.main { background-color: #0e1117; }
|
| 16 |
-
.stButton>button { width: 100%; border-radius: 20px; background: linear-gradient(45deg, #ff4b4b, #ff8f8f); color: white; border: none; }
|
| 17 |
-
.fraud-card { padding:
|
| 18 |
-
.legit-card { padding:
|
|
|
|
| 19 |
</style>
|
| 20 |
""", unsafe_allow_html=True)
|
| 21 |
|
|
@@ -94,16 +95,12 @@ class MoEFraudModel(nn.Module):
|
|
| 94 |
def forward(self, cat_input, num_input):
|
| 95 |
all_embs = [self.embeddings[c](cat_input[:, i]) for i, c in enumerate(self.cat_cols)]
|
| 96 |
global_features = torch.cat([torch.cat(all_embs, dim=1), num_input], dim=1)
|
| 97 |
-
|
| 98 |
weights = self.gating_network(global_features)
|
| 99 |
-
|
| 100 |
h1 = self.expert1(self.get_features(cat_input, num_input, self.e1_cols_cat, self.e1_cols_num))
|
| 101 |
h2 = self.expert2(self.get_features(cat_input, num_input, self.e2_cols_cat, self.e2_cols_num))
|
| 102 |
h3 = self.expert3(self.get_features(cat_input, num_input, self.e3_cols_cat, self.e3_cols_num))
|
| 103 |
h4 = self.expert4(self.get_features(cat_input, num_input, self.e4_cols_cat, self.e4_cols_num))
|
| 104 |
-
|
| 105 |
h_final = (weights[:, 0:1]*h1 + weights[:, 1:2]*h2 + weights[:, 2:3]*h3 + weights[:, 3:4]*h4)
|
| 106 |
-
|
| 107 |
return self.classifier(h_final)
|
| 108 |
|
| 109 |
REPO_ID = "rocky250/FinTech"
|
|
@@ -114,82 +111,64 @@ def load_assets():
|
|
| 114 |
config_path = hf_hub_download(repo_id=REPO_ID, filename="config.json")
|
| 115 |
encoders_path = hf_hub_download(repo_id=REPO_ID, filename="label_encoders.joblib")
|
| 116 |
scaler_path = hf_hub_download(repo_id=REPO_ID, filename="scaler.joblib")
|
| 117 |
-
|
| 118 |
with open(config_path, 'r') as f:
|
| 119 |
config = json.load(f)
|
| 120 |
-
|
| 121 |
model = MoEFraudModel(config['cat_dims'], config['num_cols_map'], config['embed_dim'])
|
| 122 |
model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))
|
| 123 |
model.eval()
|
| 124 |
-
|
| 125 |
return model, joblib.load(encoders_path), joblib.load(scaler_path), config
|
| 126 |
|
| 127 |
model, encoders, scaler, config = load_assets()
|
| 128 |
|
| 129 |
-
st.title("FinTech Fraud
|
| 130 |
-
st.markdown("
|
| 131 |
|
| 132 |
-
with st.form("
|
| 133 |
-
|
| 134 |
|
| 135 |
-
with
|
| 136 |
-
st.subheader("
|
| 137 |
-
|
| 138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
gender = st.selectbox("Gender", ["M", "F"])
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
st.
|
| 144 |
-
|
| 145 |
-
merchant = st.text_input("Merchant", "fraud_Kirlin and Sons")
|
| 146 |
-
category = st.text_input("Category", "personal_care")
|
| 147 |
-
amt = st.number_input("Amount ($)", value=2.86)
|
| 148 |
-
|
| 149 |
-
with col3:
|
| 150 |
-
st.subheader("Location Info")
|
| 151 |
-
street = st.text_input("Street", "351 Darlene Green")
|
| 152 |
-
city = st.text_input("City", "Birmingham")
|
| 153 |
-
state = st.text_input("State", "AL")
|
| 154 |
-
zip_code = st.text_input("Zip Code", "35201")
|
| 155 |
|
| 156 |
-
submit = st.form_submit_button("
|
| 157 |
|
| 158 |
if submit:
|
| 159 |
try:
|
| 160 |
dt_obj = datetime.strptime(trans_dt, "%d-%m-%Y %H:%M")
|
| 161 |
|
| 162 |
num_feats = [
|
| 163 |
-
amt,
|
| 164 |
-
dt_obj.
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
dt_obj.timestamp(),
|
| 168 |
-
50000,
|
| 169 |
-
35,
|
| 170 |
-
3600,
|
| 171 |
-
50.0,
|
| 172 |
-
10.0,
|
| 173 |
-
100.0,
|
| 174 |
-
0.02,
|
| 175 |
-
1000,
|
| 176 |
-
33.5,
|
| 177 |
-
-86.8,
|
| 178 |
-
33.6,
|
| 179 |
-
-86.9,
|
| 180 |
-
15.5,
|
| 181 |
-
0
|
| 182 |
]
|
| 183 |
|
| 184 |
-
|
|
|
|
| 185 |
|
| 186 |
num_input = scaler.transform([num_feats])
|
| 187 |
|
| 188 |
cat_encoded = []
|
| 189 |
-
for i,
|
| 190 |
-
val = str(
|
| 191 |
-
if val in encoders[
|
| 192 |
-
cat_encoded.append(encoders[
|
| 193 |
else:
|
| 194 |
cat_encoded.append(0)
|
| 195 |
|
|
@@ -199,21 +178,30 @@ if submit:
|
|
| 199 |
logits = model(cat_t, num_t)
|
| 200 |
prob = torch.sigmoid(logits).item()
|
| 201 |
|
| 202 |
-
st.
|
| 203 |
if prob > 0.5:
|
| 204 |
-
st.markdown(f"""
|
| 205 |
-
<
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
|
|
|
|
|
|
|
|
|
| 209 |
else:
|
| 210 |
-
st.markdown(f"""
|
| 211 |
-
<
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
|
|
|
|
|
|
|
|
|
| 215 |
|
| 216 |
except Exception as e:
|
| 217 |
-
st.error(f"Error
|
| 218 |
|
| 219 |
-
st.sidebar.
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
st.markdown("""
|
| 14 |
<style>
|
| 15 |
.main { background-color: #0e1117; }
|
| 16 |
+
.stButton>button { width: 100%; border-radius: 20px; background: linear-gradient(45deg, #ff4b4b, #ff8f8f); color: white; border: none; font-weight: bold; }
|
| 17 |
+
.fraud-card { padding: 25px; border-radius: 15px; background-color: #1e2130; border-left: 8px solid #ff4b4b; margin-bottom: 20px; box-shadow: 0 4px 15px rgba(255, 75, 75, 0.2); }
|
| 18 |
+
.legit-card { padding: 25px; border-radius: 15px; background-color: #1e2130; border-left: 8px solid #00ffcc; margin-bottom: 20px; box-shadow: 0 4px 15px rgba(0, 255, 204, 0.1); }
|
| 19 |
+
.stTextInput>div>div>input, .stNumberInput>div>div>input { background-color: #262730; color: white; border-radius: 10px; }
|
| 20 |
</style>
|
| 21 |
""", unsafe_allow_html=True)
|
| 22 |
|
|
|
|
| 95 |
def forward(self, cat_input, num_input):
|
| 96 |
all_embs = [self.embeddings[c](cat_input[:, i]) for i, c in enumerate(self.cat_cols)]
|
| 97 |
global_features = torch.cat([torch.cat(all_embs, dim=1), num_input], dim=1)
|
|
|
|
| 98 |
weights = self.gating_network(global_features)
|
|
|
|
| 99 |
h1 = self.expert1(self.get_features(cat_input, num_input, self.e1_cols_cat, self.e1_cols_num))
|
| 100 |
h2 = self.expert2(self.get_features(cat_input, num_input, self.e2_cols_cat, self.e2_cols_num))
|
| 101 |
h3 = self.expert3(self.get_features(cat_input, num_input, self.e3_cols_cat, self.e3_cols_num))
|
| 102 |
h4 = self.expert4(self.get_features(cat_input, num_input, self.e4_cols_cat, self.e4_cols_num))
|
|
|
|
| 103 |
h_final = (weights[:, 0:1]*h1 + weights[:, 1:2]*h2 + weights[:, 2:3]*h3 + weights[:, 3:4]*h4)
|
|
|
|
| 104 |
return self.classifier(h_final)
|
| 105 |
|
| 106 |
REPO_ID = "rocky250/FinTech"
|
|
|
|
| 111 |
config_path = hf_hub_download(repo_id=REPO_ID, filename="config.json")
|
| 112 |
encoders_path = hf_hub_download(repo_id=REPO_ID, filename="label_encoders.joblib")
|
| 113 |
scaler_path = hf_hub_download(repo_id=REPO_ID, filename="scaler.joblib")
|
|
|
|
| 114 |
with open(config_path, 'r') as f:
|
| 115 |
config = json.load(f)
|
|
|
|
| 116 |
model = MoEFraudModel(config['cat_dims'], config['num_cols_map'], config['embed_dim'])
|
| 117 |
model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))
|
| 118 |
model.eval()
|
|
|
|
| 119 |
return model, joblib.load(encoders_path), joblib.load(scaler_path), config
|
| 120 |
|
| 121 |
model, encoders, scaler, config = load_assets()
|
| 122 |
|
| 123 |
+
st.title("FinTech Mixture-of-Experts Fraud Guard")
|
| 124 |
+
st.markdown("---")
|
| 125 |
|
| 126 |
+
with st.form("main_form"):
|
| 127 |
+
c1, c2 = st.columns(2)
|
| 128 |
|
| 129 |
+
with c1:
|
| 130 |
+
st.subheader("Transaction Details")
|
| 131 |
+
trans_dt = st.text_input("Transaction Date & Time", value="21-06-2020 12:14")
|
| 132 |
+
amt = st.number_input("Transaction Amount ($)", value=2.86, format="%.2f")
|
| 133 |
+
cc_num = st.text_input("Credit Card Number", value="2291160000000000")
|
| 134 |
+
merchant = st.text_input("Merchant Name", value="fraud_Kirlin and Sons")
|
| 135 |
+
category = st.text_input("Category", value="personal_care")
|
| 136 |
+
|
| 137 |
+
with c2:
|
| 138 |
+
st.subheader("👤 Cardholder Info")
|
| 139 |
+
col_a, col_b = st.columns(2)
|
| 140 |
+
first_name = col_a.text_input("First Name", value="Jeff")
|
| 141 |
+
last_name = col_b.text_input("Last Name", value="Elliott")
|
| 142 |
gender = st.selectbox("Gender", ["M", "F"])
|
| 143 |
+
street = st.text_input("Street Address", value="351 Darlene Green")
|
| 144 |
+
city = st.text_input("City", value="Birmingham")
|
| 145 |
+
state = st.text_input("State", value="AL")
|
| 146 |
+
zip_code = st.text_input("Zip Code", value="35201")
|
| 147 |
+
job = st.text_input("Job Title", value="Software Engineer")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
+
submit = st.form_submit_button("RUN FRAUD ANALYSIS")
|
| 150 |
|
| 151 |
if submit:
|
| 152 |
try:
|
| 153 |
dt_obj = datetime.strptime(trans_dt, "%d-%m-%Y %H:%M")
|
| 154 |
|
| 155 |
num_feats = [
|
| 156 |
+
amt, dt_obj.hour, dt_obj.weekday(),
|
| 157 |
+
1 if dt_obj.weekday() >= 5 else 0, dt_obj.timestamp(),
|
| 158 |
+
50000, 35, 3600, 50.0, 10.0, 100.0,
|
| 159 |
+
0.02, 1000, 33.5, -86.8, 33.6, -86.9, 15.5, 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
]
|
| 161 |
|
| 162 |
+
cat_order = ['category', 'cc_num', 'gender', 'job', 'city', 'state', 'zip', 'merchant']
|
| 163 |
+
cat_vals = [category, cc_num, gender, job, city, state, zip_code, merchant]
|
| 164 |
|
| 165 |
num_input = scaler.transform([num_feats])
|
| 166 |
|
| 167 |
cat_encoded = []
|
| 168 |
+
for i, col_name in enumerate(cat_order):
|
| 169 |
+
val = str(cat_vals[i])
|
| 170 |
+
if val in encoders[col_name].classes_:
|
| 171 |
+
cat_encoded.append(encoders[col_name].transform([val])[0])
|
| 172 |
else:
|
| 173 |
cat_encoded.append(0)
|
| 174 |
|
|
|
|
| 178 |
logits = model(cat_t, num_t)
|
| 179 |
prob = torch.sigmoid(logits).item()
|
| 180 |
|
| 181 |
+
st.markdown("### Analysis Result")
|
| 182 |
if prob > 0.5:
|
| 183 |
+
st.markdown(f"""
|
| 184 |
+
<div class="fraud-card">
|
| 185 |
+
<h2 style="color:#ff4b4b; margin-top:0;">High Risk Detected</h2>
|
| 186 |
+
<p style="font-size:1.2rem;">Fraud Probability: <strong>{prob*100:.2f}%</strong></p>
|
| 187 |
+
<hr style="border:0.5px solid #ff4b4b33;">
|
| 188 |
+
<p>The MoE Gating Network has flagged this transaction as highly suspicious based on merchant patterns and transaction timing.</p>
|
| 189 |
+
</div>
|
| 190 |
+
""", unsafe_allow_html=True)
|
| 191 |
else:
|
| 192 |
+
st.markdown(f"""
|
| 193 |
+
<div class="legit-card">
|
| 194 |
+
<h2 style="color:#00ffcc; margin-top:0;">Transaction Verified</h2>
|
| 195 |
+
<p style="font-size:1.2rem;">Fraud Probability: <strong>{prob*100:.2f}%</strong></p>
|
| 196 |
+
<hr style="border:0.5px solid #00ffcc33;">
|
| 197 |
+
<p>Our experts have analyzed the spatial and behavioral features and determined this transaction is likely legitimate.</p>
|
| 198 |
+
</div>
|
| 199 |
+
""", unsafe_allow_html=True)
|
| 200 |
|
| 201 |
except Exception as e:
|
| 202 |
+
st.error(f"Validation Error: Ensure Date Format is DD-MM-YYYY HH:MM. Raw Error: {e}")
|
| 203 |
|
| 204 |
+
st.sidebar.title("System Info")
|
| 205 |
+
st.sidebar.info("Model: MoE Neural Network")
|
| 206 |
+
st.sidebar.write(f"Repository: {REPO_ID}")
|
| 207 |
+
st.sidebar.write("Architecture: 4 Experts (Temporal, Behavioral, Frequency, Spatial)")
|