File size: 3,743 Bytes
7537ea4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import gradio as gr
import pickle
import numpy as np
import json

with open('model.pkl', 'rb') as f:
    model = pickle.load(f)

with open('metrics.json', 'r') as f:
    metrics = json.load(f)

with open('retrain_report.json', 'r') as f:
    retrain_report = json.load(f)


def predict(amount, hour, day_of_week, distance_from_home,
            distance_from_last_txn, ratio_to_median, num_txn_last_24h,
            is_foreign, merchant_risk_score, card_age_months):
    features = np.array([[amount, hour, day_of_week, distance_from_home,
                          distance_from_last_txn, ratio_to_median,
                          num_txn_last_24h, int(is_foreign),
                          merchant_risk_score, card_age_months]])
    prob = model.predict_proba(features)[0]
    return {'Legitimate': float(prob[0]), 'Fraud': float(prob[1])}


def get_model_info():
    m = metrics.get('metrics', {})
    r = retrain_report
    lines = []
    lines.append('Last updated: ' + metrics.get('timestamp', 'N/A'))
    lines.append('F1 Score: ' + str(round(m.get('f1', 0), 4)))
    lines.append('ROC AUC: ' + str(round(m.get('roc_auc', 0), 4)))
    lines.append('Precision: ' + str(round(m.get('precision', 0), 4)))
    lines.append('Recall: ' + str(round(m.get('recall', 0), 4)))
    lines.append('')
    lines.append('Retrain Decision: ' + str(r.get('decision', 'N/A')))
    lines.append('Reason: ' + str(r.get('reason', 'N/A')))
    drift = r.get('drift', {})
    lines.append('Drift Detected: ' + str(drift.get('dataset_drift', 'N/A')))
    lines.append('Drifted Features: ' + str(drift.get('drifted_features', [])))
    return chr(10).join(lines)


with gr.Blocks() as demo:
    gr.Markdown('# Fraud Detection System (Auto-Retrained)')
    gr.Markdown('This model is automatically retrained when data drift is detected.')

    with gr.Tab('Predict'):
        with gr.Row():
            with gr.Column():
                amount = gr.Number(value=50.0, label='Transaction Amount ($)')
                hour = gr.Slider(0, 23, value=14, step=1, label='Hour of Day')
                day = gr.Slider(0, 6, value=3, step=1, label='Day of Week (0=Mon)')
                dist_home = gr.Number(value=10.0, label='Distance from Home (km)')
                dist_last = gr.Number(value=5.0, label='Distance from Last Txn (km)')
            with gr.Column():
                ratio = gr.Number(value=1.0, label='Ratio to Median Spending')
                n_txn = gr.Slider(0, 20, value=3, step=1, label='Txns in Last 24h')
                foreign = gr.Checkbox(value=False, label='Foreign Transaction')
                merchant = gr.Slider(0, 1, value=0.2, step=0.05, label='Merchant Risk Score')
                card_age = gr.Slider(1, 120, value=36, step=1, label='Card Age (months)')

        predict_btn = gr.Button('Analyze Transaction', variant='primary')
        output = gr.Label(num_top_classes=2, label='Result')

        predict_btn.click(
            fn=predict,
            inputs=[amount, hour, day, dist_home, dist_last,
                    ratio, n_txn, foreign, merchant, card_age],
            outputs=output
        )

        gr.Examples(
            examples=[
                [25.0, 14, 2, 5.0, 3.0, 0.8, 2, False, 0.1, 48],
                [500.0, 3, 5, 100.0, 80.0, 5.0, 10, True, 0.7, 6],
                [1200.0, 2, 6, 200.0, 150.0, 8.0, 15, True, 0.9, 3],
            ],
            inputs=[amount, hour, day, dist_home, dist_last,
                    ratio, n_txn, foreign, merchant, card_age],
        )

    with gr.Tab('Model Info'):
        info_btn = gr.Button('Show Model Info')
        info_output = gr.Textbox(label='Model Details', lines=12)
        info_btn.click(fn=get_model_info, outputs=info_output)

demo.launch()