Sina Media Lab commited on
Commit
e16685b
·
1 Parent(s): 02bffb4

api change2

Browse files
Files changed (1) hide show
  1. main.py +24 -12
main.py CHANGED
@@ -1,35 +1,41 @@
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
 
3
  import numpy as np
4
- import joblib
5
 
6
  app = FastAPI(
7
- title="Iris KNN Prediction API",
8
- description="API for predicting Iris species using KNN model",
9
- version="1.0.0"
10
  )
11
 
 
12
  try:
13
- model, target_names = joblib.load("iris_knn.pkl")
14
- except:
 
15
  model = None
16
  target_names = []
17
 
 
18
  class IrisData(BaseModel):
19
  sepal_length: float
20
  sepal_width: float
21
  petal_length: float
22
  petal_width: float
23
 
 
24
  @app.get("/")
25
- def root():
26
- return {"message": "Iris KNN API Running! Visit /docs to test the API."}
 
27
 
28
  @app.post("/predict")
29
  def predict_iris(data: IrisData):
30
  if model is None:
31
- return {"error": "Model not found on server"}
32
 
 
33
  arr = np.array([[
34
  data.sepal_length,
35
  data.sepal_width,
@@ -37,9 +43,15 @@ def predict_iris(data: IrisData):
37
  data.petal_width
38
  ]])
39
 
40
- pred = model.predict(arr)[0]
 
 
 
41
  return {
42
  "input": data.dict(),
43
- "predicted_class": str(target_names[pred])
 
 
 
 
44
  }
45
-
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
+ from joblib import load
4
  import numpy as np
 
5
 
6
  app = FastAPI(
7
+ title="Iris Species Prediction API",
8
+ description="An API to predict Iris species with class probabilities.",
9
+ version="1.1.0"
10
  )
11
 
12
+ # Load model at startup
13
  try:
14
+ model = load("iris_random_forest.joblib")
15
+ target_names = ["setosa", "versicolor", "virginica"]
16
+ except FileNotFoundError:
17
  model = None
18
  target_names = []
19
 
20
+
21
  class IrisData(BaseModel):
22
  sepal_length: float
23
  sepal_width: float
24
  petal_length: float
25
  petal_width: float
26
 
27
+
28
  @app.get("/")
29
+ def read_root():
30
+ return {"message": "Welcome! Use POST /predict to classify Iris data."}
31
+
32
 
33
  @app.post("/predict")
34
  def predict_iris(data: IrisData):
35
  if model is None:
36
+ return {"error": "Model not found on server."}
37
 
38
+ # Convert input
39
  arr = np.array([[
40
  data.sepal_length,
41
  data.sepal_width,
 
43
  data.petal_width
44
  ]])
45
 
46
+ # Predictions
47
+ pred_idx = int(model.predict(arr)[0])
48
+ probas = model.predict_proba(arr)[0]
49
+
50
  return {
51
  "input": data.dict(),
52
+ "predicted_class_index": pred_idx,
53
+ "predicted_class_name": target_names[pred_idx],
54
+ "class_probabilities": {
55
+ target_names[i]: float(probas[i]) for i in range(len(target_names))
56
+ }
57
  }