Sahithi27 commited on
Commit
1f33bc4
·
verified ·
1 Parent(s): 1056ebc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -20
app.py CHANGED
@@ -5,27 +5,18 @@ import numpy as np
5
  sess = ort.InferenceSession("collusion_xgb_model.onnx")
6
 
7
  def predict(amount, user_txn, driver_txn, pair_count, hour, day):
8
- X = np.array([[amount, user_txn, driver_txn, pair_count, hour, day]], dtype=np.float32)
 
 
 
 
9
  outputs = sess.run(None, {sess.get_inputs()[0].name: X})
10
- scores = outputs[0]
11
 
12
- if scores.ndim == 2 and scores.shape[1] == 2:
13
- fraud_prob = scores[0][1]
14
- else:
15
- fraud_prob = scores[0][0]
16
 
17
- return float(fraud_prob)
18
 
19
- gr.Interface(
20
- fn=predict,
21
- inputs=[
22
- gr.Number(label="Amount"),
23
- gr.Number(label="User Txn Count"),
24
- gr.Number(label="Driver Txn Count"),
25
- gr.Number(label="User-Driver Pair Count"),
26
- gr.Number(label="Hour"),
27
- gr.Number(label="Day of Week"),
28
- ],
29
- outputs=gr.Number(label="Fraud Probability"),
30
- title="Collusion Fraud Detection (ONNX)"
31
- ).launch()
 
5
  sess = ort.InferenceSession("collusion_xgb_model.onnx")
6
 
7
  def predict(amount, user_txn, driver_txn, pair_count, hour, day):
8
+ X = np.array(
9
+ [[amount, user_txn, driver_txn, pair_count, hour, day]],
10
+ dtype=np.float32
11
+ )
12
+
13
  outputs = sess.run(None, {sess.get_inputs()[0].name: X})
 
14
 
15
+ # IMPORTANT:
16
+ # outputs[0] = label (0 or 1)
17
+ # outputs[1] = probability [p0, p1]
18
+ probs = outputs[1]
19
 
20
+ fraud_prob = probs[0][1] # class-1 (fraud)
21
 
22
+ return float(fraud_prob)