File size: 3,089 Bytes
1aa566a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""
Prediction endpoint.

POST /predict  →  returns duration prediction + logs to monitor
POST /feedback →  submit delayed ground truth
"""
from __future__ import annotations

import json
import time
import uuid
from typing import TYPE_CHECKING

import pandas as pd
from fastapi import APIRouter, HTTPException, Request

from src.api.schemas import (
    PredictionRequest,
    PredictionResponse,
    FeedbackRequest,
    FeedbackResponse,
)
from src.utils.config import resolve, settings
from src.utils.logging_config import get_logger

router = APIRouter(prefix="/predict", tags=["Prediction"])
log = get_logger(__name__)

_PREDICTION_LOG = resolve(settings.api.prediction_log_path)


@router.post("", response_model=PredictionResponse)
async def predict(body: PredictionRequest, request: Request) -> PredictionResponse:
    """
    Predict taxi trip duration from input features.

    The prediction is:
    1. Logged to disk (for drift analysis)
    2. Registered in the performance monitor (awaiting ground truth)
    """
    app_state = request.app.state
    model = app_state.model
    preprocessor = app_state.preprocessor
    monitor = app_state.monitor

    if model is None:
        raise HTTPException(status_code=503, detail="No trained model available. Please train first.")

    features = body.model_dump()
    request_id = str(uuid.uuid4())
    timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())

    df = pd.DataFrame([features])
    X = preprocessor.transform(df)
    prediction = float(model.predict(X)[0])
    prediction = max(1.0, round(prediction, 2))

    # Log prediction
    log_entry = {
        "request_id": request_id,
        "timestamp": timestamp,
        "features": features,
        "prediction": prediction,
        "model_version": app_state.model_version,
    }
    with open(_PREDICTION_LOG, "a", encoding="utf-8") as fh:
        fh.write(json.dumps(log_entry) + "\n")

    # Register in performance monitor
    monitor.log_prediction(request_id, prediction, features)
    app_state.samples_since_last_retrain += 1

    log.debug("Prediction: request_id=%s  pred=%.2f min", request_id, prediction)

    return PredictionResponse(
        request_id=request_id,
        predicted_duration_min=prediction,
        model_version=app_state.model_version,
        timestamp=timestamp,
    )


@router.post("/feedback", response_model=FeedbackResponse)
async def submit_feedback(body: FeedbackRequest, request: Request) -> FeedbackResponse:
    """
    Submit the delayed ground truth for a previous prediction.

    This simulates labels arriving hours after the prediction was made
    (e.g., after the trip completes and data is reconciled).
    """
    monitor = request.app.state.monitor
    matched = monitor.log_ground_truth(body.request_id, body.actual_duration_min)

    if matched:
        msg = f"Ground truth matched for request_id={body.request_id}"
    else:
        msg = f"No pending prediction found for request_id={body.request_id}"

    return FeedbackResponse(request_id=body.request_id, matched=matched, message=msg)