LovnishVerma commited on
Commit
8bbc1dc
·
verified ·
1 Parent(s): 9c72171

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -9
app.py CHANGED
@@ -5,7 +5,7 @@ from sklearn.datasets import load_iris
5
  from sklearn.ensemble import RandomForestClassifier
6
 
7
  app = Flask(__name__)
8
- CORS(app) # Allow frontend calls from GitHub Pages
9
 
10
  # --- Train or load model
11
  try:
@@ -39,7 +39,13 @@ def home():
39
  }</pre>
40
  <p>Response:</p>
41
  <pre>{
42
- "prediction": "setosa"
 
 
 
 
 
 
43
  }</pre>
44
  """
45
 
@@ -54,15 +60,26 @@ def predict():
54
  data["petal_width"]
55
  ]
56
 
57
- prediction = model.predict([features])[0] # could be int or string
 
 
58
 
59
- # If prediction is already a string (class label)
60
- if isinstance(prediction, str):
61
- return jsonify({"prediction": prediction})
62
-
63
- # If prediction is numeric (0, 1, 2)
64
  target_names = load_iris().target_names
65
- return jsonify({"prediction": target_names[int(prediction)]})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
 
68
  if __name__ == "__main__":
 
5
  from sklearn.ensemble import RandomForestClassifier
6
 
7
  app = Flask(__name__)
8
+ CORS(app)
9
 
10
  # --- Train or load model
11
  try:
 
39
  }</pre>
40
  <p>Response:</p>
41
  <pre>{
42
+ "prediction": "setosa",
43
+ "confidence": 0.98,
44
+ "probabilities": {
45
+ "setosa": 0.98,
46
+ "versicolor": 0.01,
47
+ "virginica": 0.01
48
+ }
49
  }</pre>
50
  """
51
 
 
60
  data["petal_width"]
61
  ]
62
 
63
+ # Predict class and probabilities
64
+ prediction_idx = model.predict([features])[0]
65
+ probs = model.predict_proba([features])[0]
66
 
 
 
 
 
 
67
  target_names = load_iris().target_names
68
+ prediction_label = target_names[int(prediction_idx)]
69
+
70
+ # Build probability dict
71
+ probabilities = {
72
+ target_names[i]: float(probs[i])
73
+ for i in range(len(target_names))
74
+ }
75
+
76
+ confidence = float(max(probs))
77
+
78
+ return jsonify({
79
+ "prediction": prediction_label,
80
+ "confidence": confidence,
81
+ "probabilities": probabilities
82
+ })
83
 
84
 
85
  if __name__ == "__main__":