YashChowdhary commited on
Commit
f1b0880
Β·
verified Β·
1 Parent(s): 1915ba3

Update app.py

Browse files

Fix 1: Light/Dark Mode Compatibility
Fix 2: Readable Pie Charts - Problem: Percentages and labels were hard to read

Files changed (1) hide show
  1. app.py +255 -226
app.py CHANGED
@@ -28,31 +28,53 @@ from xgboost import XGBClassifier
28
  from lightgbm import LGBMClassifier
29
  from imblearn.over_sampling import SMOTE
30
 
31
- # Set up matplotlib for dark mode compatibility
32
- plt.rcParams['figure.facecolor'] = '#1a1a2e'
33
- plt.rcParams['axes.facecolor'] = '#16213e'
34
- plt.rcParams['axes.edgecolor'] = '#e0e0e0'
35
- plt.rcParams['axes.labelcolor'] = '#e0e0e0'
36
- plt.rcParams['text.color'] = '#e0e0e0'
37
- plt.rcParams['xtick.color'] = '#e0e0e0'
38
- plt.rcParams['ytick.color'] = '#e0e0e0'
39
- plt.rcParams['grid.color'] = '#3a3a5c'
40
- plt.rcParams['legend.facecolor'] = '#1a1a2e'
41
- plt.rcParams['legend.edgecolor'] = '#e0e0e0'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  # ============================================================================
44
  # DATA LOADING AND PREPROCESSING
45
  # ============================================================================
46
 
47
  def load_and_prepare_data():
48
- """
49
- Load the train and test datasets.
50
- The data is already preprocessed and one-hot encoded.
51
- """
52
  train_df = pd.read_csv('train.csv')
53
  test_df = pd.read_csv('test.csv')
54
 
55
- # Separate features and target
56
  X_train = train_df.drop('fraud', axis=1)
57
  y_train = train_df['fraud']
58
  X_test = test_df.drop('fraud', axis=1)
@@ -62,10 +84,7 @@ def load_and_prepare_data():
62
 
63
 
64
  def apply_smote(X_train, y_train):
65
- """
66
- Apply SMOTE to handle class imbalance.
67
- Fraud cases are rare (~3%), so we oversample the minority class.
68
- """
69
  smote = SMOTE(random_state=42)
70
  X_resampled, y_resampled = smote.fit_resample(X_train, y_train)
71
  return X_resampled, y_resampled
@@ -76,10 +95,7 @@ def apply_smote(X_train, y_train):
76
  # ============================================================================
77
 
78
  def get_models():
79
- """
80
- Define the 4 models we'll compare.
81
- Each model is tuned for imbalanced fraud detection.
82
- """
83
  models = {
84
  'XGBoost': XGBClassifier(
85
  n_estimators=100,
@@ -119,20 +135,20 @@ def get_models():
119
  # ============================================================================
120
 
121
  def train_model(model, X_train, y_train):
122
- """Train a single model and return the fitted model."""
123
  model.fit(X_train, y_train)
124
  return model
125
 
126
 
127
  def evaluate_model(model, X_test, y_test):
128
- """Get predictions and probabilities from a trained model."""
129
  y_pred = model.predict(X_test)
130
  y_proba = model.predict_proba(X_test)[:, 1]
131
  return y_pred, y_proba
132
 
133
 
134
  def get_metrics(y_test, y_pred, y_proba):
135
- """Calculate all relevant metrics for fraud detection."""
136
  metrics = {
137
  'Accuracy': accuracy_score(y_test, y_pred),
138
  'Precision': precision_score(y_test, y_pred, zero_division=0),
@@ -144,7 +160,7 @@ def get_metrics(y_test, y_pred, y_proba):
144
 
145
 
146
  def find_optimal_threshold(y_test, y_proba):
147
- """Find the optimal classification threshold using F1 score."""
148
  thresholds = np.arange(0.1, 0.9, 0.01)
149
  f1_scores = []
150
 
@@ -161,124 +177,138 @@ def find_optimal_threshold(y_test, y_proba):
161
 
162
 
163
  # ============================================================================
164
- # VISUALIZATION FUNCTIONS (Dark Mode Compatible)
165
  # ============================================================================
166
 
167
  def plot_precision_recall_curve(y_test, y_proba, model_name):
168
- """Plot Precision-Recall curve with dark mode colors."""
169
- precision, recall, thresholds = precision_recall_curve(y_test, y_proba)
 
170
  pr_auc = auc(recall, precision)
171
 
172
- fig, ax = plt.subplots(figsize=(8, 6))
173
- ax.plot(recall, precision, '#00d4ff', linewidth=2, label=f'{model_name} (AUC = {pr_auc:.3f})')
174
- ax.fill_between(recall, precision, alpha=0.3, color='#00d4ff')
 
 
175
 
176
  # Baseline
177
  baseline = y_test.mean()
178
- ax.axhline(y=baseline, color='#ff6b6b', linestyle='--', label=f'Baseline = {baseline:.3f}')
 
179
 
180
- ax.set_xlabel('Recall (Fraud Detection Rate)', fontsize=12)
181
- ax.set_ylabel('Precision (True Fraud Rate)', fontsize=12)
182
- ax.set_title(f'Precision-Recall Curve: {model_name}', fontsize=14, fontweight='bold')
183
- ax.legend(loc='best', facecolor='#1a1a2e', edgecolor='#e0e0e0')
184
  ax.set_xlim([0, 1])
185
  ax.set_ylim([0, 1])
186
- ax.grid(True, alpha=0.3)
187
 
188
  plt.tight_layout()
189
  return fig
190
 
191
 
192
  def plot_roc_curve(y_test, y_proba, model_name):
193
- """Plot ROC curve with dark mode colors."""
194
- fpr, tpr, thresholds = roc_curve(y_test, y_proba)
 
195
  roc_auc = auc(fpr, tpr)
196
 
197
- fig, ax = plt.subplots(figsize=(8, 6))
198
- ax.plot(fpr, tpr, '#00d4ff', linewidth=2, label=f'{model_name} (AUC = {roc_auc:.3f})')
199
- ax.fill_between(fpr, tpr, alpha=0.3, color='#00d4ff')
200
- ax.plot([0, 1], [0, 1], '#ff6b6b', linestyle='--', label='Random Classifier')
 
 
 
201
 
202
- ax.set_xlabel('False Positive Rate', fontsize=12)
203
- ax.set_ylabel('True Positive Rate (Recall)', fontsize=12)
204
- ax.set_title(f'ROC Curve: {model_name}', fontsize=14, fontweight='bold')
205
- ax.legend(loc='lower right', facecolor='#1a1a2e', edgecolor='#e0e0e0')
206
  ax.set_xlim([0, 1])
207
  ax.set_ylim([0, 1])
208
- ax.grid(True, alpha=0.3)
209
 
210
  plt.tight_layout()
211
  return fig
212
 
213
 
214
  def plot_confusion_matrix(y_test, y_pred, model_name):
215
- """Plot confusion matrix heatmap with dark mode colors."""
 
216
  cm = confusion_matrix(y_test, y_pred)
217
 
218
- fig, ax = plt.subplots(figsize=(8, 6))
219
 
220
- # Custom colormap for dark mode
221
- cmap = sns.color_palette("Blues", as_cmap=True)
222
-
223
- sns.heatmap(cm, annot=True, fmt='d', cmap=cmap, ax=ax,
224
  xticklabels=['Legitimate', 'Fraud'],
225
  yticklabels=['Legitimate', 'Fraud'],
226
- annot_kws={'size': 16, 'color': 'white'},
227
- cbar_kws={'label': 'Count'})
 
228
 
229
- ax.set_xlabel('Predicted Label', fontsize=12)
230
- ax.set_ylabel('True Label', fontsize=12)
231
- ax.set_title(f'Confusion Matrix: {model_name}', fontsize=14, fontweight='bold')
232
 
233
- # Add summary text
234
  tn, fp, fn, tp = cm.ravel()
235
- text = f"TN: {tn} | FP: {fp}\nFN: {fn} | TP: {tp}"
236
- ax.text(1.35, 0.5, text, transform=ax.transAxes, fontsize=10,
237
- verticalalignment='center',
238
- bbox=dict(boxstyle='round', facecolor='#2d2d44', edgecolor='#e0e0e0', alpha=0.8))
239
 
240
  plt.tight_layout()
241
  return fig
242
 
243
 
244
  def plot_feature_importance(model, feature_names, model_name):
245
- """Plot top 15 most important features with dark mode colors."""
 
246
  fig, ax = plt.subplots(figsize=(10, 8))
247
 
248
- # Get feature importances based on model type
249
  if hasattr(model, 'feature_importances_'):
250
  importances = model.feature_importances_
251
  elif hasattr(model, 'coef_'):
252
  importances = np.abs(model.coef_[0])
253
  else:
254
  ax.text(0.5, 0.5, 'Feature importance not available',
255
- ha='center', va='center', fontsize=14, color='#e0e0e0')
 
256
  return fig
257
 
258
- # Create dataframe and sort
259
  importance_df = pd.DataFrame({
260
  'Feature': feature_names,
261
  'Importance': importances
262
  }).sort_values('Importance', ascending=True).tail(15)
263
 
264
- # Gradient colors
265
- colors = plt.cm.Blues(np.linspace(0.4, 0.9, len(importance_df)))
266
- ax.barh(importance_df['Feature'], importance_df['Importance'], color=colors)
 
 
 
 
 
267
 
268
- ax.set_xlabel('Importance Score', fontsize=12)
269
- ax.set_title(f'Top 15 Feature Importances: {model_name}', fontsize=14, fontweight='bold')
270
- ax.grid(True, alpha=0.3, axis='x')
271
 
272
  plt.tight_layout()
273
  return fig
274
 
275
 
276
  def plot_threshold_analysis(y_test, y_proba, model_name):
277
- """Plot threshold analysis with dark mode colors."""
 
278
  thresholds = np.arange(0.05, 0.95, 0.01)
279
- precisions = []
280
- recalls = []
281
- f1_scores = []
282
 
283
  for thresh in thresholds:
284
  y_pred_thresh = (y_proba >= thresh).astype(int)
@@ -291,46 +321,104 @@ def plot_threshold_analysis(y_test, y_proba, model_name):
291
 
292
  fig, ax = plt.subplots(figsize=(10, 6))
293
 
294
- ax.plot(thresholds, precisions, '#00d4ff', linewidth=2, label='Precision')
295
- ax.plot(thresholds, recalls, '#00ff88', linewidth=2, label='Recall')
296
- ax.plot(thresholds, f1_scores, '#ff6b6b', linewidth=2, label='F1 Score')
 
 
 
 
297
 
298
- ax.axvline(x=best_threshold, color='#ffd700', linestyle='--',
299
- label=f'Optimal Threshold = {best_threshold:.2f}')
300
- ax.axvline(x=0.5, color='#888888', linestyle=':', alpha=0.7, label='Default (0.5)')
301
 
302
- ax.set_xlabel('Classification Threshold', fontsize=12)
303
- ax.set_ylabel('Score', fontsize=12)
304
- ax.set_title(f'Threshold Analysis: {model_name}', fontsize=14, fontweight='bold')
305
- ax.legend(loc='best', facecolor='#1a1a2e', edgecolor='#e0e0e0')
306
  ax.set_xlim([0, 1])
307
  ax.set_ylim([0, 1])
308
- ax.grid(True, alpha=0.3)
309
 
310
  plt.tight_layout()
311
  return fig
312
 
313
 
314
  def plot_class_distribution(train_df, test_df):
315
- """Plot class distribution with dark mode colors."""
316
- fig, axes = plt.subplots(1, 2, figsize=(12, 5))
 
317
 
318
- colors = ['#00ff88', '#ff6b6b']
319
- explode = (0, 0.1)
320
 
321
  # Training data
322
- train_counts = train_df['fraud'].value_counts()
323
- axes[0].pie(train_counts, labels=['Legitimate', 'Fraud'], autopct='%1.1f%%',
324
- colors=colors, explode=explode, shadow=True, startangle=90,
325
- textprops={'color': '#e0e0e0', 'fontsize': 11})
326
- axes[0].set_title('Training Data Distribution', fontsize=14, fontweight='bold')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327
 
328
  # Test data
329
- test_counts = test_df['fraud'].value_counts()
330
- axes[1].pie(test_counts, labels=['Legitimate', 'Fraud'], autopct='%1.1f%%',
331
- colors=colors, explode=explode, shadow=True, startangle=90,
332
- textprops={'color': '#e0e0e0', 'fontsize': 11})
333
- axes[1].set_title('Test Data Distribution', fontsize=14, fontweight='bold')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
 
335
  fig.suptitle('Class Imbalance in Fraud Detection Dataset', fontsize=16, fontweight='bold', y=1.02)
336
  plt.tight_layout()
@@ -338,40 +426,42 @@ def plot_class_distribution(train_df, test_df):
338
 
339
 
340
  def plot_model_comparison(all_metrics):
341
- """Bar chart comparing all 4 models with dark mode colors."""
 
342
  fig, ax = plt.subplots(figsize=(12, 6))
343
 
344
- models = list(all_metrics.keys())
345
  metrics = ['Accuracy', 'Precision', 'Recall', 'F1 Score', 'ROC AUC']
346
 
347
  x = np.arange(len(metrics))
348
  width = 0.2
349
 
350
- colors = ['#00d4ff', '#00ff88', '#ff6b6b', '#ffd700']
351
 
352
- for i, model in enumerate(models):
353
  values = [all_metrics[model][m] for m in metrics]
354
- bars = ax.bar(x + i*width, values, width, label=model, color=colors[i])
 
355
 
356
- # Add value labels on bars
357
  for bar, v in zip(bars, values):
358
- ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,
359
- f'{v:.2f}', ha='center', va='bottom', fontsize=8, color='#e0e0e0')
360
 
361
- ax.set_ylabel('Score', fontsize=12)
362
- ax.set_title('Model Performance Comparison', fontsize=14, fontweight='bold')
363
  ax.set_xticks(x + width * 1.5)
364
- ax.set_xticklabels(metrics)
365
- ax.legend(loc='upper left', bbox_to_anchor=(1, 1), facecolor='#1a1a2e', edgecolor='#e0e0e0')
366
  ax.set_ylim([0, 1.15])
367
- ax.grid(True, alpha=0.3, axis='y')
368
 
369
  plt.tight_layout()
370
  return fig
371
 
372
 
373
  # ============================================================================
374
- # LOAD DATA AND TRAIN MODELS AT STARTUP
375
  # ============================================================================
376
 
377
  print("Loading data...")
@@ -380,7 +470,7 @@ X_train, X_test, y_train, y_test, train_df, test_df = load_and_prepare_data()
380
  print("Applying SMOTE to handle class imbalance...")
381
  X_train_balanced, y_train_balanced = apply_smote(X_train, y_train)
382
 
383
- print("Training models (this may take a moment)...")
384
  models = get_models()
385
  trained_models = {}
386
  all_metrics = {}
@@ -399,12 +489,12 @@ print("Models trained successfully!")
399
 
400
 
401
  # ============================================================================
402
- # GRADIO INTERFACE FUNCTIONS
403
  # ============================================================================
404
 
405
  def get_data_overview():
406
- """Return dataset summary."""
407
- summary = f"""
408
  ## Dataset Overview
409
 
410
  ### Training Data
@@ -419,17 +509,16 @@ def get_data_overview():
419
 
420
  ### Features
421
  - **Number of Features:** {X_train.shape[1]}
422
- - **Feature Types:** All numeric (pre-processed and one-hot encoded)
423
 
424
  ### Class Imbalance Handling
425
  - Applied **SMOTE** (Synthetic Minority Over-sampling Technique)
426
  - Training samples after SMOTE: {len(X_train_balanced):,}
427
  """
428
- return summary
429
 
430
 
431
  def update_model_display(model_name):
432
- """Update metrics display when model is selected."""
433
  metrics = all_metrics[model_name]
434
  y_pred = all_predictions[model_name]
435
  y_proba = all_probabilities[model_name]
@@ -437,7 +526,7 @@ def update_model_display(model_name):
437
  best_thresh, best_f1, _, _ = find_optimal_threshold(y_test, y_proba)
438
 
439
  metrics_text = f"""
440
- ## {model_name} Performance Metrics
441
 
442
  | Metric | Score |
443
  |--------|-------|
@@ -460,7 +549,7 @@ def update_model_display(model_name):
460
 
461
 
462
  def get_selected_plot(model_name, plot_type):
463
- """Generate selected plot for chosen model."""
464
  y_proba = all_probabilities[model_name]
465
  y_pred = all_predictions[model_name]
466
 
@@ -478,53 +567,31 @@ def get_selected_plot(model_name, plot_type):
478
 
479
 
480
  def get_comparison_results():
481
- """Generate comparison table and plot."""
482
  comparison_df = pd.DataFrame(all_metrics).T.round(4)
483
-
484
  best_models = comparison_df.idxmax()
485
 
486
- summary = "## Model Comparison Summary\n\n"
487
- summary += "| Metric | Best Model | Score |\n|--------|------------|-------|\n"
488
  for metric in comparison_df.columns:
489
  best = best_models[metric]
490
  score = comparison_df.loc[best, metric]
491
- summary += f"| {metric} | {best} | {score:.4f} |\n"
492
 
493
  return comparison_df.to_markdown(), summary, plot_model_comparison(all_metrics)
494
 
495
 
496
  def update_threshold_plot(model_name):
497
- """Update threshold analysis plot."""
498
- y_proba = all_probabilities[model_name]
499
- return plot_threshold_analysis(y_test, y_proba, model_name)
500
 
501
 
502
- # ============================================================================
503
- # GRADIO UI LAYOUT
504
- # ============================================================================
505
-
506
- # Use a dark-friendly theme
507
- with gr.Blocks(
508
- title="Auto Insurance Fraud Detection",
509
- theme=gr.themes.Base(
510
- primary_hue="blue",
511
- secondary_hue="slate",
512
- neutral_hue="slate",
513
- ).set(
514
- body_background_fill="#0f0f1a",
515
- body_background_fill_dark="#0f0f1a",
516
- block_background_fill="#1a1a2e",
517
- block_background_fill_dark="#1a1a2e",
518
- border_color_primary="#3a3a5c",
519
- border_color_primary_dark="#3a3a5c",
520
- )
521
- ) as demo:
522
 
523
  gr.Markdown("""
524
  # πŸš— Auto Insurance Claims Fraud Detection
525
 
526
- This application demonstrates machine learning models for detecting fraudulent auto insurance claims.
527
- The models are trained on historical claims data and can predict whether a claim is likely fraudulent.
528
 
529
  **Models:** XGBoost | LightGBM | Random Forest | Logistic Regression
530
  """)
@@ -533,7 +600,7 @@ with gr.Blocks(
533
  # Tab 1: Data Overview
534
  with gr.TabItem("πŸ“Š Data Overview"):
535
  gr.Markdown(get_data_overview())
536
- dist_plot = gr.Plot(value=plot_class_distribution(train_df, test_df))
537
 
538
  # Tab 2: Model Evaluation
539
  with gr.TabItem("🎯 Model Evaluation"):
@@ -562,69 +629,42 @@ with gr.Blocks(
562
  plot = get_selected_plot(model_name, plot_type)
563
  return metrics, report, plot
564
 
565
- model_selector.change(
566
- fn=update_all,
567
- inputs=[model_selector, plot_selector],
568
- outputs=[metrics_display, report_display, plot_display]
569
- )
570
- plot_selector.change(
571
- fn=update_all,
572
- inputs=[model_selector, plot_selector],
573
- outputs=[metrics_display, report_display, plot_display]
574
- )
575
-
576
- demo.load(
577
- fn=update_all,
578
- inputs=[model_selector, plot_selector],
579
- outputs=[metrics_display, report_display, plot_display]
580
- )
581
 
582
- # Tab 3: Model Comparison
583
  with gr.TabItem("πŸ“ˆ Compare Models"):
584
- gr.Markdown("## All Models Performance Comparison")
585
-
586
  comparison_table, comparison_summary, comparison_plot = get_comparison_results()
587
-
588
  gr.Markdown(comparison_summary)
589
  gr.Markdown(comparison_table)
590
  gr.Plot(value=comparison_plot)
591
 
592
- # Tab 4: Threshold Analysis
593
  with gr.TabItem("βš–οΈ Threshold Optimization"):
594
  gr.Markdown("""
595
- ## Finding the Optimal Classification Threshold
596
 
597
- In fraud detection, the default 0.5 threshold often isn't optimal.
598
- We need to balance catching frauds (recall) vs. false alarms (precision).
599
- The optimal threshold maximizes F1 score.
600
  """)
601
 
602
- thresh_model = gr.Dropdown(
603
- choices=list(models.keys()),
604
- value="XGBoost",
605
- label="Select Model for Threshold Analysis"
606
- )
607
-
608
  thresh_plot = gr.Plot()
609
 
610
- thresh_model.change(
611
- fn=update_threshold_plot,
612
- inputs=[thresh_model],
613
- outputs=[thresh_plot]
614
- )
615
 
616
- demo.load(
617
- fn=update_threshold_plot,
618
- inputs=[thresh_model],
619
- outputs=[thresh_plot]
620
- )
621
-
622
- # Optimal thresholds table
623
- thresh_summary = "### Optimal Thresholds by Model\n\n| Model | Optimal Threshold | F1 at Optimal |\n|-------|-------------------|---------------|\n"
624
  for name in models.keys():
625
- opt_thresh, opt_f1, _, _ = find_optimal_threshold(y_test, all_probabilities[name])
626
- thresh_summary += f"| {name} | {opt_thresh:.2f} | {opt_f1:.4f} |\n"
627
-
628
  gr.Markdown(thresh_summary)
629
 
630
  # Tab 5: About
@@ -633,30 +673,19 @@ with gr.Blocks(
633
  ## About This Project
634
 
635
  ### Business Context
636
- Auto insurance fraud costs the industry billions of dollars annually.
637
- This project builds ML models to flag potentially fraudulent claims.
638
 
639
- ### Technical Approach
640
- 1. **Data Preparation:** 46 features describing claims and customers
641
- 2. **Class Imbalance:** ~3% fraud rate, handled with SMOTE
642
- 3. **Model Training:** Four algorithms compared
643
- 4. **Evaluation:** Precision-Recall focus due to imbalance
644
- 5. **Threshold Optimization:** Find optimal cutoff for business needs
645
-
646
- ### Models Used
647
  - **XGBoost:** Gradient boosting, excellent for tabular data
648
- - **LightGBM:** Fast, memory efficient gradient boosting
649
- - **Random Forest:** Robust ensemble of decision trees
650
- - **Logistic Regression:** Interpretable linear baseline
651
 
652
  ### Key Metrics
653
- - **Precision:** Of flagged claims, how many are actually fraud
654
- - **Recall:** Of actual frauds, how many did we catch
655
- - **F1 Score:** Harmonic mean balancing both metrics
656
- - **ROC AUC:** Overall discrimination ability
657
  """)
658
 
659
-
660
- # Launch
661
  if __name__ == "__main__":
662
  demo.launch()
 
28
  from lightgbm import LGBMClassifier
29
  from imblearn.over_sampling import SMOTE
30
 
31
+
32
+ # ============================================================================
33
+ # PLOT STYLE CONFIGURATION
34
+ # Use white background for universal readability in both light and dark modes
35
+ # ============================================================================
36
+
37
+ def setup_plot_style():
38
+ """Configure matplotlib for clean, readable plots."""
39
+ plt.rcParams.update({
40
+ 'figure.facecolor': 'white',
41
+ 'axes.facecolor': 'white',
42
+ 'axes.edgecolor': '#333333',
43
+ 'axes.labelcolor': '#333333',
44
+ 'text.color': '#333333',
45
+ 'xtick.color': '#333333',
46
+ 'ytick.color': '#333333',
47
+ 'grid.color': '#cccccc',
48
+ 'grid.alpha': 0.5,
49
+ 'legend.facecolor': 'white',
50
+ 'legend.edgecolor': '#cccccc',
51
+ 'font.size': 11,
52
+ 'axes.titlesize': 14,
53
+ 'axes.labelsize': 12,
54
+ })
55
+
56
+ setup_plot_style()
57
+
58
+ # Color palette - vibrant colors that work on white background
59
+ COLORS = {
60
+ 'primary': '#2563eb', # Blue
61
+ 'success': '#16a34a', # Green
62
+ 'danger': '#dc2626', # Red
63
+ 'warning': '#f59e0b', # Amber
64
+ 'purple': '#9333ea', # Purple
65
+ 'cyan': '#0891b2', # Cyan
66
+ }
67
+
68
 
69
  # ============================================================================
70
  # DATA LOADING AND PREPROCESSING
71
  # ============================================================================
72
 
73
  def load_and_prepare_data():
74
+ """Load the train and test datasets."""
 
 
 
75
  train_df = pd.read_csv('train.csv')
76
  test_df = pd.read_csv('test.csv')
77
 
 
78
  X_train = train_df.drop('fraud', axis=1)
79
  y_train = train_df['fraud']
80
  X_test = test_df.drop('fraud', axis=1)
 
84
 
85
 
86
  def apply_smote(X_train, y_train):
87
+ """Apply SMOTE to handle class imbalance."""
 
 
 
88
  smote = SMOTE(random_state=42)
89
  X_resampled, y_resampled = smote.fit_resample(X_train, y_train)
90
  return X_resampled, y_resampled
 
95
  # ============================================================================
96
 
97
  def get_models():
98
+ """Define the 4 models for comparison."""
 
 
 
99
  models = {
100
  'XGBoost': XGBClassifier(
101
  n_estimators=100,
 
135
  # ============================================================================
136
 
137
  def train_model(model, X_train, y_train):
138
+ """Train a model."""
139
  model.fit(X_train, y_train)
140
  return model
141
 
142
 
143
  def evaluate_model(model, X_test, y_test):
144
+ """Get predictions and probabilities."""
145
  y_pred = model.predict(X_test)
146
  y_proba = model.predict_proba(X_test)[:, 1]
147
  return y_pred, y_proba
148
 
149
 
150
  def get_metrics(y_test, y_pred, y_proba):
151
+ """Calculate evaluation metrics."""
152
  metrics = {
153
  'Accuracy': accuracy_score(y_test, y_pred),
154
  'Precision': precision_score(y_test, y_pred, zero_division=0),
 
160
 
161
 
162
  def find_optimal_threshold(y_test, y_proba):
163
+ """Find optimal threshold using F1 score."""
164
  thresholds = np.arange(0.1, 0.9, 0.01)
165
  f1_scores = []
166
 
 
177
 
178
 
179
  # ============================================================================
180
+ # VISUALIZATION FUNCTIONS
181
  # ============================================================================
182
 
183
  def plot_precision_recall_curve(y_test, y_proba, model_name):
184
+ """Plot Precision-Recall curve."""
185
+ setup_plot_style()
186
+ precision, recall, _ = precision_recall_curve(y_test, y_proba)
187
  pr_auc = auc(recall, precision)
188
 
189
+ fig, ax = plt.subplots(figsize=(9, 6))
190
+
191
+ ax.plot(recall, precision, color=COLORS['primary'], linewidth=2.5,
192
+ label=f'{model_name} (AUC = {pr_auc:.3f})')
193
+ ax.fill_between(recall, precision, alpha=0.2, color=COLORS['primary'])
194
 
195
  # Baseline
196
  baseline = y_test.mean()
197
+ ax.axhline(y=baseline, color=COLORS['danger'], linestyle='--', linewidth=2,
198
+ label=f'Random Baseline = {baseline:.3f}')
199
 
200
+ ax.set_xlabel('Recall (Fraud Detection Rate)', fontweight='bold')
201
+ ax.set_ylabel('Precision (True Fraud Rate)', fontweight='bold')
202
+ ax.set_title(f'Precision-Recall Curve: {model_name}', fontsize=15, fontweight='bold', pad=15)
203
+ ax.legend(loc='upper right', fontsize=11, framealpha=0.95)
204
  ax.set_xlim([0, 1])
205
  ax.set_ylim([0, 1])
206
+ ax.grid(True, alpha=0.4)
207
 
208
  plt.tight_layout()
209
  return fig
210
 
211
 
212
  def plot_roc_curve(y_test, y_proba, model_name):
213
+ """Plot ROC curve."""
214
+ setup_plot_style()
215
+ fpr, tpr, _ = roc_curve(y_test, y_proba)
216
  roc_auc = auc(fpr, tpr)
217
 
218
+ fig, ax = plt.subplots(figsize=(9, 6))
219
+
220
+ ax.plot(fpr, tpr, color=COLORS['primary'], linewidth=2.5,
221
+ label=f'{model_name} (AUC = {roc_auc:.3f})')
222
+ ax.fill_between(fpr, tpr, alpha=0.2, color=COLORS['primary'])
223
+ ax.plot([0, 1], [0, 1], color=COLORS['danger'], linestyle='--', linewidth=2,
224
+ label='Random Classifier')
225
 
226
+ ax.set_xlabel('False Positive Rate', fontweight='bold')
227
+ ax.set_ylabel('True Positive Rate (Recall)', fontweight='bold')
228
+ ax.set_title(f'ROC Curve: {model_name}', fontsize=15, fontweight='bold', pad=15)
229
+ ax.legend(loc='lower right', fontsize=11, framealpha=0.95)
230
  ax.set_xlim([0, 1])
231
  ax.set_ylim([0, 1])
232
+ ax.grid(True, alpha=0.4)
233
 
234
  plt.tight_layout()
235
  return fig
236
 
237
 
238
  def plot_confusion_matrix(y_test, y_pred, model_name):
239
+ """Plot confusion matrix heatmap."""
240
+ setup_plot_style()
241
  cm = confusion_matrix(y_test, y_pred)
242
 
243
+ fig, ax = plt.subplots(figsize=(9, 7))
244
 
245
+ # Use a colormap with good contrast
246
+ sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax,
 
 
247
  xticklabels=['Legitimate', 'Fraud'],
248
  yticklabels=['Legitimate', 'Fraud'],
249
+ annot_kws={'size': 18, 'fontweight': 'bold'},
250
+ linewidths=2, linecolor='white',
251
+ cbar_kws={'label': 'Count', 'shrink': 0.8})
252
 
253
+ ax.set_xlabel('Predicted Label', fontweight='bold', fontsize=12)
254
+ ax.set_ylabel('True Label', fontweight='bold', fontsize=12)
255
+ ax.set_title(f'Confusion Matrix: {model_name}', fontsize=15, fontweight='bold', pad=15)
256
 
257
+ # Summary box
258
  tn, fp, fn, tp = cm.ravel()
259
+ summary = f"True Neg: {tn:,}\nFalse Pos: {fp:,}\nFalse Neg: {fn:,}\nTrue Pos: {tp:,}"
260
+ ax.text(1.25, 0.5, summary, transform=ax.transAxes, fontsize=11,
261
+ verticalalignment='center', fontfamily='monospace',
262
+ bbox=dict(boxstyle='round,pad=0.5', facecolor='#f0f0f0', edgecolor='#cccccc'))
263
 
264
  plt.tight_layout()
265
  return fig
266
 
267
 
268
  def plot_feature_importance(model, feature_names, model_name):
269
+ """Plot top 15 most important features."""
270
+ setup_plot_style()
271
  fig, ax = plt.subplots(figsize=(10, 8))
272
 
273
+ # Get feature importances
274
  if hasattr(model, 'feature_importances_'):
275
  importances = model.feature_importances_
276
  elif hasattr(model, 'coef_'):
277
  importances = np.abs(model.coef_[0])
278
  else:
279
  ax.text(0.5, 0.5, 'Feature importance not available',
280
+ ha='center', va='center', fontsize=14)
281
+ ax.set_facecolor('white')
282
  return fig
283
 
284
+ # Create and sort dataframe
285
  importance_df = pd.DataFrame({
286
  'Feature': feature_names,
287
  'Importance': importances
288
  }).sort_values('Importance', ascending=True).tail(15)
289
 
290
+ # Gradient blue bars
291
+ colors = plt.cm.Blues(np.linspace(0.4, 0.85, len(importance_df)))
292
+ bars = ax.barh(importance_df['Feature'], importance_df['Importance'], color=colors, edgecolor='#333333', linewidth=0.5)
293
+
294
+ # Add value labels
295
+ for bar, val in zip(bars, importance_df['Importance']):
296
+ ax.text(bar.get_width() + 0.001, bar.get_y() + bar.get_height()/2,
297
+ f'{val:.3f}', va='center', fontsize=9)
298
 
299
+ ax.set_xlabel('Importance Score', fontweight='bold')
300
+ ax.set_title(f'Top 15 Feature Importances: {model_name}', fontsize=15, fontweight='bold', pad=15)
301
+ ax.grid(True, alpha=0.4, axis='x')
302
 
303
  plt.tight_layout()
304
  return fig
305
 
306
 
307
  def plot_threshold_analysis(y_test, y_proba, model_name):
308
+ """Plot threshold analysis."""
309
+ setup_plot_style()
310
  thresholds = np.arange(0.05, 0.95, 0.01)
311
+ precisions, recalls, f1_scores = [], [], []
 
 
312
 
313
  for thresh in thresholds:
314
  y_pred_thresh = (y_proba >= thresh).astype(int)
 
321
 
322
  fig, ax = plt.subplots(figsize=(10, 6))
323
 
324
+ ax.plot(thresholds, precisions, color=COLORS['primary'], linewidth=2.5, label='Precision')
325
+ ax.plot(thresholds, recalls, color=COLORS['success'], linewidth=2.5, label='Recall')
326
+ ax.plot(thresholds, f1_scores, color=COLORS['danger'], linewidth=2.5, label='F1 Score')
327
+
328
+ ax.axvline(x=best_threshold, color=COLORS['warning'], linestyle='--', linewidth=2,
329
+ label=f'Optimal = {best_threshold:.2f}')
330
+ ax.axvline(x=0.5, color='#888888', linestyle=':', linewidth=1.5, label='Default (0.5)')
331
 
332
+ # Mark optimal point
333
+ ax.scatter([best_threshold], [f1_scores[best_idx]], color=COLORS['warning'], s=100, zorder=5)
 
334
 
335
+ ax.set_xlabel('Classification Threshold', fontweight='bold')
336
+ ax.set_ylabel('Score', fontweight='bold')
337
+ ax.set_title(f'Threshold Analysis: {model_name}', fontsize=15, fontweight='bold', pad=15)
338
+ ax.legend(loc='center right', fontsize=11, framealpha=0.95)
339
  ax.set_xlim([0, 1])
340
  ax.set_ylim([0, 1])
341
+ ax.grid(True, alpha=0.4)
342
 
343
  plt.tight_layout()
344
  return fig
345
 
346
 
347
  def plot_class_distribution(train_df, test_df):
348
+ """Plot class distribution with clear, readable labels."""
349
+ setup_plot_style()
350
+ fig, axes = plt.subplots(1, 2, figsize=(14, 6))
351
 
352
+ colors = [COLORS['success'], COLORS['danger']]
353
+ explode = (0, 0.08)
354
 
355
  # Training data
356
+ train_fraud = train_df['fraud'].sum()
357
+ train_legit = len(train_df) - train_fraud
358
+ train_sizes = [train_legit, train_fraud]
359
+ train_pct = [train_legit/len(train_df)*100, train_fraud/len(train_df)*100]
360
+
361
+ wedges1, texts1, autotexts1 = axes[0].pie(
362
+ train_sizes,
363
+ explode=explode,
364
+ colors=colors,
365
+ autopct='%1.1f%%',
366
+ startangle=90,
367
+ shadow=False,
368
+ wedgeprops={'edgecolor': 'white', 'linewidth': 2}
369
+ )
370
+
371
+ # Style the percentage text
372
+ for autotext in autotexts1:
373
+ autotext.set_color('white')
374
+ autotext.set_fontsize(14)
375
+ autotext.set_fontweight('bold')
376
+
377
+ axes[0].set_title('Training Data Distribution', fontsize=14, fontweight='bold', pad=10)
378
+
379
+ # Add legend with counts
380
+ axes[0].legend(
381
+ wedges1,
382
+ [f'Legitimate: {train_legit:,} ({train_pct[0]:.1f}%)',
383
+ f'Fraud: {train_fraud:,} ({train_pct[1]:.1f}%)'],
384
+ loc='lower center',
385
+ bbox_to_anchor=(0.5, -0.15),
386
+ fontsize=11,
387
+ framealpha=0.95
388
+ )
389
 
390
  # Test data
391
+ test_fraud = test_df['fraud'].sum()
392
+ test_legit = len(test_df) - test_fraud
393
+ test_sizes = [test_legit, test_fraud]
394
+ test_pct = [test_legit/len(test_df)*100, test_fraud/len(test_df)*100]
395
+
396
+ wedges2, texts2, autotexts2 = axes[1].pie(
397
+ test_sizes,
398
+ explode=explode,
399
+ colors=colors,
400
+ autopct='%1.1f%%',
401
+ startangle=90,
402
+ shadow=False,
403
+ wedgeprops={'edgecolor': 'white', 'linewidth': 2}
404
+ )
405
+
406
+ for autotext in autotexts2:
407
+ autotext.set_color('white')
408
+ autotext.set_fontsize(14)
409
+ autotext.set_fontweight('bold')
410
+
411
+ axes[1].set_title('Test Data Distribution', fontsize=14, fontweight='bold', pad=10)
412
+
413
+ axes[1].legend(
414
+ wedges2,
415
+ [f'Legitimate: {test_legit:,} ({test_pct[0]:.1f}%)',
416
+ f'Fraud: {test_fraud:,} ({test_pct[1]:.1f}%)'],
417
+ loc='lower center',
418
+ bbox_to_anchor=(0.5, -0.15),
419
+ fontsize=11,
420
+ framealpha=0.95
421
+ )
422
 
423
  fig.suptitle('Class Imbalance in Fraud Detection Dataset', fontsize=16, fontweight='bold', y=1.02)
424
  plt.tight_layout()
 
426
 
427
 
428
  def plot_model_comparison(all_metrics):
429
+ """Bar chart comparing all models."""
430
+ setup_plot_style()
431
  fig, ax = plt.subplots(figsize=(12, 6))
432
 
433
+ models_list = list(all_metrics.keys())
434
  metrics = ['Accuracy', 'Precision', 'Recall', 'F1 Score', 'ROC AUC']
435
 
436
  x = np.arange(len(metrics))
437
  width = 0.2
438
 
439
+ colors = [COLORS['primary'], COLORS['success'], COLORS['danger'], COLORS['purple']]
440
 
441
+ for i, model in enumerate(models_list):
442
  values = [all_metrics[model][m] for m in metrics]
443
+ bars = ax.bar(x + i*width, values, width, label=model, color=colors[i],
444
+ edgecolor='white', linewidth=0.5)
445
 
446
+ # Add value labels
447
  for bar, v in zip(bars, values):
448
+ ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,
449
+ f'{v:.2f}', ha='center', va='bottom', fontsize=9, fontweight='bold')
450
 
451
+ ax.set_ylabel('Score', fontweight='bold')
452
+ ax.set_title('Model Performance Comparison', fontsize=15, fontweight='bold', pad=15)
453
  ax.set_xticks(x + width * 1.5)
454
+ ax.set_xticklabels(metrics, fontweight='bold')
455
+ ax.legend(loc='upper right', fontsize=10, framealpha=0.95)
456
  ax.set_ylim([0, 1.15])
457
+ ax.grid(True, alpha=0.4, axis='y')
458
 
459
  plt.tight_layout()
460
  return fig
461
 
462
 
463
  # ============================================================================
464
+ # LOAD DATA AND TRAIN MODELS
465
  # ============================================================================
466
 
467
  print("Loading data...")
 
470
  print("Applying SMOTE to handle class imbalance...")
471
  X_train_balanced, y_train_balanced = apply_smote(X_train, y_train)
472
 
473
+ print("Training models...")
474
  models = get_models()
475
  trained_models = {}
476
  all_metrics = {}
 
489
 
490
 
491
  # ============================================================================
492
+ # GRADIO INTERFACE
493
  # ============================================================================
494
 
495
  def get_data_overview():
496
+ """Dataset summary."""
497
+ return f"""
498
  ## Dataset Overview
499
 
500
  ### Training Data
 
509
 
510
  ### Features
511
  - **Number of Features:** {X_train.shape[1]}
512
+ - **Feature Types:** All numeric (pre-processed)
513
 
514
  ### Class Imbalance Handling
515
  - Applied **SMOTE** (Synthetic Minority Over-sampling Technique)
516
  - Training samples after SMOTE: {len(X_train_balanced):,}
517
  """
 
518
 
519
 
520
  def update_model_display(model_name):
521
+ """Update metrics when model is selected."""
522
  metrics = all_metrics[model_name]
523
  y_pred = all_predictions[model_name]
524
  y_proba = all_probabilities[model_name]
 
526
  best_thresh, best_f1, _, _ = find_optimal_threshold(y_test, y_proba)
527
 
528
  metrics_text = f"""
529
+ ## {model_name} Performance
530
 
531
  | Metric | Score |
532
  |--------|-------|
 
549
 
550
 
551
  def get_selected_plot(model_name, plot_type):
552
+ """Generate selected plot."""
553
  y_proba = all_probabilities[model_name]
554
  y_pred = all_predictions[model_name]
555
 
 
567
 
568
 
569
  def get_comparison_results():
570
+ """Generate comparison."""
571
  comparison_df = pd.DataFrame(all_metrics).T.round(4)
 
572
  best_models = comparison_df.idxmax()
573
 
574
+ summary = "## Best Model by Metric\n\n| Metric | Best Model | Score |\n|--------|------------|-------|\n"
 
575
  for metric in comparison_df.columns:
576
  best = best_models[metric]
577
  score = comparison_df.loc[best, metric]
578
+ summary += f"| {metric} | **{best}** | {score:.4f} |\n"
579
 
580
  return comparison_df.to_markdown(), summary, plot_model_comparison(all_metrics)
581
 
582
 
583
  def update_threshold_plot(model_name):
584
+ """Update threshold plot."""
585
+ return plot_threshold_analysis(y_test, all_probabilities[model_name], model_name)
 
586
 
587
 
588
+ # Build UI
589
+ with gr.Blocks(title="Auto Insurance Fraud Detection", theme=gr.themes.Soft()) as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
590
 
591
  gr.Markdown("""
592
  # πŸš— Auto Insurance Claims Fraud Detection
593
 
594
+ Machine learning models for detecting fraudulent auto insurance claims.
 
595
 
596
  **Models:** XGBoost | LightGBM | Random Forest | Logistic Regression
597
  """)
 
600
  # Tab 1: Data Overview
601
  with gr.TabItem("πŸ“Š Data Overview"):
602
  gr.Markdown(get_data_overview())
603
+ gr.Plot(value=plot_class_distribution(train_df, test_df))
604
 
605
  # Tab 2: Model Evaluation
606
  with gr.TabItem("🎯 Model Evaluation"):
 
629
  plot = get_selected_plot(model_name, plot_type)
630
  return metrics, report, plot
631
 
632
+ model_selector.change(fn=update_all, inputs=[model_selector, plot_selector],
633
+ outputs=[metrics_display, report_display, plot_display])
634
+ plot_selector.change(fn=update_all, inputs=[model_selector, plot_selector],
635
+ outputs=[metrics_display, report_display, plot_display])
636
+ demo.load(fn=update_all, inputs=[model_selector, plot_selector],
637
+ outputs=[metrics_display, report_display, plot_display])
 
 
 
 
 
 
 
 
 
 
638
 
639
+ # Tab 3: Compare Models
640
  with gr.TabItem("πŸ“ˆ Compare Models"):
 
 
641
  comparison_table, comparison_summary, comparison_plot = get_comparison_results()
642
+ gr.Markdown("## All Models Performance Comparison")
643
  gr.Markdown(comparison_summary)
644
  gr.Markdown(comparison_table)
645
  gr.Plot(value=comparison_plot)
646
 
647
+ # Tab 4: Threshold
648
  with gr.TabItem("βš–οΈ Threshold Optimization"):
649
  gr.Markdown("""
650
+ ## Finding the Optimal Threshold
651
 
652
+ The default 0.5 threshold often isn't optimal for imbalanced data.
653
+ We balance **Recall** (catching frauds) vs **Precision** (avoiding false alarms).
 
654
  """)
655
 
656
+ thresh_model = gr.Dropdown(choices=list(models.keys()), value="XGBoost",
657
+ label="Select Model")
 
 
 
 
658
  thresh_plot = gr.Plot()
659
 
660
+ thresh_model.change(fn=update_threshold_plot, inputs=[thresh_model], outputs=[thresh_plot])
661
+ demo.load(fn=update_threshold_plot, inputs=[thresh_model], outputs=[thresh_plot])
 
 
 
662
 
663
+ # Thresholds table
664
+ thresh_summary = "### Optimal Thresholds\n\n| Model | Threshold | F1 Score |\n|-------|-----------|----------|\n"
 
 
 
 
 
 
665
  for name in models.keys():
666
+ opt_t, opt_f1, _, _ = find_optimal_threshold(y_test, all_probabilities[name])
667
+ thresh_summary += f"| {name} | {opt_t:.2f} | {opt_f1:.4f} |\n"
 
668
  gr.Markdown(thresh_summary)
669
 
670
  # Tab 5: About
 
673
  ## About This Project
674
 
675
  ### Business Context
676
+ Auto insurance fraud costs billions annually. This tool flags potentially fraudulent claims.
 
677
 
678
+ ### Models
 
 
 
 
 
 
 
679
  - **XGBoost:** Gradient boosting, excellent for tabular data
680
+ - **LightGBM:** Fast, memory-efficient gradient boosting
681
+ - **Random Forest:** Robust ensemble method
682
+ - **Logistic Regression:** Interpretable baseline
683
 
684
  ### Key Metrics
685
+ - **Precision:** Of flagged claims, how many are actually fraud?
686
+ - **Recall:** Of actual frauds, how many did we catch?
687
+ - **F1 Score:** Balance of precision and recall
 
688
  """)
689
 
 
 
690
  if __name__ == "__main__":
691
  demo.launch()