File size: 779 Bytes
2939a15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from fastapi import APIRouter
from app.models.prediction_models import PredictionRequest, PredictionResponse
from typing import List
from app.utils.data_preparation import load_model_and_data
from app.utils.validations import check_country_code, check_valid_ids

router = APIRouter()

df, model = load_model_and_data()


@router.post("/", response_model=List[PredictionResponse])
def predict(request: PredictionRequest):
    check_country_code(request)
    check_valid_ids(request, df)

    prediction_data = df.loc[request.invoiceId]
    predictions = model.predict(prediction_data)
    response_data = [
        {"invoiceId": invoice_id, "prediction": float(prediction)}
        for invoice_id, prediction in zip(request.invoiceId, predictions)
    ]

    return response_data