Sahithi27 commited on
Commit
09e497d
·
verified ·
1 Parent(s): 1f33bc4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -7
app.py CHANGED
@@ -2,7 +2,23 @@ import onnxruntime as ort
2
  import gradio as gr
3
  import numpy as np
4
 
5
- sess = ort.InferenceSession("collusion_xgb_model.onnx")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  def predict(amount, user_txn, driver_txn, pair_count, hour, day):
8
  X = np.array(
@@ -12,11 +28,22 @@ def predict(amount, user_txn, driver_txn, pair_count, hour, day):
12
 
13
  outputs = sess.run(None, {sess.get_inputs()[0].name: X})
14
 
15
- # IMPORTANT:
16
- # outputs[0] = label (0 or 1)
17
- # outputs[1] = probability [p0, p1]
18
- probs = outputs[1]
19
-
20
- fraud_prob = probs[0][1] # class-1 (fraud)
21
 
22
  return float(fraud_prob)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import gradio as gr
3
  import numpy as np
4
 
5
+ # Load ONNX model
6
+ sess = ort.InferenceSession("collusion_xgb_model.onnx", providers=["CPUExecutionProvider"])
7
+
8
+ # Inspect outputs ONCE (safe)
9
+ output_info = sess.get_outputs()
10
+
11
+ # Find probability output index
12
+ PROB_INDEX = None
13
+ for i, out in enumerate(output_info):
14
+ if len(out.shape) == 2 and out.shape[1] == 2:
15
+ PROB_INDEX = i
16
+ break
17
+
18
+ if PROB_INDEX is None:
19
+ raise RuntimeError("Probability output not found in ONNX model")
20
+
21
+ print(f" Using output index {PROB_INDEX} as probability tensor")
22
 
23
  def predict(amount, user_txn, driver_txn, pair_count, hour, day):
24
  X = np.array(
 
28
 
29
  outputs = sess.run(None, {sess.get_inputs()[0].name: X})
30
 
31
+ probs = outputs[PROB_INDEX] # [p0, p1]
32
+ fraud_prob = probs[0][1] # class-1 probability
 
 
 
 
33
 
34
  return float(fraud_prob)
35
+
36
+ gr.Interface(
37
+ fn=predict,
38
+ inputs=[
39
+ gr.Number(label="Amount"),
40
+ gr.Number(label="User Txn Count"),
41
+ gr.Number(label="Driver Txn Count"),
42
+ gr.Number(label="User–Driver Pair Count"),
43
+ gr.Number(label="Hour"),
44
+ gr.Number(label="Day of Week"),
45
+ ],
46
+ outputs=gr.Number(label="Fraud Probability"),
47
+ title="Collusion Fraud Detection (ONNX)",
48
+ ).launch()
49
+