bank-model-api / app.py
aephidayatuloh's picture
Update app.py
2f86ae7 verified
# app.py
import joblib
import pandas as pd
from pydantic import BaseModel
from fastapi import FastAPI, HTTPException
from huggingface_hub import hf_hub_download
# --- KONFIGURASI HF HUB ---
HF_REPO_ID = "aephidayatuloh/bank-model"
HF_MODEL_FILENAME = "random_forest_bank_marketing_pipeline.joblib"
# app.py (atau index.py)
# --- SETUP MODEL (DIJALANKAN SEKALI SAAT STARTUP) ---
app = FastAPI(title="Bank Deposit Prediction (Docker)")
@app.on_event("startup")
def load_model():
global MODEL_PIPELINE
try:
# Download model dari HF Hub (direkomendasikan)
downloaded_model_path = hf_hub_download(
repo_id=HF_REPO_ID,
filename=HF_MODEL_FILENAME
)
MODEL_PIPELINE = joblib.load(downloaded_model_path)
print("βœ… Model berhasil dimuat dari Hugging Face Hub.")
except Exception as e:
print(f"❌ Gagal memuat model: {e}")
MODEL_PIPELINE = None
# --- ENDPOINT PREDIKSI ---
@app.get("/")
def home():
return {"status": "ok", "message": "FastAPI is running inside Docker on HF Spaces."}
# --- 1. Definisi Skema Fitur (Data Mentah) ---
# Model ini mendefinisikan struktur objek yang ada di dalam key "features"
class FeaturesSchema(BaseModel):
"""Skema Pydantic untuk data fitur internal."""
age: int
job: str
marital: str
education: str
default: str
balance: int
housing: str
loan: str
contact: str
day: int
month: str
# duration: int
campaign: int
pdays: int
previous: int
poutcome: str
# Pastikan semua 15 fitur ada di sini, sesuai urutan.
# --- 2. Definisi Skema Payload (Wrapper) ---
# Model ini mendefinisikan struktur payload keseluruhan (yang memiliki key "features")
class PredictionPayload(BaseModel):
"""Skema Pydantic untuk payload yang dikirim."""
features: FeaturesSchema # πŸ’‘ PERUBAHAN UTAMA DI SINI
# --- 3. Perubahan pada Endpoint ---
@app.post("/predict")
# Ganti nama model input di endpoint dari 'PredictionInput' menjadi 'PredictionPayload'
def predict(payload_data: PredictionPayload):
if MODEL_PIPELINE is None:
raise HTTPException(status_code=500, detail="Model gagal dimuat.")
try:
# πŸ’‘ PERUBAHAN PADA PENGAMBILAN DATA
# Ambil data fitur dari wrapper 'payload_data'
input_dict = payload_data.features.dict()
input_df = pd.DataFrame([input_dict])
# ... sisa kode prediksi tetap sama ...
prediction = MODEL_PIPELINE.predict(input_df)[0]
prediction_proba = MODEL_PIPELINE.predict_proba(input_df)[0].tolist()
return {
"prediction_class": int(prediction),
"probability": prediction_proba
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Prediction error: {e}")