Sina Media Lab commited on
Commit
f9d5b98
·
1 Parent(s): 4e49afd

api change 4

Browse files
Files changed (1) hide show
  1. main.py +9 -1
main.py CHANGED
@@ -9,6 +9,7 @@ app = FastAPI(
9
  version="1.0.0"
10
  )
11
 
 
12
  try:
13
  model, target_names = joblib.load("iris_knn.pkl")
14
  except:
@@ -38,7 +39,14 @@ def predict_iris(data: IrisData):
38
  ]])
39
 
40
  pred = model.predict(arr)[0]
 
 
 
 
 
 
41
  return {
42
  "input": data.dict(),
43
- "predicted_class": str(target_names[pred])
 
44
  }
 
9
  version="1.0.0"
10
  )
11
 
12
+ # Load model & class names
13
  try:
14
  model, target_names = joblib.load("iris_knn.pkl")
15
  except:
 
39
  ]])
40
 
41
  pred = model.predict(arr)[0]
42
+ proba = model.predict_proba(arr)[0]
43
+
44
+ probability_dict = {
45
+ str(target_names[i]): float(proba[i]) for i in range(len(target_names))
46
+ }
47
+
48
  return {
49
  "input": data.dict(),
50
+ "predicted_class": str(target_names[pred]),
51
+ "class_probabilities": probability_dict
52
  }