viruthik commited on
Commit
c0ce7f5
·
1 Parent(s): 902f1ae

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -0
app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ import pickle
3
+ from pydantic import BaseModel
4
+ import streamlit as st
5
+
6
+ app = FastAPI()
7
+
8
+ pickle_in = open("classifier.pkl", "rb")
9
+ classifier = pickle.load(pickle_in)
10
+
11
+
12
+ class Classe(BaseModel):
13
+ Sepal_Length: float
14
+ Sepal_Width: float
15
+ Petal_Length: float
16
+ Petal_Width: float
17
+
18
+
19
+ @app.get("/")
20
+ def index():
21
+ return {"hello": "FastAPI"}
22
+
23
+
24
+ @app.get('/{name}')
25
+ def get_name(name: str):
26
+ return {'message': f'hello, {name}'}
27
+
28
+
29
+ @app.post('/predict')
30
+ def predict_species(data: Classe):
31
+ Sepal_Length = data.Sepal_Length
32
+ Sepal_Width = data.Sepal_Width
33
+ Petal_Length = data.Petal_Length
34
+ Petal_Width = data.Petal_Width
35
+
36
+ prediction = classifier.predict([[Sepal_Length, Sepal_Width, Petal_Length, Petal_Width]])
37
+
38
+ if prediction[0] == 0:
39
+ species = "setosa"
40
+ elif prediction[0] == 1:
41
+ species = "virginica"
42
+ elif prediction[0] == 2:
43
+ species = "versicolor"
44
+ else:
45
+ species = "unknown"
46
+
47
+ return {'prediction': species}
48
+
49
+
50
+ if __name__ == "__main__":
51
+ import uvicorn
52
+ import subprocess
53
+
54
+ uvicorn_proc = subprocess.Popen(["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"], stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
55
+ st.title("Iris Species Prediction")
56
+ st.subheader("Enter the following parameters:")
57
+ sepal_length = st.number_input("Sepal Length", min_value=0.0, max_value=10.0, step=0.1, value=5.0)
58
+ sepal_width = st.number_input("Sepal Width", min_value=0.0, max_value=10.0, step=0.1, value=3.5)
59
+ petal_length = st.number_input("Petal Length", min_value=0.0, max_value=10.0, step=0.1, value=1.4)
60
+ petal_width = st.number_input("Petal Width", min_value=0.0, max_value=10.0, step=0.1, value=0.2)
61
+
62
+ submit = st.button("Predict")
63
+ if submit:
64
+ payload = {"Sepal_Length": sepal_length, "Sepal_Width": sepal_width, "Petal_Length": petal_length, "Petal_Width": petal_width}
65
+ prediction = st.empty()
66
+ with st.spinner("Predicting..."):
67
+ response = requests.post("http://localhost:8000/predict", json=payload)
68
+ if response.status_code == 200:
69
+ prediction_result = response.json()
70
+ prediction.success(f"Prediction: {prediction_result['prediction']}")
71
+ else:
72
+ prediction.error("Prediction failed.")
73
+
74
+ uvicorn_proc.kill()