Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI | |
| import pickle | |
| from pydantic import BaseModel | |
| import streamlit as st | |
| import requests | |
| app = FastAPI() | |
| pickle_in = open("classifier.pkl", "rb") | |
| classifier = pickle.load(pickle_in) | |
| class Classe(BaseModel): | |
| Sepal_Length: float | |
| Sepal_Width: float | |
| Petal_Length: float | |
| Petal_Width: float | |
| def index(): | |
| return {"hello": "FastAPI"} | |
| def get_name(name: str): | |
| return {'message': f'hello, {name}'} | |
| def predict_species(data: Classe): | |
| Sepal_Length = data.Sepal_Length | |
| Sepal_Width = data.Sepal_Width | |
| Petal_Length = data.Petal_Length | |
| Petal_Width = data.Petal_Width | |
| prediction = classifier.predict([[Sepal_Length, Sepal_Width, Petal_Length, Petal_Width]]) | |
| if prediction[0] == 0: | |
| species = "setosa" | |
| elif prediction[0] == 1: | |
| species = "virginica" | |
| elif prediction[0] == 2: | |
| species = "versicolor" | |
| else: | |
| species = "unknown" | |
| return {'prediction': species} | |
| if __name__ == "__main__": | |
| import uvicorn | |
| import subprocess | |
| uvicorn_proc = subprocess.Popen(["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"], stdout=subprocess.PIPE, stderr=subprocess.STDOUT) | |
| st.title("Iris Species Prediction") | |
| st.subheader("Enter the following parameters:") | |
| sepal_length = st.number_input("Sepal Length", min_value=0.0, max_value=10.0, step=0.1, value=5.0) | |
| sepal_width = st.number_input("Sepal Width", min_value=0.0, max_value=10.0, step=0.1, value=3.5) | |
| petal_length = st.number_input("Petal Length", min_value=0.0, max_value=10.0, step=0.1, value=1.4) | |
| petal_width = st.number_input("Petal Width", min_value=0.0, max_value=10.0, step=0.1, value=0.2) | |
| submit = st.button("Predict") | |
| if submit: | |
| payload = {"Sepal_Length": sepal_length, "Sepal_Width": sepal_width, "Petal_Length": petal_length, "Petal_Width": petal_width} | |
| prediction = st.empty() | |
| with st.spinner("Predicting..."): | |
| response = requests.post("http://localhost:8000/predict", json=payload) | |
| if response.status_code == 200: | |
| prediction_result = response.json() | |
| prediction.success(f"Prediction: {prediction_result['prediction']}") | |
| else: | |
| prediction.error("Prediction failed.") | |
| uvicorn_proc.kill() | |