Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -5,7 +5,7 @@ from sklearn.datasets import load_iris
|
|
| 5 |
from sklearn.ensemble import RandomForestClassifier
|
| 6 |
|
| 7 |
app = Flask(__name__)
|
| 8 |
-
CORS(app)
|
| 9 |
|
| 10 |
# --- Train or load model
|
| 11 |
try:
|
|
@@ -39,7 +39,13 @@ def home():
|
|
| 39 |
}</pre>
|
| 40 |
<p>Response:</p>
|
| 41 |
<pre>{
|
| 42 |
-
"prediction": "setosa"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
}</pre>
|
| 44 |
"""
|
| 45 |
|
|
@@ -54,15 +60,26 @@ def predict():
|
|
| 54 |
data["petal_width"]
|
| 55 |
]
|
| 56 |
|
| 57 |
-
|
|
|
|
|
|
|
| 58 |
|
| 59 |
-
# If prediction is already a string (class label)
|
| 60 |
-
if isinstance(prediction, str):
|
| 61 |
-
return jsonify({"prediction": prediction})
|
| 62 |
-
|
| 63 |
-
# If prediction is numeric (0, 1, 2)
|
| 64 |
target_names = load_iris().target_names
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
|
| 68 |
if __name__ == "__main__":
|
|
|
|
| 5 |
from sklearn.ensemble import RandomForestClassifier
|
| 6 |
|
| 7 |
app = Flask(__name__)
|
| 8 |
+
CORS(app)
|
| 9 |
|
| 10 |
# --- Train or load model
|
| 11 |
try:
|
|
|
|
| 39 |
}</pre>
|
| 40 |
<p>Response:</p>
|
| 41 |
<pre>{
|
| 42 |
+
"prediction": "setosa",
|
| 43 |
+
"confidence": 0.98,
|
| 44 |
+
"probabilities": {
|
| 45 |
+
"setosa": 0.98,
|
| 46 |
+
"versicolor": 0.01,
|
| 47 |
+
"virginica": 0.01
|
| 48 |
+
}
|
| 49 |
}</pre>
|
| 50 |
"""
|
| 51 |
|
|
|
|
| 60 |
data["petal_width"]
|
| 61 |
]
|
| 62 |
|
| 63 |
+
# Predict class and probabilities
|
| 64 |
+
prediction_idx = model.predict([features])[0]
|
| 65 |
+
probs = model.predict_proba([features])[0]
|
| 66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
target_names = load_iris().target_names
|
| 68 |
+
prediction_label = target_names[int(prediction_idx)]
|
| 69 |
+
|
| 70 |
+
# Build probability dict
|
| 71 |
+
probabilities = {
|
| 72 |
+
target_names[i]: float(probs[i])
|
| 73 |
+
for i in range(len(target_names))
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
confidence = float(max(probs))
|
| 77 |
+
|
| 78 |
+
return jsonify({
|
| 79 |
+
"prediction": prediction_label,
|
| 80 |
+
"confidence": confidence,
|
| 81 |
+
"probabilities": probabilities
|
| 82 |
+
})
|
| 83 |
|
| 84 |
|
| 85 |
if __name__ == "__main__":
|