File size: 2,958 Bytes
fef0885
78a46ba
fef0885
 
 
 
 
 
78a46ba
fef0885
 
21b385b
fef0885
 
8d2abde
fef0885
8d2abde
fef0885
 
8d2abde
 
 
fef0885
 
 
 
 
 
 
78a46ba
 
 
 
 
 
 
 
fef0885
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78a46ba
fef0885
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d2abde
 
 
 
 
fef0885
8d2abde
 
 
fef0885
8d2abde
 
fef0885
 
 
 
 
 
78a46ba
8d2abde
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from fastapi import FastAPI,HTTPException
from fastapi.middleware.cors import CORSMiddleware
from typing import Literal,List
import uvicorn
from pydantic import BaseModel
import pandas as pd
import os
import pickle
from transformers import log_transform

# setup
SRC = os.path.abspath('.')

# Load the pipeline using pickle
pipeline_path = os.path.join(SRC, 'pipeline.pkl')
with open(pipeline_path, 'rb') as file:
    pipeline = pickle.load(file)

# Load the encoder using pickle
model_path = os.path.join(SRC, 'rfc_model.pkl')
with open(model_path, 'rb') as file:
    model = pickle.load(file)

app = FastAPI(
    title= 'Income Classification FastAPI',
    description='A FastAPI service to classify individuals based on income level using a trained machine learning model.',
    version= '1.0.0'
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  
    allow_credentials=True,
    allow_methods=["*"],  
    allow_headers=["*"],  
)

class IncomePredictionInput(BaseModel):
    age:                   int
    gender:                str
    education:             str
    worker_class:          str
    marital_status:        str
    race:                  str
    is_hispanic:           str
    employment_commitment: str
    employment_stat:       int
    wage_per_hour:         int
    working_week_per_year: int
    industry_code:         int
    industry_code_main:    str
    occupation_code:       int
    occupation_code_main:  str
    total_employed:        int
    household_summary:     str
    vet_benefit:           int
    tax_status:            str
    gains:                 int
    losses:                int
    stocks_status:         int
    citizenship:           str


   
class IncomePredictionOutput(BaseModel):
    income_prediction: str
    prediction_probability: float


# get
@app.get('/')
def home():
    return {
        'message': 'Income Classification FastAPI',
        'description': 'FastAPI service to classify individuals based on income level.',
        'instruction': 'Click here (/docs) to access API documentation and test endpoints.'
    }
   

# post
@app.post('/classify', response_model=IncomePredictionOutput)
def income_classification(income: IncomePredictionInput):
    try:
        # Convert input data to DataFrame
        input_df = pd.DataFrame([dict(income)])

        # Preprocess the input data through the pipeline
        input_df_transformed = pipeline.transform(input_df)

        # Make predictions
        prediction = model.predict(input_df_transformed)
        probability = model.predict_proba(input_df_transformed).max(axis=1)[0]

        prediction_result = "Above Limit" if prediction[0] == 1 else "Below Limit"
        return {"income_prediction": prediction_result, "prediction_probability": probability}

    except Exception as e:
        error_detail = str(e)
        raise HTTPException(status_code=500, detail=f"Error during classification: {error_detail}")