Spaces:
Running
Running
Add FastAPI application
Browse files- app/main.py +49 -0
- app/routers/health.py +31 -0
- app/routers/predict.py +61 -0
- app/routers/train.py +49 -0
- app/schemas/__init__.py +0 -0
- app/schemas/request.py +30 -0
- app/schemas/response.py +26 -0
- app/utils/__init__.py +0 -0
- app/utils/model_loader.py +32 -0
- test_api.py +85 -0
- test_train_endpoints.py +28 -0
app/main.py
CHANGED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, Request
|
| 2 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
+
from fastapi.responses import JSONResponse
|
| 4 |
+
from app.routers import health, predict, train
|
| 5 |
+
from mlpipeline.exception import MLPipelineException
|
| 6 |
+
import uvicorn
|
| 7 |
+
|
| 8 |
+
app = FastAPI(
|
| 9 |
+
title="AutoML MLOps API",
|
| 10 |
+
description="AutoML pipeline API for heart disease prediction",
|
| 11 |
+
version="1.0.0"
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
app.add_middleware(
|
| 15 |
+
CORSMiddleware,
|
| 16 |
+
allow_origins=["*"],
|
| 17 |
+
allow_credentials=True,
|
| 18 |
+
allow_methods=["*"],
|
| 19 |
+
allow_headers=["*"],
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
app.include_router(health.router)
|
| 23 |
+
app.include_router(predict.router)
|
| 24 |
+
app.include_router(train.router)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@app.exception_handler(MLPipelineException)
|
| 28 |
+
async def mlpipeline_exception_handler(request: Request, exc: MLPipelineException):
|
| 29 |
+
return JSONResponse(
|
| 30 |
+
status_code=500,
|
| 31 |
+
content={"error": str(exc)}
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@app.get("/")
|
| 36 |
+
async def root():
|
| 37 |
+
return {
|
| 38 |
+
"message": "AutoML MLOps API",
|
| 39 |
+
"version": "1.0.0",
|
| 40 |
+
"endpoints": {
|
| 41 |
+
"health": "/health",
|
| 42 |
+
"predict": "/predict",
|
| 43 |
+
"train": "/train"
|
| 44 |
+
}
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
if __name__ == "__main__":
|
| 49 |
+
uvicorn.run("app.main:app", host="0.0.0.0", port=8000, reload=True)
|
app/routers/health.py
CHANGED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter
|
| 2 |
+
from app.schemas.response import HealthResponse
|
| 3 |
+
from app.utils.model_loader import model_loader
|
| 4 |
+
|
| 5 |
+
router = APIRouter(prefix="/health", tags=["health"])
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@router.get("/", response_model=HealthResponse)
|
| 9 |
+
async def health_check():
|
| 10 |
+
is_loaded = model_loader.is_loaded()
|
| 11 |
+
return HealthResponse(
|
| 12 |
+
status="healthy" if is_loaded else "degraded",
|
| 13 |
+
model_loaded=is_loaded,
|
| 14 |
+
message="Model loaded and ready" if is_loaded else "Model not loaded"
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@router.get("/ready", response_model=HealthResponse)
|
| 19 |
+
async def readiness_check():
|
| 20 |
+
is_loaded = model_loader.is_loaded()
|
| 21 |
+
if not is_loaded:
|
| 22 |
+
return HealthResponse(
|
| 23 |
+
status="not_ready",
|
| 24 |
+
model_loaded=False,
|
| 25 |
+
message="Model not loaded"
|
| 26 |
+
)
|
| 27 |
+
return HealthResponse(
|
| 28 |
+
status="ready",
|
| 29 |
+
model_loaded=True,
|
| 30 |
+
message="Service ready"
|
| 31 |
+
)
|
app/routers/predict.py
CHANGED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, HTTPException
|
| 2 |
+
from app.schemas.request import PredictionRequest, BatchPredictionRequest
|
| 3 |
+
from app.schemas.response import PredictionResponse, BatchPredictionResponse
|
| 4 |
+
from app.utils.model_loader import model_loader
|
| 5 |
+
import pandas as pd
|
| 6 |
+
|
| 7 |
+
router = APIRouter(prefix="/predict", tags=["prediction"])
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def convert_to_original_columns(data_dict):
|
| 11 |
+
mapping = {
|
| 12 |
+
"Chest_pain_type": "Chest pain type",
|
| 13 |
+
"FBS_over_120": "FBS over 120",
|
| 14 |
+
"EKG_results": "EKG results",
|
| 15 |
+
"Max_HR": "Max HR",
|
| 16 |
+
"Exercise_angina": "Exercise angina",
|
| 17 |
+
"ST_depression": "ST depression",
|
| 18 |
+
"Slope_of_ST": "Slope of ST",
|
| 19 |
+
"Number_of_vessels_fluro": "Number of vessels fluro"
|
| 20 |
+
}
|
| 21 |
+
return {mapping.get(k, k): v for k, v in data_dict.items()}
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def add_interaction_features(df):
|
| 25 |
+
df['id_x_Age'] = df['id'] * df['Age']
|
| 26 |
+
return df
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@router.post("/", response_model=PredictionResponse)
|
| 30 |
+
async def predict_single(request: PredictionRequest):
|
| 31 |
+
try:
|
| 32 |
+
pipeline = model_loader.get_pipeline()
|
| 33 |
+
input_dict = convert_to_original_columns(request.model_dump())
|
| 34 |
+
df = pd.DataFrame([input_dict])
|
| 35 |
+
df = add_interaction_features(df)
|
| 36 |
+
result = pipeline.predict(df)
|
| 37 |
+
|
| 38 |
+
return PredictionResponse(
|
| 39 |
+
prediction=result["predictions"][0],
|
| 40 |
+
probability=result.get("probabilities")[0] if result.get("probabilities") else None
|
| 41 |
+
)
|
| 42 |
+
except Exception as e:
|
| 43 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@router.post("/batch", response_model=BatchPredictionResponse)
|
| 47 |
+
async def predict_batch(request: BatchPredictionRequest):
|
| 48 |
+
try:
|
| 49 |
+
pipeline = model_loader.get_pipeline()
|
| 50 |
+
data_list = [convert_to_original_columns(item.model_dump()) for item in request.data]
|
| 51 |
+
df = pd.DataFrame(data_list)
|
| 52 |
+
df = add_interaction_features(df)
|
| 53 |
+
result = pipeline.predict(df)
|
| 54 |
+
|
| 55 |
+
return BatchPredictionResponse(
|
| 56 |
+
predictions=result["predictions"],
|
| 57 |
+
probabilities=result.get("probabilities"),
|
| 58 |
+
num_samples=result["num_samples"]
|
| 59 |
+
)
|
| 60 |
+
except Exception as e:
|
| 61 |
+
raise HTTPException(status_code=500, detail=str(e))
|
app/routers/train.py
CHANGED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, HTTPException, BackgroundTasks
|
| 2 |
+
from app.schemas.request import TrainingRequest
|
| 3 |
+
from app.schemas.response import TrainingResponse
|
| 4 |
+
from app.utils.model_loader import model_loader
|
| 5 |
+
from mlpipeline.pipeline.training_pipeline import TrainingPipeline
|
| 6 |
+
|
| 7 |
+
router = APIRouter(prefix="/train", tags=["training"])
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def run_training_task():
|
| 11 |
+
try:
|
| 12 |
+
pipeline = TrainingPipeline()
|
| 13 |
+
artifact = pipeline.run_pipeline()
|
| 14 |
+
model_loader.reload_model()
|
| 15 |
+
return artifact
|
| 16 |
+
except Exception as e:
|
| 17 |
+
raise e
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@router.post("/", response_model=TrainingResponse)
|
| 21 |
+
async def trigger_training(request: TrainingRequest, background_tasks: BackgroundTasks):
|
| 22 |
+
try:
|
| 23 |
+
if request.force_retrain:
|
| 24 |
+
background_tasks.add_task(run_training_task)
|
| 25 |
+
return TrainingResponse(
|
| 26 |
+
status="training_started",
|
| 27 |
+
message="Training pipeline started in background"
|
| 28 |
+
)
|
| 29 |
+
else:
|
| 30 |
+
artifact = run_training_task()
|
| 31 |
+
return TrainingResponse(
|
| 32 |
+
status="completed",
|
| 33 |
+
message="Training completed successfully",
|
| 34 |
+
model_path=artifact.pushed_model_path
|
| 35 |
+
)
|
| 36 |
+
except Exception as e:
|
| 37 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@router.post("/reload", response_model=TrainingResponse)
|
| 41 |
+
async def reload_model():
|
| 42 |
+
try:
|
| 43 |
+
model_loader.reload_model()
|
| 44 |
+
return TrainingResponse(
|
| 45 |
+
status="success",
|
| 46 |
+
message="Model reloaded successfully"
|
| 47 |
+
)
|
| 48 |
+
except Exception as e:
|
| 49 |
+
raise HTTPException(status_code=500, detail=str(e))
|
app/schemas/__init__.py
ADDED
|
File without changes
|
app/schemas/request.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel, Field
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class PredictionRequest(BaseModel):
|
| 6 |
+
id: int
|
| 7 |
+
Age: int
|
| 8 |
+
Sex: int
|
| 9 |
+
Chest_pain_type: int = Field(alias="Chest pain type")
|
| 10 |
+
BP: int
|
| 11 |
+
Cholesterol: int
|
| 12 |
+
FBS_over_120: int = Field(alias="FBS over 120")
|
| 13 |
+
EKG_results: int = Field(alias="EKG results")
|
| 14 |
+
Max_HR: int = Field(alias="Max HR")
|
| 15 |
+
Exercise_angina: int = Field(alias="Exercise angina")
|
| 16 |
+
ST_depression: float = Field(alias="ST depression")
|
| 17 |
+
Slope_of_ST: int = Field(alias="Slope of ST")
|
| 18 |
+
Number_of_vessels_fluro: int = Field(alias="Number of vessels fluro")
|
| 19 |
+
Thallium: int
|
| 20 |
+
|
| 21 |
+
class Config:
|
| 22 |
+
populate_by_name = True
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class BatchPredictionRequest(BaseModel):
|
| 26 |
+
data: List[PredictionRequest]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class TrainingRequest(BaseModel):
|
| 30 |
+
force_retrain: Optional[bool] = False
|
app/schemas/response.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel
|
| 2 |
+
from typing import List, Optional, Dict, Any
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class PredictionResponse(BaseModel):
|
| 6 |
+
prediction: int
|
| 7 |
+
probability: Optional[List[float]] = None
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class BatchPredictionResponse(BaseModel):
|
| 11 |
+
predictions: List[int]
|
| 12 |
+
probabilities: Optional[List[List[float]]] = None
|
| 13 |
+
num_samples: int
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class HealthResponse(BaseModel):
|
| 17 |
+
status: str
|
| 18 |
+
model_loaded: bool
|
| 19 |
+
message: str
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class TrainingResponse(BaseModel):
|
| 23 |
+
status: str
|
| 24 |
+
message: str
|
| 25 |
+
metrics: Optional[Dict[str, Any]] = None
|
| 26 |
+
model_path: Optional[str] = None
|
app/utils/__init__.py
ADDED
|
File without changes
|
app/utils/model_loader.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from mlpipeline.pipeline.prediction_pipeline import PredictionPipeline
|
| 2 |
+
from threading import Lock
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class ModelLoader:
|
| 6 |
+
_instance = None
|
| 7 |
+
_lock = Lock()
|
| 8 |
+
|
| 9 |
+
def __new__(cls):
|
| 10 |
+
if cls._instance is None:
|
| 11 |
+
with cls._lock:
|
| 12 |
+
if cls._instance is None:
|
| 13 |
+
cls._instance = super().__new__(cls)
|
| 14 |
+
cls._instance.pipeline = None
|
| 15 |
+
return cls._instance
|
| 16 |
+
|
| 17 |
+
def get_pipeline(self) -> PredictionPipeline:
|
| 18 |
+
if self.pipeline is None:
|
| 19 |
+
self.pipeline = PredictionPipeline()
|
| 20 |
+
self.pipeline.load_model()
|
| 21 |
+
return self.pipeline
|
| 22 |
+
|
| 23 |
+
def reload_model(self):
|
| 24 |
+
self.pipeline = PredictionPipeline()
|
| 25 |
+
self.pipeline.load_model()
|
| 26 |
+
return self.pipeline
|
| 27 |
+
|
| 28 |
+
def is_loaded(self) -> bool:
|
| 29 |
+
return self.pipeline is not None and self.pipeline.model is not None
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
model_loader = ModelLoader()
|
test_api.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
BASE_URL = "http://localhost:8000"
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def test_health():
|
| 8 |
+
response = requests.get(f"{BASE_URL}/health/")
|
| 9 |
+
print(f"Health Check: {response.json()}")
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def test_predict():
|
| 13 |
+
data = {
|
| 14 |
+
"id": 1,
|
| 15 |
+
"Age": 50,
|
| 16 |
+
"Sex": 1,
|
| 17 |
+
"Chest pain type": 2,
|
| 18 |
+
"BP": 120,
|
| 19 |
+
"Cholesterol": 200,
|
| 20 |
+
"FBS over 120": 0,
|
| 21 |
+
"EKG results": 0,
|
| 22 |
+
"Max HR": 150,
|
| 23 |
+
"Exercise angina": 0,
|
| 24 |
+
"ST depression": 1.0,
|
| 25 |
+
"Slope of ST": 1,
|
| 26 |
+
"Number of vessels fluro": 0,
|
| 27 |
+
"Thallium": 3
|
| 28 |
+
}
|
| 29 |
+
response = requests.post(f"{BASE_URL}/predict/", json=data)
|
| 30 |
+
print(f"Prediction: {response.json()}")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def test_batch_predict():
|
| 34 |
+
data = {
|
| 35 |
+
"data": [
|
| 36 |
+
{
|
| 37 |
+
"id": 1,
|
| 38 |
+
"Age": 50,
|
| 39 |
+
"Sex": 1,
|
| 40 |
+
"Chest pain type": 2,
|
| 41 |
+
"BP": 120,
|
| 42 |
+
"Cholesterol": 200,
|
| 43 |
+
"FBS over 120": 0,
|
| 44 |
+
"EKG results": 0,
|
| 45 |
+
"Max HR": 150,
|
| 46 |
+
"Exercise angina": 0,
|
| 47 |
+
"ST depression": 1.0,
|
| 48 |
+
"Slope of ST": 1,
|
| 49 |
+
"Number of vessels fluro": 0,
|
| 50 |
+
"Thallium": 3
|
| 51 |
+
},
|
| 52 |
+
{
|
| 53 |
+
"id": 2,
|
| 54 |
+
"Age": 60,
|
| 55 |
+
"Sex": 0,
|
| 56 |
+
"Chest pain type": 1,
|
| 57 |
+
"BP": 130,
|
| 58 |
+
"Cholesterol": 220,
|
| 59 |
+
"FBS over 120": 1,
|
| 60 |
+
"EKG results": 1,
|
| 61 |
+
"Max HR": 140,
|
| 62 |
+
"Exercise angina": 1,
|
| 63 |
+
"ST depression": 2.0,
|
| 64 |
+
"Slope of ST": 2,
|
| 65 |
+
"Number of vessels fluro": 1,
|
| 66 |
+
"Thallium": 6
|
| 67 |
+
}
|
| 68 |
+
]
|
| 69 |
+
}
|
| 70 |
+
response = requests.post(f"{BASE_URL}/predict/batch", json=data)
|
| 71 |
+
print(f"Batch Prediction: {response.json()}")
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
if __name__ == "__main__":
|
| 75 |
+
print("Testing API Endpoints...")
|
| 76 |
+
print("-" * 50)
|
| 77 |
+
|
| 78 |
+
try:
|
| 79 |
+
test_health()
|
| 80 |
+
print("-" * 50)
|
| 81 |
+
test_predict()
|
| 82 |
+
print("-" * 50)
|
| 83 |
+
test_batch_predict()
|
| 84 |
+
except Exception as e:
|
| 85 |
+
print(f"Error: {e}")
|
test_train_endpoints.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
|
| 3 |
+
BASE_URL = "http://localhost:8000"
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def test_reload():
|
| 7 |
+
print("Testing model reload...")
|
| 8 |
+
response = requests.post(f"{BASE_URL}/train/reload")
|
| 9 |
+
print(f"Response: {response.json()}\n")
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def test_health_after_load():
|
| 13 |
+
print("Testing health after model load...")
|
| 14 |
+
response = requests.get(f"{BASE_URL}/health/")
|
| 15 |
+
print(f"Response: {response.json()}\n")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def test_ready():
|
| 19 |
+
print("Testing readiness endpoint...")
|
| 20 |
+
response = requests.get(f"{BASE_URL}/health/ready")
|
| 21 |
+
print(f"Response: {response.json()}\n")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
if __name__ == "__main__":
|
| 25 |
+
test_reload()
|
| 26 |
+
test_health_after_load()
|
| 27 |
+
test_ready()
|
| 28 |
+
print("All endpoint tests completed!")
|