MalikShehram commited on
Commit
efe2189
·
verified ·
1 Parent(s): 7707501

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -45
app.py CHANGED
@@ -1,56 +1,48 @@
 
1
  from fastapi import FastAPI
2
- from pydantic import BaseModel
3
  import joblib
4
  import numpy as np
5
 
6
- # 1. Load the trained model
7
- # Ensure 'iris_model.pkl' is uploaded to the same folder in Hugging Face
8
  model = joblib.load("iris_model.pkl")
9
 
10
- # 2. Define the class mapping based on your notebook encoding
11
- # In your notebook: Iris-setosa=1, Iris-versicolor=2, Iris-virginica=3
12
- class_names = {
13
- 1: "Iris-setosa",
14
- 2: "Iris-versicolor",
15
- 3: "Iris-virginica"
16
- }
17
 
18
- # 3. Define the input data format using Pydantic
19
- class IrisInput(BaseModel):
20
- SepalLengthCm: float
21
- SepalWidthCm: float
22
- PetalLengthCm: float
23
- PetalWidthCm: float
24
-
25
- # 4. Initialize FastAPI
26
- app = FastAPI()
27
-
28
- # 5. Define the Home Route (Health Check)
29
- @app.get("/")
30
- def home():
31
- return {"message": "Iris Species Prediction API is Live!"}
32
-
33
- # 6. Define the Prediction Route
34
- @app.post("/predict")
35
- def predict_species(data: IrisInput):
36
- # Extract features from the input object
37
- features = np.array([[
38
- data.SepalLengthCm,
39
- data.SepalWidthCm,
40
- data.PetalLengthCm,
41
- data.PetalWidthCm
42
- ]])
43
 
44
- # Make the prediction using the loaded model
45
  prediction = model.predict(features)
46
-
47
- # The model returns an array (e.g., [1]), so we take the first item
48
  predicted_class = int(prediction[0])
49
 
50
- # Map the number back to the name
51
- species_name = class_names.get(predicted_class, "Unknown")
52
-
53
- return {
54
- "predicted_class": predicted_class,
55
- "species_name": species_name
56
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
  from fastapi import FastAPI
 
3
  import joblib
4
  import numpy as np
5
 
6
+ # 1. Load the model you saved from your notebook
 
7
  model = joblib.load("iris_model.pkl")
8
 
9
+ # 2. Define the mapping from your notebook's label encoding
10
+ # Iris-setosa: 1, Iris-versicolor: 2, Iris-virginica: 3
11
+ class_names = {1: "Iris-setosa", 2: "Iris-versicolor", 3: "Iris-virginica"}
 
 
 
 
12
 
13
+ # 3. Create the prediction function for the UI
14
+ def predict_iris(sepal_l, sepal_w, petal_l, petal_w):
15
+ # Prepare the input array for the Logistic Regression model
16
+ features = np.array([[sepal_l, sepal_w, petal_l, petal_w]])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ # Get the numerical prediction
19
  prediction = model.predict(features)
 
 
20
  predicted_class = int(prediction[0])
21
 
22
+ # Return the species name
23
+ return class_names.get(predicted_class, "Unknown")
24
+
25
+ # 4. Set up the Gradio Interface
26
+ interface = gr.Interface(
27
+ fn=predict_iris,
28
+ inputs=[
29
+ gr.Slider(4.0, 8.0, label="Sepal Length (cm)"),
30
+ gr.Slider(2.0, 4.5, label="Sepal Width (cm)"),
31
+ gr.Slider(1.0, 7.0, label="Petal Length (cm)"),
32
+ gr.Slider(0.1, 2.5, label="Petal Width (cm)"),
33
+ ],
34
+ outputs=gr.Textbox(label="Predicted Species"),
35
+ title="Iris Species Classifier",
36
+ description="Slide the values to predict if the flower is Setosa, Versicolor, or Virginica.",
37
+ theme="soft"
38
+ )
39
+
40
+ # 5. Initialize FastAPI and mount the Gradio UI
41
+ app = FastAPI()
42
+
43
+ @app.get("/health")
44
+ def health_check():
45
+ return {"status": "online"}
46
+
47
+ # This mounts the UI to the root "/" path
48
+ app = gr.mount_gradio_app(app, interface, path="/")