Sina Media Lab commited on
Commit
4e49afd
·
1 Parent(s): 7fa72b3

api change 4

Browse files
Files changed (1) hide show
  1. main.py +11 -24
main.py CHANGED
@@ -1,41 +1,35 @@
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
- from joblib import load
4
  import numpy as np
 
5
 
6
  app = FastAPI(
7
- title="Iris Species Prediction API",
8
- description="An API to predict Iris species with class probabilities.",
9
- version="1.1.0"
10
  )
11
 
12
- # Load model at startup
13
  try:
14
- model = load("iris_knn.pkl")
15
- target_names = ["setosa", "versicolor", "virginica"]
16
- except FileNotFoundError:
17
  model = None
18
  target_names = []
19
 
20
-
21
  class IrisData(BaseModel):
22
  sepal_length: float
23
  sepal_width: float
24
  petal_length: float
25
  petal_width: float
26
 
27
-
28
  @app.get("/")
29
- def read_root():
30
- return {"message": "Welcome! Use POST /predict to classify Iris data."}
31
-
32
 
33
  @app.post("/predict")
34
  def predict_iris(data: IrisData):
35
  if model is None:
36
- return {"error": "Model not found on server."}
37
 
38
- # Convert input
39
  arr = np.array([[
40
  data.sepal_length,
41
  data.sepal_width,
@@ -43,15 +37,8 @@ def predict_iris(data: IrisData):
43
  data.petal_width
44
  ]])
45
 
46
- # Predictions
47
- pred_idx = int(model.predict(arr)[0])
48
- probas = model.predict_proba(arr)[0]
49
-
50
  return {
51
  "input": data.dict(),
52
- "predicted_class_index": pred_idx,
53
- "predicted_class_name": target_names[pred_idx],
54
- "class_probabilities": {
55
- target_names[i]: float(probas[i]) for i in range(len(target_names))
56
- }
57
  }
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
 
3
  import numpy as np
4
+ import joblib
5
 
6
  app = FastAPI(
7
+ title="Iris KNN Prediction API",
8
+ description="API for predicting Iris species using KNN model",
9
+ version="1.0.0"
10
  )
11
 
 
12
  try:
13
+ model, target_names = joblib.load("iris_knn.pkl")
14
+ except:
 
15
  model = None
16
  target_names = []
17
 
 
18
  class IrisData(BaseModel):
19
  sepal_length: float
20
  sepal_width: float
21
  petal_length: float
22
  petal_width: float
23
 
 
24
  @app.get("/")
25
+ def root():
26
+ return {"message": "Iris KNN API Running! Visit /docs to test the API."}
 
27
 
28
  @app.post("/predict")
29
  def predict_iris(data: IrisData):
30
  if model is None:
31
+ return {"error": "Model not found on server"}
32
 
 
33
  arr = np.array([[
34
  data.sepal_length,
35
  data.sepal_width,
 
37
  data.petal_width
38
  ]])
39
 
40
+ pred = model.predict(arr)[0]
 
 
 
41
  return {
42
  "input": data.dict(),
43
+ "predicted_class": str(target_names[pred])
 
 
 
 
44
  }