Olivier-52
Update app.py
91defc5
import os
import uvicorn
import pandas as pd
from pydantic import BaseModel
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import StreamingResponse
from dotenv import load_dotenv
import joblib
import io
description = """
# API for rental pricing optimization
The goal of this API is to serve data to optimize rental pricing.
## Machine-Learning
Where you can:
* `/predict` rental price per day of a car
* `/predict_batch` rental price per day for a batch of cars
Check out documentation for more information on each endpoint.
"""
tags_metadata = [
{
"name": "Predictions",
"description": "Endpoints that uses our Machine Learning model for rental pricing optimization",
},
]
load_dotenv()
MODEL = os.getenv("MODEL_PATH", "models/cars_sharing_lr_model.joblib")
model = joblib.load(MODEL)
app = FastAPI(
title="🚗 Car rental pricing optimization",
description=description,
version="1.0",
contact={
"name": "Olivier",
"url": "https://huggingface.co/Olivier-52",
},
openapi_tags=tags_metadata,)
@app.get("/")
def index():
"""Return a message to the user.
This endpoint does not take any parameters and returns a message
to the user. It is used to test the API.
Returns:
str: A message to the user.
"""
return "Hello world! Go to /docs to try the API."
class PredictionFeatures(BaseModel):
model_key: str
mileage: int
engine_power: int
fuel: str
paint_color: str
car_type: str
private_parking_available: bool
has_gps: bool
has_air_conditioning: bool
automatic_car: bool
has_getaround_connect: bool
has_speed_regulator: bool
winter_tires: bool
@app.post("/predict", tags=["Predictions"])
def predict(features: PredictionFeatures):
"""Predict the rental price per day of a car.
This endpoint takes a PredictionFeatures object as input and returns
the predicted rental price per day.
Args:
features (PredictionFeatures): A PredictionFeatures object
containing the features of the car to predict.
Returns:
dict: A dictionary containing the predicted rental price per day.
"""
df = pd.DataFrame({
"model_key": [features.model_key],
"mileage": [features.mileage],
"engine_power": [features.engine_power],
"fuel": [features.fuel],
"paint_color": [features.paint_color],
"car_type": [features.car_type],
"private_parking_available": [features.private_parking_available],
"has_gps": [features.has_gps],
"has_air_conditioning": [features.has_air_conditioning],
"automatic_car": [features.automatic_car],
"has_getaround_connect": [features.has_getaround_connect],
"has_speed_regulator": [features.has_speed_regulator],
"winter_tires": [features.winter_tires],
})
prediction = model.predict(df)[0]
return {"prediction": float(prediction)}
@app.post("/predict_batch", tags=["Predictions"])
async def predict_batch(file: UploadFile = File(...)):
"""Predict the rental price per day of a batch of cars.
Args:
file (UploadFile): A CSV file containing the features of the cars to predict.
Returns:
StreamingResponse: A CSV file with the predicted rental price per day for each car.
"""
try:
df = pd.read_csv(file.file)
required_columns = {
"mileage", "engine_power", "fuel", "paint_color", "car_type",
"private_parking_available", "has_gps", "has_air_conditioning",
"automatic_car", "has_getaround_connect", "has_speed_regulator", "winter_tires"
}
missing_columns = required_columns - set(df.columns)
if missing_columns:
raise HTTPException(
status_code=400,
detail=f"Missing required columns in the CSV file: {missing_columns}"
)
df["rental_price_per_day"] = model.predict(df)
stream = io.StringIO()
df.to_csv(stream, index=False)
stream.seek(0)
return StreamingResponse(
iter([stream.getvalue()]),
media_type="text/csv",
headers={"Content-Disposition": "attachment; filename=predicted_prices.csv"}
)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Error processing file: {str(e)}")
if __name__ == "__main__":
uvicorn.run(app, host="localhost", port=8000)