Sahithi27's picture
Update app.py
aa855a4 verified
import onnxruntime as ort
import gradio as gr
import numpy as np
# ------------------------------------------------
# Load ONNX model (CPU only)
# ------------------------------------------------
sess = ort.InferenceSession(
"collusion_xgb_model.onnx",
providers=["CPUExecutionProvider"]
)
# ------------------------------------------------
# Detect probability output tensor safely
# ------------------------------------------------
output_info = sess.get_outputs()
PROB_INDEX = None
for i, out in enumerate(output_info):
if len(out.shape) == 2 and out.shape[1] == 2:
PROB_INDEX = i
break
if PROB_INDEX is None:
raise RuntimeError("❌ Probability output not found in ONNX model")
print(f"✅ Using output index {PROB_INDEX} as probability tensor")
# ------------------------------------------------
# Risk bucket logic (PRODUCT FRIENDLY)
# ------------------------------------------------
def risk_bucket(p):
if p >= 0.05:
return "HIGH RISK"
elif p >= 0.01:
return "MEDIUM RISK"
else:
return "LOW RISK"
# ------------------------------------------------
# Prediction function
# ------------------------------------------------
def predict(amount, user_txn, driver_txn, pair_count, hour, day):
"""
Feature order MUST match training:
[amount, user_txn_count, driver_txn_count, user_driver_pair_count, hour, day_of_week]
"""
X = np.array(
[[
float(amount),
float(user_txn),
float(driver_txn),
float(pair_count),
float(hour),
float(day)
]],
dtype=np.float32
)
outputs = sess.run(None, {sess.get_inputs()[0].name: X})
probs = outputs[PROB_INDEX]
fraud_prob = float(probs[0][1]) # class-1 probability
# Convert to interpretable risk
risk = risk_bucket(fraud_prob)
risk_score = int(fraud_prob * 1000) # scaled score for visibility
return fraud_prob, risk_score, risk
# ------------------------------------------------
# Gradio UI
# ------------------------------------------------
gr.Interface(
fn=predict,
inputs=[
gr.Number(label="Amount"),
gr.Number(label="User Txn Count"),
gr.Number(label="Driver Txn Count"),
gr.Number(label="User–Driver Pair Count"),
gr.Number(label="Hour (0–23)"),
gr.Number(label="Day of Week (0=Mon, 6=Sun)"),
],
outputs=[
gr.Number(label="Fraud Probability"),
gr.Number(label="Risk Score (0–1000)"),
gr.Text(label="Risk Level"),
],
title="Collusion Fraud Detection (ONNX)",
description=(
"This model detects potential user–driver collusion. "
"Fraud probability is a ranking signal; risk score and bucket are used for decisions."
),
).launch()