File size: 2,129 Bytes
3d14250
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from fastapi import FastAPI, HTTPException
from predict_helper import predict
from pydantic import BaseModel, Field
from fastapi.middleware.cors import CORSMiddleware

app = FastAPI(title='Customer Segmentation', version='1.0')

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Allows all origins
    allow_credentials=True,
    allow_methods=["*"],  # Allows all methods
    allow_headers=["*"],  # Allows all headers
)


class BaseInput(BaseModel):
    Age                 :   int = Field(..., ge=18, le=100, description="Customer age between 18 and 100")
    Income              :   int = Field(..., ge=0, le=200000, description="Income between 0 and 200000")
    Total_Spendings     :   int = Field(..., ge=0, le=5000, description="Total spendings (sum of purchases)")
    NumWebPurchases     :   int = Field(..., ge=0, le=100, description="Number of web purchases")
    NumStorePurchases   :   int = Field(..., ge=0, le=100, description="Number of store purchases")
    NumWebVisitsMonth   :   int = Field(..., ge=0, le=50, description="Number of web visits per month")
    Recency             :   int = Field(..., ge=0, le=365, description="Recency (days since last purchase)")


class BaseOutput(BaseModel):
    cluster_id      : int
    cluster_name    : str
    description     : str
    recommendation  : str


@app.get('/')
def Status():
    return {'message' : 'The api server is live and working'}


@app.post('/predict', response_model=BaseOutput)
def predict_segment(input_data: BaseInput):
    try:
        result = predict(
            age=input_data.Age,
            income=input_data.Income,
            total_spending=input_data.Total_Spendings,
            num_web_purchases=input_data.NumWebPurchases,
            num_store_purchases=input_data.NumStorePurchases,
            num_web_visits=input_data.NumWebVisitsMonth,
            recency=input_data.Recency
        )
        return result
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"Error while predicting the output: {e}")