Spaces:
Runtime error
Runtime error
all files
Browse files- encoder.pkl +3 -0
- main.py +91 -0
- rfc_pipeline.pkl +3 -0
encoder.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:98baac762b1becb7d5b699f21576a27ebe7555a83826444f23d540dcfc7d01d1
|
| 3 |
+
size 270
|
main.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI,HTTPException
|
| 2 |
+
from typing import Literal,List
|
| 3 |
+
import uvicorn
|
| 4 |
+
from pydantic import BaseModel
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import os
|
| 7 |
+
import pickle
|
| 8 |
+
|
| 9 |
+
# setup
|
| 10 |
+
SRC = os.path.abspath('./SRC')
|
| 11 |
+
|
| 12 |
+
# Load the pipeline using pickle
|
| 13 |
+
pipeline_path = os.path.join(SRC, 'rfc_pipeline.pkl')
|
| 14 |
+
with open(pipeline_path, 'rb') as file:
|
| 15 |
+
rfc_pipeline = pickle.load(file)
|
| 16 |
+
|
| 17 |
+
# Load the encoder using pickle
|
| 18 |
+
encoder_path = os.path.join(SRC, 'encoder.pkl')
|
| 19 |
+
with open(encoder_path, 'rb') as file:
|
| 20 |
+
encoder = pickle.load(file)
|
| 21 |
+
|
| 22 |
+
app = FastAPI(
|
| 23 |
+
title= 'Income Classification FastAPI',
|
| 24 |
+
description='A FastAPI service to classify individuals based on income level using a trained machine learning model.',
|
| 25 |
+
version= '1.0.0'
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
class IncomePredictionInput(BaseModel):
|
| 29 |
+
age: int
|
| 30 |
+
gender: str
|
| 31 |
+
education: str
|
| 32 |
+
worker_class: str
|
| 33 |
+
marital_status: str
|
| 34 |
+
race: str
|
| 35 |
+
is_hispanic: str
|
| 36 |
+
employment_commitment: str
|
| 37 |
+
employment_stat: int
|
| 38 |
+
wage_per_hour: int
|
| 39 |
+
working_week_per_year: int
|
| 40 |
+
industry_code: int
|
| 41 |
+
industry_code_main: str
|
| 42 |
+
occupation_code: int
|
| 43 |
+
occupation_code_main: str
|
| 44 |
+
total_employed: int
|
| 45 |
+
household_summary: str
|
| 46 |
+
vet_benefit: int
|
| 47 |
+
tax_status: str
|
| 48 |
+
gains: int
|
| 49 |
+
losses: int
|
| 50 |
+
stocks_status: int
|
| 51 |
+
citizenship: str
|
| 52 |
+
importance_of_record: float
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class IncomePredictionOutput(BaseModel):
|
| 56 |
+
income_prediction: str
|
| 57 |
+
prediction_probability: float
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# get
|
| 61 |
+
@app.get('/')
|
| 62 |
+
def home():
|
| 63 |
+
return {
|
| 64 |
+
'message': 'Income Classification FastAPI',
|
| 65 |
+
'description': 'FastAPI service to classify individuals based on income level.',
|
| 66 |
+
'instruction': 'Click here (/docs) to access API documentation and test endpoints.'
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# post
|
| 71 |
+
@app.post('/classify', response_model=IncomePredictionOutput)
|
| 72 |
+
def income_classification(income: IncomePredictionInput):
|
| 73 |
+
try:
|
| 74 |
+
df = pd.DataFrame([income.model_dump()])
|
| 75 |
+
|
| 76 |
+
# Make predictions
|
| 77 |
+
prediction = rfc_pipeline.predict(df)
|
| 78 |
+
output = rfc_pipeline.predict_proba(df)
|
| 79 |
+
|
| 80 |
+
prediction_result = "Income over $50K" if prediction[0] == 1 else "Income under $50K"
|
| 81 |
+
return {"income_prediction": prediction_result, "prediction_probability": output[0][1]}
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
except Exception as e:
|
| 85 |
+
# Return error message and details if an exception occurs
|
| 86 |
+
error_detail = str(e)
|
| 87 |
+
raise HTTPException(status_code=500, detail=f"Error during classification: {error_detail}")
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
if __name__ == '__main__':
|
| 91 |
+
uvicorn.run('main:app', reload=True)
|
rfc_pipeline.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:de6be7713061b3b93f04b7e4c3239216dc6f51b4b18ead0b60fce19c0cbe594e
|
| 3 |
+
size 223631035
|