Spaces:
Sleeping
Sleeping
| import os | |
| import uvicorn | |
| import pandas as pd | |
| from pydantic import BaseModel | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.responses import StreamingResponse | |
| from dotenv import load_dotenv | |
| import joblib | |
| import io | |
| description = """ | |
| # API for rental pricing optimization | |
| The goal of this API is to serve data to optimize rental pricing. | |
| ## Machine-Learning | |
| Where you can: | |
| * `/predict` rental price per day of a car | |
| * `/predict_batch` rental price per day for a batch of cars | |
| Check out documentation for more information on each endpoint. | |
| """ | |
| tags_metadata = [ | |
| { | |
| "name": "Predictions", | |
| "description": "Endpoints that uses our Machine Learning model for rental pricing optimization", | |
| }, | |
| ] | |
| load_dotenv() | |
| MODEL = os.getenv("MODEL_PATH", "models/cars_sharing_lr_model.joblib") | |
| model = joblib.load(MODEL) | |
| app = FastAPI( | |
| title="🚗 Car rental pricing optimization", | |
| description=description, | |
| version="1.0", | |
| contact={ | |
| "name": "Olivier", | |
| "url": "https://huggingface.co/Olivier-52", | |
| }, | |
| openapi_tags=tags_metadata,) | |
| def index(): | |
| """Return a message to the user. | |
| This endpoint does not take any parameters and returns a message | |
| to the user. It is used to test the API. | |
| Returns: | |
| str: A message to the user. | |
| """ | |
| return "Hello world! Go to /docs to try the API." | |
| class PredictionFeatures(BaseModel): | |
| model_key: str | |
| mileage: int | |
| engine_power: int | |
| fuel: str | |
| paint_color: str | |
| car_type: str | |
| private_parking_available: bool | |
| has_gps: bool | |
| has_air_conditioning: bool | |
| automatic_car: bool | |
| has_getaround_connect: bool | |
| has_speed_regulator: bool | |
| winter_tires: bool | |
| def predict(features: PredictionFeatures): | |
| """Predict the rental price per day of a car. | |
| This endpoint takes a PredictionFeatures object as input and returns | |
| the predicted rental price per day. | |
| Args: | |
| features (PredictionFeatures): A PredictionFeatures object | |
| containing the features of the car to predict. | |
| Returns: | |
| dict: A dictionary containing the predicted rental price per day. | |
| """ | |
| df = pd.DataFrame({ | |
| "model_key": [features.model_key], | |
| "mileage": [features.mileage], | |
| "engine_power": [features.engine_power], | |
| "fuel": [features.fuel], | |
| "paint_color": [features.paint_color], | |
| "car_type": [features.car_type], | |
| "private_parking_available": [features.private_parking_available], | |
| "has_gps": [features.has_gps], | |
| "has_air_conditioning": [features.has_air_conditioning], | |
| "automatic_car": [features.automatic_car], | |
| "has_getaround_connect": [features.has_getaround_connect], | |
| "has_speed_regulator": [features.has_speed_regulator], | |
| "winter_tires": [features.winter_tires], | |
| }) | |
| prediction = model.predict(df)[0] | |
| return {"prediction": float(prediction)} | |
| async def predict_batch(file: UploadFile = File(...)): | |
| """Predict the rental price per day of a batch of cars. | |
| Args: | |
| file (UploadFile): A CSV file containing the features of the cars to predict. | |
| Returns: | |
| StreamingResponse: A CSV file with the predicted rental price per day for each car. | |
| """ | |
| try: | |
| df = pd.read_csv(file.file) | |
| required_columns = { | |
| "mileage", "engine_power", "fuel", "paint_color", "car_type", | |
| "private_parking_available", "has_gps", "has_air_conditioning", | |
| "automatic_car", "has_getaround_connect", "has_speed_regulator", "winter_tires" | |
| } | |
| missing_columns = required_columns - set(df.columns) | |
| if missing_columns: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Missing required columns in the CSV file: {missing_columns}" | |
| ) | |
| df["rental_price_per_day"] = model.predict(df) | |
| stream = io.StringIO() | |
| df.to_csv(stream, index=False) | |
| stream.seek(0) | |
| return StreamingResponse( | |
| iter([stream.getvalue()]), | |
| media_type="text/csv", | |
| headers={"Content-Disposition": "attachment; filename=predicted_prices.csv"} | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Error processing file: {str(e)}") | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="localhost", port=8000) |