import onnxruntime as ort import gradio as gr import numpy as np # ------------------------------------------------ # Load ONNX model (CPU only) # ------------------------------------------------ sess = ort.InferenceSession( "collusion_xgb_model.onnx", providers=["CPUExecutionProvider"] ) # ------------------------------------------------ # Detect probability output tensor safely # ------------------------------------------------ output_info = sess.get_outputs() PROB_INDEX = None for i, out in enumerate(output_info): if len(out.shape) == 2 and out.shape[1] == 2: PROB_INDEX = i break if PROB_INDEX is None: raise RuntimeError("❌ Probability output not found in ONNX model") print(f"✅ Using output index {PROB_INDEX} as probability tensor") # ------------------------------------------------ # Risk bucket logic (PRODUCT FRIENDLY) # ------------------------------------------------ def risk_bucket(p): if p >= 0.05: return "HIGH RISK" elif p >= 0.01: return "MEDIUM RISK" else: return "LOW RISK" # ------------------------------------------------ # Prediction function # ------------------------------------------------ def predict(amount, user_txn, driver_txn, pair_count, hour, day): """ Feature order MUST match training: [amount, user_txn_count, driver_txn_count, user_driver_pair_count, hour, day_of_week] """ X = np.array( [[ float(amount), float(user_txn), float(driver_txn), float(pair_count), float(hour), float(day) ]], dtype=np.float32 ) outputs = sess.run(None, {sess.get_inputs()[0].name: X}) probs = outputs[PROB_INDEX] fraud_prob = float(probs[0][1]) # class-1 probability # Convert to interpretable risk risk = risk_bucket(fraud_prob) risk_score = int(fraud_prob * 1000) # scaled score for visibility return fraud_prob, risk_score, risk # ------------------------------------------------ # Gradio UI # ------------------------------------------------ gr.Interface( fn=predict, inputs=[ gr.Number(label="Amount"), gr.Number(label="User Txn Count"), gr.Number(label="Driver Txn Count"), gr.Number(label="User–Driver Pair Count"), gr.Number(label="Hour (0–23)"), gr.Number(label="Day of Week (0=Mon, 6=Sun)"), ], outputs=[ gr.Number(label="Fraud Probability"), gr.Number(label="Risk Score (0–1000)"), gr.Text(label="Risk Level"), ], title="Collusion Fraud Detection (ONNX)", description=( "This model detects potential user–driver collusion. " "Fraud probability is a ranking signal; risk score and bucket are used for decisions." ), ).launch()