File size: 3,860 Bytes
5c61354
 
 
 
b53ee19
 
 
5c61354
b53ee19
5c61354
 
 
b53ee19
 
 
5c61354
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b53ee19
5c61354
 
 
 
 
 
 
b53ee19
8e5f262
 
 
 
b53ee19
 
 
 
 
8e5f262
b53ee19
 
 
 
 
 
5c61354
b53ee19
8e5f262
5c61354
 
b53ee19
5c61354
 
 
 
 
b53ee19
5c61354
 
 
 
 
 
 
b53ee19
 
 
 
 
 
 
 
 
 
 
 
5c61354
 
 
 
 
 
b53ee19
5c61354
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from fastapi import APIRouter, HTTPException
from app.schemas.request import PredictionRequest, BatchPredictionRequest
from app.schemas.response import PredictionResponse, BatchPredictionResponse
from app.utils.model_loader import model_loader
from app.utils.metrics import prediction_counter, prediction_duration
from monitoring.model_monitoring.prediction_logger import PredictionLogger
from pathlib import Path
import pandas as pd
import time

router = APIRouter(prefix="/predict", tags=["prediction"])

# Initialize prediction logger
prediction_logger = PredictionLogger(Path("monitoring/predictions"))


def convert_to_original_columns(data_dict):
    mapping = {
        "Chest_pain_type": "Chest pain type",
        "FBS_over_120": "FBS over 120",
        "EKG_results": "EKG results",
        "Max_HR": "Max HR",
        "Exercise_angina": "Exercise angina",
        "ST_depression": "ST depression",
        "Slope_of_ST": "Slope of ST",
        "Number_of_vessels_fluro": "Number of vessels fluro"
    }
    return {mapping.get(k, k): v for k, v in data_dict.items()}


def add_interaction_features(df):
    df['id_x_Age'] = df['id'] * df['Age']
    return df


@router.post("/", response_model=PredictionResponse)
async def predict_single(request: PredictionRequest):
    start_time = time.time()
    try:
        pipeline = model_loader.get_pipeline()
        input_dict = convert_to_original_columns(request.model_dump())
        df = pd.DataFrame([input_dict])
        df = add_interaction_features(df)
        result = pipeline.predict(df)
        
        prediction = result["predictions"][0]
        proba_row = result.get("probabilities")[0] if result.get("probabilities") else None
        # proba_row is a list [p_class0, p_class1, ...]; take the positive-class prob for logging
        proba_scalar = float(proba_row[-1]) if proba_row is not None else None

        # Log prediction
        prediction_logger.log_prediction(
            input_data=input_dict,
            prediction=int(prediction),
            model_version="v1",
            metadata={"probability": proba_scalar}
        )
        
        # Update metrics
        prediction_counter.labels(model_version="v1", status="success").inc()
        prediction_duration.observe(time.time() - start_time)
        
        return PredictionResponse(
            prediction=prediction,
            probability=proba_row
        )
    except Exception as e:
        prediction_counter.labels(model_version="v1", status="error").inc()
        raise HTTPException(status_code=500, detail=str(e))


@router.post("/batch", response_model=BatchPredictionResponse)
async def predict_batch(request: BatchPredictionRequest):
    start_time = time.time()
    try:
        pipeline = model_loader.get_pipeline()
        data_list = [convert_to_original_columns(item.model_dump()) for item in request.data]
        df = pd.DataFrame(data_list)
        df = add_interaction_features(df)
        result = pipeline.predict(df)
        
        # Log batch predictions
        for input_data, prediction in zip(data_list, result["predictions"]):
            prediction_logger.log_prediction(
                input_data=input_data,
                prediction=int(prediction),
                model_version="v1"
            )
        
        # Update metrics
        prediction_counter.labels(model_version="v1", status="success").inc(len(result["predictions"]))
        prediction_duration.observe(time.time() - start_time)
        
        return BatchPredictionResponse(
            predictions=result["predictions"],
            probabilities=result.get("probabilities"),
            num_samples=result["num_samples"]
        )
    except Exception as e:
        prediction_counter.labels(model_version="v1", status="error").inc()
        raise HTTPException(status_code=500, detail=str(e))