Update app.py
Browse files
app.py
CHANGED
|
@@ -1,56 +1,48 @@
|
|
|
|
|
| 1 |
from fastapi import FastAPI
|
| 2 |
-
from pydantic import BaseModel
|
| 3 |
import joblib
|
| 4 |
import numpy as np
|
| 5 |
|
| 6 |
-
# 1. Load the
|
| 7 |
-
# Ensure 'iris_model.pkl' is uploaded to the same folder in Hugging Face
|
| 8 |
model = joblib.load("iris_model.pkl")
|
| 9 |
|
| 10 |
-
# 2. Define the
|
| 11 |
-
#
|
| 12 |
-
class_names = {
|
| 13 |
-
1: "Iris-setosa",
|
| 14 |
-
2: "Iris-versicolor",
|
| 15 |
-
3: "Iris-virginica"
|
| 16 |
-
}
|
| 17 |
|
| 18 |
-
# 3.
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
PetalLengthCm: float
|
| 23 |
-
PetalWidthCm: float
|
| 24 |
-
|
| 25 |
-
# 4. Initialize FastAPI
|
| 26 |
-
app = FastAPI()
|
| 27 |
-
|
| 28 |
-
# 5. Define the Home Route (Health Check)
|
| 29 |
-
@app.get("/")
|
| 30 |
-
def home():
|
| 31 |
-
return {"message": "Iris Species Prediction API is Live!"}
|
| 32 |
-
|
| 33 |
-
# 6. Define the Prediction Route
|
| 34 |
-
@app.post("/predict")
|
| 35 |
-
def predict_species(data: IrisInput):
|
| 36 |
-
# Extract features from the input object
|
| 37 |
-
features = np.array([[
|
| 38 |
-
data.SepalLengthCm,
|
| 39 |
-
data.SepalWidthCm,
|
| 40 |
-
data.PetalLengthCm,
|
| 41 |
-
data.PetalWidthCm
|
| 42 |
-
]])
|
| 43 |
|
| 44 |
-
#
|
| 45 |
prediction = model.predict(features)
|
| 46 |
-
|
| 47 |
-
# The model returns an array (e.g., [1]), so we take the first item
|
| 48 |
predicted_class = int(prediction[0])
|
| 49 |
|
| 50 |
-
#
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
from fastapi import FastAPI
|
|
|
|
| 3 |
import joblib
|
| 4 |
import numpy as np
|
| 5 |
|
| 6 |
+
# 1. Load the model you saved from your notebook
|
|
|
|
| 7 |
model = joblib.load("iris_model.pkl")
|
| 8 |
|
| 9 |
+
# 2. Define the mapping from your notebook's label encoding
|
| 10 |
+
# Iris-setosa: 1, Iris-versicolor: 2, Iris-virginica: 3
|
| 11 |
+
class_names = {1: "Iris-setosa", 2: "Iris-versicolor", 3: "Iris-virginica"}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
+
# 3. Create the prediction function for the UI
|
| 14 |
+
def predict_iris(sepal_l, sepal_w, petal_l, petal_w):
|
| 15 |
+
# Prepare the input array for the Logistic Regression model
|
| 16 |
+
features = np.array([[sepal_l, sepal_w, petal_l, petal_w]])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
+
# Get the numerical prediction
|
| 19 |
prediction = model.predict(features)
|
|
|
|
|
|
|
| 20 |
predicted_class = int(prediction[0])
|
| 21 |
|
| 22 |
+
# Return the species name
|
| 23 |
+
return class_names.get(predicted_class, "Unknown")
|
| 24 |
+
|
| 25 |
+
# 4. Set up the Gradio Interface
|
| 26 |
+
interface = gr.Interface(
|
| 27 |
+
fn=predict_iris,
|
| 28 |
+
inputs=[
|
| 29 |
+
gr.Slider(4.0, 8.0, label="Sepal Length (cm)"),
|
| 30 |
+
gr.Slider(2.0, 4.5, label="Sepal Width (cm)"),
|
| 31 |
+
gr.Slider(1.0, 7.0, label="Petal Length (cm)"),
|
| 32 |
+
gr.Slider(0.1, 2.5, label="Petal Width (cm)"),
|
| 33 |
+
],
|
| 34 |
+
outputs=gr.Textbox(label="Predicted Species"),
|
| 35 |
+
title="Iris Species Classifier",
|
| 36 |
+
description="Slide the values to predict if the flower is Setosa, Versicolor, or Virginica.",
|
| 37 |
+
theme="soft"
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
# 5. Initialize FastAPI and mount the Gradio UI
|
| 41 |
+
app = FastAPI()
|
| 42 |
+
|
| 43 |
+
@app.get("/health")
|
| 44 |
+
def health_check():
|
| 45 |
+
return {"status": "online"}
|
| 46 |
+
|
| 47 |
+
# This mounts the UI to the root "/" path
|
| 48 |
+
app = gr.mount_gradio_app(app, interface, path="/")
|