Spaces:
Build error
Build error
added confidences to predictions
Browse files
app.py
CHANGED
|
@@ -19,7 +19,7 @@ def train_iris_model(algorithm):
|
|
| 19 |
if algorithm == 'KNN':
|
| 20 |
model = KNeighborsClassifier()
|
| 21 |
elif algorithm == 'SVM':
|
| 22 |
-
model = SVC()
|
| 23 |
elif algorithm == "logistic regression":
|
| 24 |
model = LogisticRegression()
|
| 25 |
elif algorithm == 'Random Forest':
|
|
@@ -28,6 +28,7 @@ def train_iris_model(algorithm):
|
|
| 28 |
model = AdaBoostClassifier()
|
| 29 |
elif algorithm == 'Decision tree':
|
| 30 |
model = DecisionTreeClassifier()
|
|
|
|
| 31 |
model.fit(X_train, y_train)
|
| 32 |
|
| 33 |
return model
|
|
@@ -39,7 +40,12 @@ def predict_iris_species(model, input_data):
|
|
| 39 |
# Make predictions using the trained model
|
| 40 |
prediction = model.predict(input_data)
|
| 41 |
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
def main():
|
| 45 |
st.title("Iris Species Prediction App")
|
|
@@ -50,7 +56,6 @@ def main():
|
|
| 50 |
# Train the model based on user's choice
|
| 51 |
trained_model = train_iris_model(algorithm)
|
| 52 |
|
| 53 |
-
|
| 54 |
st.sidebar.header("User Input")
|
| 55 |
sepal_length = st.sidebar.slider("Sepal Length", 0.0, 10.0, 5.0)
|
| 56 |
sepal_width = st.sidebar.slider("Sepal Width", 0.0, 10.0, 5.0)
|
|
@@ -58,7 +63,7 @@ def main():
|
|
| 58 |
petal_width = st.sidebar.slider("Petal Width", 0.0, 10.0, 5.0)
|
| 59 |
|
| 60 |
input_values = [sepal_length, sepal_width, petal_length, petal_width]
|
| 61 |
-
prediction_result = predict_iris_species(trained_model, input_values)
|
| 62 |
|
| 63 |
species_mapping = {0: 'Iris-setosa', 1: 'Iris-virginica', 2: 'Iris-versicolor'}
|
| 64 |
predicted_species = species_mapping.get(prediction_result[0], 'Unknown')
|
|
@@ -72,6 +77,9 @@ def main():
|
|
| 72 |
st.subheader("Prediction:")
|
| 73 |
st.success(f"Predicted Species: {predicted_species}")
|
| 74 |
|
|
|
|
|
|
|
|
|
|
| 75 |
# Display relevant images based on prediction
|
| 76 |
if predicted_species == 'Iris-setosa':
|
| 77 |
st.image('setosa_image.jpg', caption='Iris-setosa', use_column_width=True)
|
|
|
|
| 19 |
if algorithm == 'KNN':
|
| 20 |
model = KNeighborsClassifier()
|
| 21 |
elif algorithm == 'SVM':
|
| 22 |
+
model = SVC(probability=True)
|
| 23 |
elif algorithm == "logistic regression":
|
| 24 |
model = LogisticRegression()
|
| 25 |
elif algorithm == 'Random Forest':
|
|
|
|
| 28 |
model = AdaBoostClassifier()
|
| 29 |
elif algorithm == 'Decision tree':
|
| 30 |
model = DecisionTreeClassifier()
|
| 31 |
+
|
| 32 |
model.fit(X_train, y_train)
|
| 33 |
|
| 34 |
return model
|
|
|
|
| 40 |
# Make predictions using the trained model
|
| 41 |
prediction = model.predict(input_data)
|
| 42 |
|
| 43 |
+
# Check if the model has a predict_proba method
|
| 44 |
+
if hasattr(model, 'predict_proba'):
|
| 45 |
+
confidence = model.predict_proba(input_data).max()
|
| 46 |
+
return prediction, confidence
|
| 47 |
+
else:
|
| 48 |
+
return prediction, None
|
| 49 |
|
| 50 |
def main():
|
| 51 |
st.title("Iris Species Prediction App")
|
|
|
|
| 56 |
# Train the model based on user's choice
|
| 57 |
trained_model = train_iris_model(algorithm)
|
| 58 |
|
|
|
|
| 59 |
st.sidebar.header("User Input")
|
| 60 |
sepal_length = st.sidebar.slider("Sepal Length", 0.0, 10.0, 5.0)
|
| 61 |
sepal_width = st.sidebar.slider("Sepal Width", 0.0, 10.0, 5.0)
|
|
|
|
| 63 |
petal_width = st.sidebar.slider("Petal Width", 0.0, 10.0, 5.0)
|
| 64 |
|
| 65 |
input_values = [sepal_length, sepal_width, petal_length, petal_width]
|
| 66 |
+
prediction_result, confidence = predict_iris_species(trained_model, input_values)
|
| 67 |
|
| 68 |
species_mapping = {0: 'Iris-setosa', 1: 'Iris-virginica', 2: 'Iris-versicolor'}
|
| 69 |
predicted_species = species_mapping.get(prediction_result[0], 'Unknown')
|
|
|
|
| 77 |
st.subheader("Prediction:")
|
| 78 |
st.success(f"Predicted Species: {predicted_species}")
|
| 79 |
|
| 80 |
+
if confidence is not None:
|
| 81 |
+
st.info(f"Confidence of prediction: {confidence * 100:.2f}%")
|
| 82 |
+
|
| 83 |
# Display relevant images based on prediction
|
| 84 |
if predicted_species == 'Iris-setosa':
|
| 85 |
st.image('setosa_image.jpg', caption='Iris-setosa', use_column_width=True)
|