Abeshith commited on
Commit
5c61354
·
1 Parent(s): 7f1fbee

Add FastAPI application

Browse files
app/main.py CHANGED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import JSONResponse
4
+ from app.routers import health, predict, train
5
+ from mlpipeline.exception import MLPipelineException
6
+ import uvicorn
7
+
8
+ app = FastAPI(
9
+ title="AutoML MLOps API",
10
+ description="AutoML pipeline API for heart disease prediction",
11
+ version="1.0.0"
12
+ )
13
+
14
+ app.add_middleware(
15
+ CORSMiddleware,
16
+ allow_origins=["*"],
17
+ allow_credentials=True,
18
+ allow_methods=["*"],
19
+ allow_headers=["*"],
20
+ )
21
+
22
+ app.include_router(health.router)
23
+ app.include_router(predict.router)
24
+ app.include_router(train.router)
25
+
26
+
27
+ @app.exception_handler(MLPipelineException)
28
+ async def mlpipeline_exception_handler(request: Request, exc: MLPipelineException):
29
+ return JSONResponse(
30
+ status_code=500,
31
+ content={"error": str(exc)}
32
+ )
33
+
34
+
35
+ @app.get("/")
36
+ async def root():
37
+ return {
38
+ "message": "AutoML MLOps API",
39
+ "version": "1.0.0",
40
+ "endpoints": {
41
+ "health": "/health",
42
+ "predict": "/predict",
43
+ "train": "/train"
44
+ }
45
+ }
46
+
47
+
48
+ if __name__ == "__main__":
49
+ uvicorn.run("app.main:app", host="0.0.0.0", port=8000, reload=True)
app/routers/health.py CHANGED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter
2
+ from app.schemas.response import HealthResponse
3
+ from app.utils.model_loader import model_loader
4
+
5
+ router = APIRouter(prefix="/health", tags=["health"])
6
+
7
+
8
+ @router.get("/", response_model=HealthResponse)
9
+ async def health_check():
10
+ is_loaded = model_loader.is_loaded()
11
+ return HealthResponse(
12
+ status="healthy" if is_loaded else "degraded",
13
+ model_loaded=is_loaded,
14
+ message="Model loaded and ready" if is_loaded else "Model not loaded"
15
+ )
16
+
17
+
18
+ @router.get("/ready", response_model=HealthResponse)
19
+ async def readiness_check():
20
+ is_loaded = model_loader.is_loaded()
21
+ if not is_loaded:
22
+ return HealthResponse(
23
+ status="not_ready",
24
+ model_loaded=False,
25
+ message="Model not loaded"
26
+ )
27
+ return HealthResponse(
28
+ status="ready",
29
+ model_loaded=True,
30
+ message="Service ready"
31
+ )
app/routers/predict.py CHANGED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, HTTPException
2
+ from app.schemas.request import PredictionRequest, BatchPredictionRequest
3
+ from app.schemas.response import PredictionResponse, BatchPredictionResponse
4
+ from app.utils.model_loader import model_loader
5
+ import pandas as pd
6
+
7
+ router = APIRouter(prefix="/predict", tags=["prediction"])
8
+
9
+
10
+ def convert_to_original_columns(data_dict):
11
+ mapping = {
12
+ "Chest_pain_type": "Chest pain type",
13
+ "FBS_over_120": "FBS over 120",
14
+ "EKG_results": "EKG results",
15
+ "Max_HR": "Max HR",
16
+ "Exercise_angina": "Exercise angina",
17
+ "ST_depression": "ST depression",
18
+ "Slope_of_ST": "Slope of ST",
19
+ "Number_of_vessels_fluro": "Number of vessels fluro"
20
+ }
21
+ return {mapping.get(k, k): v for k, v in data_dict.items()}
22
+
23
+
24
+ def add_interaction_features(df):
25
+ df['id_x_Age'] = df['id'] * df['Age']
26
+ return df
27
+
28
+
29
+ @router.post("/", response_model=PredictionResponse)
30
+ async def predict_single(request: PredictionRequest):
31
+ try:
32
+ pipeline = model_loader.get_pipeline()
33
+ input_dict = convert_to_original_columns(request.model_dump())
34
+ df = pd.DataFrame([input_dict])
35
+ df = add_interaction_features(df)
36
+ result = pipeline.predict(df)
37
+
38
+ return PredictionResponse(
39
+ prediction=result["predictions"][0],
40
+ probability=result.get("probabilities")[0] if result.get("probabilities") else None
41
+ )
42
+ except Exception as e:
43
+ raise HTTPException(status_code=500, detail=str(e))
44
+
45
+
46
+ @router.post("/batch", response_model=BatchPredictionResponse)
47
+ async def predict_batch(request: BatchPredictionRequest):
48
+ try:
49
+ pipeline = model_loader.get_pipeline()
50
+ data_list = [convert_to_original_columns(item.model_dump()) for item in request.data]
51
+ df = pd.DataFrame(data_list)
52
+ df = add_interaction_features(df)
53
+ result = pipeline.predict(df)
54
+
55
+ return BatchPredictionResponse(
56
+ predictions=result["predictions"],
57
+ probabilities=result.get("probabilities"),
58
+ num_samples=result["num_samples"]
59
+ )
60
+ except Exception as e:
61
+ raise HTTPException(status_code=500, detail=str(e))
app/routers/train.py CHANGED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, HTTPException, BackgroundTasks
2
+ from app.schemas.request import TrainingRequest
3
+ from app.schemas.response import TrainingResponse
4
+ from app.utils.model_loader import model_loader
5
+ from mlpipeline.pipeline.training_pipeline import TrainingPipeline
6
+
7
+ router = APIRouter(prefix="/train", tags=["training"])
8
+
9
+
10
+ def run_training_task():
11
+ try:
12
+ pipeline = TrainingPipeline()
13
+ artifact = pipeline.run_pipeline()
14
+ model_loader.reload_model()
15
+ return artifact
16
+ except Exception as e:
17
+ raise e
18
+
19
+
20
+ @router.post("/", response_model=TrainingResponse)
21
+ async def trigger_training(request: TrainingRequest, background_tasks: BackgroundTasks):
22
+ try:
23
+ if request.force_retrain:
24
+ background_tasks.add_task(run_training_task)
25
+ return TrainingResponse(
26
+ status="training_started",
27
+ message="Training pipeline started in background"
28
+ )
29
+ else:
30
+ artifact = run_training_task()
31
+ return TrainingResponse(
32
+ status="completed",
33
+ message="Training completed successfully",
34
+ model_path=artifact.pushed_model_path
35
+ )
36
+ except Exception as e:
37
+ raise HTTPException(status_code=500, detail=str(e))
38
+
39
+
40
+ @router.post("/reload", response_model=TrainingResponse)
41
+ async def reload_model():
42
+ try:
43
+ model_loader.reload_model()
44
+ return TrainingResponse(
45
+ status="success",
46
+ message="Model reloaded successfully"
47
+ )
48
+ except Exception as e:
49
+ raise HTTPException(status_code=500, detail=str(e))
app/schemas/__init__.py ADDED
File without changes
app/schemas/request.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ from typing import List, Optional
3
+
4
+
5
+ class PredictionRequest(BaseModel):
6
+ id: int
7
+ Age: int
8
+ Sex: int
9
+ Chest_pain_type: int = Field(alias="Chest pain type")
10
+ BP: int
11
+ Cholesterol: int
12
+ FBS_over_120: int = Field(alias="FBS over 120")
13
+ EKG_results: int = Field(alias="EKG results")
14
+ Max_HR: int = Field(alias="Max HR")
15
+ Exercise_angina: int = Field(alias="Exercise angina")
16
+ ST_depression: float = Field(alias="ST depression")
17
+ Slope_of_ST: int = Field(alias="Slope of ST")
18
+ Number_of_vessels_fluro: int = Field(alias="Number of vessels fluro")
19
+ Thallium: int
20
+
21
+ class Config:
22
+ populate_by_name = True
23
+
24
+
25
+ class BatchPredictionRequest(BaseModel):
26
+ data: List[PredictionRequest]
27
+
28
+
29
+ class TrainingRequest(BaseModel):
30
+ force_retrain: Optional[bool] = False
app/schemas/response.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ from typing import List, Optional, Dict, Any
3
+
4
+
5
+ class PredictionResponse(BaseModel):
6
+ prediction: int
7
+ probability: Optional[List[float]] = None
8
+
9
+
10
+ class BatchPredictionResponse(BaseModel):
11
+ predictions: List[int]
12
+ probabilities: Optional[List[List[float]]] = None
13
+ num_samples: int
14
+
15
+
16
+ class HealthResponse(BaseModel):
17
+ status: str
18
+ model_loaded: bool
19
+ message: str
20
+
21
+
22
+ class TrainingResponse(BaseModel):
23
+ status: str
24
+ message: str
25
+ metrics: Optional[Dict[str, Any]] = None
26
+ model_path: Optional[str] = None
app/utils/__init__.py ADDED
File without changes
app/utils/model_loader.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mlpipeline.pipeline.prediction_pipeline import PredictionPipeline
2
+ from threading import Lock
3
+
4
+
5
+ class ModelLoader:
6
+ _instance = None
7
+ _lock = Lock()
8
+
9
+ def __new__(cls):
10
+ if cls._instance is None:
11
+ with cls._lock:
12
+ if cls._instance is None:
13
+ cls._instance = super().__new__(cls)
14
+ cls._instance.pipeline = None
15
+ return cls._instance
16
+
17
+ def get_pipeline(self) -> PredictionPipeline:
18
+ if self.pipeline is None:
19
+ self.pipeline = PredictionPipeline()
20
+ self.pipeline.load_model()
21
+ return self.pipeline
22
+
23
+ def reload_model(self):
24
+ self.pipeline = PredictionPipeline()
25
+ self.pipeline.load_model()
26
+ return self.pipeline
27
+
28
+ def is_loaded(self) -> bool:
29
+ return self.pipeline is not None and self.pipeline.model is not None
30
+
31
+
32
+ model_loader = ModelLoader()
test_api.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import json
3
+
4
+ BASE_URL = "http://localhost:8000"
5
+
6
+
7
+ def test_health():
8
+ response = requests.get(f"{BASE_URL}/health/")
9
+ print(f"Health Check: {response.json()}")
10
+
11
+
12
+ def test_predict():
13
+ data = {
14
+ "id": 1,
15
+ "Age": 50,
16
+ "Sex": 1,
17
+ "Chest pain type": 2,
18
+ "BP": 120,
19
+ "Cholesterol": 200,
20
+ "FBS over 120": 0,
21
+ "EKG results": 0,
22
+ "Max HR": 150,
23
+ "Exercise angina": 0,
24
+ "ST depression": 1.0,
25
+ "Slope of ST": 1,
26
+ "Number of vessels fluro": 0,
27
+ "Thallium": 3
28
+ }
29
+ response = requests.post(f"{BASE_URL}/predict/", json=data)
30
+ print(f"Prediction: {response.json()}")
31
+
32
+
33
+ def test_batch_predict():
34
+ data = {
35
+ "data": [
36
+ {
37
+ "id": 1,
38
+ "Age": 50,
39
+ "Sex": 1,
40
+ "Chest pain type": 2,
41
+ "BP": 120,
42
+ "Cholesterol": 200,
43
+ "FBS over 120": 0,
44
+ "EKG results": 0,
45
+ "Max HR": 150,
46
+ "Exercise angina": 0,
47
+ "ST depression": 1.0,
48
+ "Slope of ST": 1,
49
+ "Number of vessels fluro": 0,
50
+ "Thallium": 3
51
+ },
52
+ {
53
+ "id": 2,
54
+ "Age": 60,
55
+ "Sex": 0,
56
+ "Chest pain type": 1,
57
+ "BP": 130,
58
+ "Cholesterol": 220,
59
+ "FBS over 120": 1,
60
+ "EKG results": 1,
61
+ "Max HR": 140,
62
+ "Exercise angina": 1,
63
+ "ST depression": 2.0,
64
+ "Slope of ST": 2,
65
+ "Number of vessels fluro": 1,
66
+ "Thallium": 6
67
+ }
68
+ ]
69
+ }
70
+ response = requests.post(f"{BASE_URL}/predict/batch", json=data)
71
+ print(f"Batch Prediction: {response.json()}")
72
+
73
+
74
+ if __name__ == "__main__":
75
+ print("Testing API Endpoints...")
76
+ print("-" * 50)
77
+
78
+ try:
79
+ test_health()
80
+ print("-" * 50)
81
+ test_predict()
82
+ print("-" * 50)
83
+ test_batch_predict()
84
+ except Exception as e:
85
+ print(f"Error: {e}")
test_train_endpoints.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+
3
+ BASE_URL = "http://localhost:8000"
4
+
5
+
6
+ def test_reload():
7
+ print("Testing model reload...")
8
+ response = requests.post(f"{BASE_URL}/train/reload")
9
+ print(f"Response: {response.json()}\n")
10
+
11
+
12
+ def test_health_after_load():
13
+ print("Testing health after model load...")
14
+ response = requests.get(f"{BASE_URL}/health/")
15
+ print(f"Response: {response.json()}\n")
16
+
17
+
18
+ def test_ready():
19
+ print("Testing readiness endpoint...")
20
+ response = requests.get(f"{BASE_URL}/health/ready")
21
+ print(f"Response: {response.json()}\n")
22
+
23
+
24
+ if __name__ == "__main__":
25
+ test_reload()
26
+ test_health_after_load()
27
+ test_ready()
28
+ print("All endpoint tests completed!")