sumitsinha2603's picture
Add deployment files (Dockerfile, app, requirements)
ab4160e
# app.py
import os
import io
import pandas as pd
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from pydantic import BaseModel
from typing import List, Optional, Any
from huggingface_hub import hf_hub_download, login
import joblib
import uvicorn
from contextlib import asynccontextmanager # Added this import
MODEL_REPO_ID = "sumitsinha2603/TourismPackagePredictionAnalysisModel"
MODEL_FILENAME = "TourismPackagePredictionAnalysisModel_v1.joblib"
HF_TOKEN = userdata.get('hf_token')
api = HfApi(token=HF_TOKEN)
DOWNLOAD_DIR = "/tmp/hf_model"
# -------- Initialize FastAPI --------
app = FastAPI(
title="Tourism Prediction Model Serving",
description="Load model from HF Hub, accept inputs, return predictions and save inputs to a DataFrame",
version="0.1"
)
model = None
label_encoders = {}
def ensure_logged_in():
if HF_TOKEN:
login(token=HF_TOKEN)
else:
pass
def load_model_from_hf():
"""Download model file from HF Hub and load with joblib"""
global model
os.makedirs(DOWNLOAD_DIR, exist_ok=True)
ensure_logged_in()
try:
# Downloads file to local cache and returns full path
local_path = hf_hub_download(
repo_id=MODEL_REPO_ID,
filename=MODEL_FILENAME,
repo_type="model",
token=HF_TOKEN
)
except Exception as e:
raise RuntimeError(f"Failed to download model from HF Hub: {e}")
# Load with joblib
model_obj = joblib.load(local_path)
model = model_obj
return model
# Load model on startup
@asynccontextmanager
async def lifespan(app: FastAPI):
print("Loading model...")
global model
model = joblib.load("TourismPackagePredictionAnalysisModel_v1.joblib")
app.state.model = model
yield
print("Shutting down...")
# Re-initialize FastAPI to include lifespan, ensuring it's only defined once
app = FastAPI(
title="Tourism Prediction Model Serving",
description="Load model from HF Hub, accept inputs, return predictions and save inputs to a DataFrame",
version="0.1",
lifespan=lifespan # Pass the lifespan context manager here
)
class PredictRequest(BaseModel):
records: List[dict]
# -------- Helper to coerce inputs into DataFrame --------
def inputs_to_dataframe_from_file(file: UploadFile) -> pd.DataFrame:
# Accept CSV uploads
contents = file.file.read()
try:
df = pd.read_csv(io.BytesIO(contents))
except Exception as e:
raise HTTPException(status_code=400, detail=f"Failed to parse CSV: {e}")
return df
def inputs_to_dataframe_from_json(records: List[dict]) -> pd.DataFrame:
try:
df = pd.DataFrame(records)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON records: {e}")
return df
# -------- Endpoint: predict --------
@app.post("/predict")
async def predict(payload: Optional[PredictRequest] = None, file: Optional[UploadFile] = File(None)):
"""
Provide either:
- JSON body: {"records": [{...}, {...}]}
- or upload CSV file as form data
Returns predictions and the input dataframe saved as CSV inside container.
"""
if payload is None and file is None:
raise HTTPException(status_code=400, detail="No input provided. Send JSON 'records' or upload a CSV file.")
# Convert input to dataframe
if file is not None:
df_in = inputs_to_dataframe_from_file(file)
else:
df_in = inputs_to_dataframe_from_json(payload.records)
current_model = app.state.model
if current_model is None:
# This block might be reached if lifespan failed or for debugging, but ideally model is always loaded
try:
load_model_from_hf()
except Exception as e:
raise HTTPException(status_code=500, detail=f"Model not loaded: {e}")
current_model = model # Update if load_model_from_hf was called
try:
preds = current_model.predict(df_in)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Prediction failed: {e}")
# Save inputs
save_path = os.path.join("/app", "inputs.csv")
try:
# Append if file exists
if os.path.exists(save_path):
existing = pd.read_csv(save_path)
newdf = pd.concat([existing, df_in], ignore_index=True)
newdf.to_csv(save_path, index=False)
else:
df_in.to_csv(save_path, index=False)
except Exception as e:
# Non-fatal; continue
print("Warning: failed to save inputs:", e)
return {
"predictions": preds.tolist(),
"n_records": len(df_in),
"saved_to": save_path
}
# -------- Endpoint: save raw inputs only (optional) --------
@app.post("/save_inputs")
async def save_inputs(payload: PredictRequest):
df_in = inputs_to_dataframe_from_json(payload.records)
save_path = os.path.join("/app", "inputs.csv")
if os.path.exists(save_path):
existing = pd.read_csv(save_path)
newdf = pd.concat([existing, df_in], ignore_index=True)
newdf.to_csv(save_path, index=False)
else:
df_in.to_csv(save_path, index=False)
return {"saved_to": save_path, "n_records": len(df_in)}
# -------- Health check --------
@app.get("/health")
def health():
# Access model state via app.state
return {"status": "ok", "model_loaded": app.state.model is not None}