Spaces:
Sleeping
Sleeping
File size: 2,426 Bytes
e941256 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import joblib
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, ValidationError
import pandas as pd
from helper import convert_to_month_name, transform_new_data
# Load the model and encoder
model = joblib.load('model.pkl')
encoder = joblib.load('encoder.pkl')
app = FastAPI()
# Pydantic model for input data validation
class Item(BaseModel):
MONATSZAHL: str
AUSPRAEGUNG: str
JAHR: int
MONAT: str
# Endpoint for inference
@app.post("/predict/")
async def predict(item: Item):
try:
# Construct input data from request
input_data = {
"MONATSZAHL": item.MONATSZAHL,
"AUSPRAEGUNG": item.AUSPRAEGUNG,
"JAHR": item.JAHR,
"MONAT": item.MONAT
}
# Convert input data to DataFrame
input_df = pd.DataFrame([input_data])
# Convert 'MONAT' to month name
try:
input_df['MONAT'] = input_df['MONAT'].apply(convert_to_month_name)
except Exception as e:
raise HTTPException(
status_code=400,
detail=f"Error converting 'MONAT' to month name: {e}"
)
# Transform data with encoder
try:
transformed_df = transform_new_data(
input_df,
encoder,
original_one_hot_columns=['MONATSZAHL', 'AUSPRAEGUNG', "JAHR", 'MONAT']
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Error transforming data: {e}"
)
# Ensure the transformed data matches the model's expected input
try:
prediction = model.predict(transformed_df)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Error during model prediction: {e}"
)
# Return prediction result
return {"prediction": prediction.tolist()}
except ValidationError as e:
raise HTTPException(
status_code=422,
detail=f"Validation error: {e}"
)
except KeyError as e:
raise HTTPException(
status_code=400,
detail=f"Missing expected column: {e}"
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Internal server error: {e}"
)
|