File size: 5,416 Bytes
9a76ee6 ab4160e 9a76ee6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 | # app.py
import os
import io
import pandas as pd
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from pydantic import BaseModel
from typing import List, Optional, Any
from huggingface_hub import hf_hub_download, login
import joblib
import uvicorn
from contextlib import asynccontextmanager # Added this import
MODEL_REPO_ID = "sumitsinha2603/TourismPackagePredictionAnalysisModel"
MODEL_FILENAME = "TourismPackagePredictionAnalysisModel_v1.joblib"
HF_TOKEN = userdata.get('hf_token')
api = HfApi(token=HF_TOKEN)
DOWNLOAD_DIR = "/tmp/hf_model"
# -------- Initialize FastAPI --------
app = FastAPI(
title="Tourism Prediction Model Serving",
description="Load model from HF Hub, accept inputs, return predictions and save inputs to a DataFrame",
version="0.1"
)
model = None
label_encoders = {}
def ensure_logged_in():
if HF_TOKEN:
login(token=HF_TOKEN)
else:
pass
def load_model_from_hf():
"""Download model file from HF Hub and load with joblib"""
global model
os.makedirs(DOWNLOAD_DIR, exist_ok=True)
ensure_logged_in()
try:
# Downloads file to local cache and returns full path
local_path = hf_hub_download(
repo_id=MODEL_REPO_ID,
filename=MODEL_FILENAME,
repo_type="model",
token=HF_TOKEN
)
except Exception as e:
raise RuntimeError(f"Failed to download model from HF Hub: {e}")
# Load with joblib
model_obj = joblib.load(local_path)
model = model_obj
return model
# Load model on startup
@asynccontextmanager
async def lifespan(app: FastAPI):
print("Loading model...")
global model
model = joblib.load("TourismPackagePredictionAnalysisModel_v1.joblib")
app.state.model = model
yield
print("Shutting down...")
# Re-initialize FastAPI to include lifespan, ensuring it's only defined once
app = FastAPI(
title="Tourism Prediction Model Serving",
description="Load model from HF Hub, accept inputs, return predictions and save inputs to a DataFrame",
version="0.1",
lifespan=lifespan # Pass the lifespan context manager here
)
class PredictRequest(BaseModel):
records: List[dict]
# -------- Helper to coerce inputs into DataFrame --------
def inputs_to_dataframe_from_file(file: UploadFile) -> pd.DataFrame:
# Accept CSV uploads
contents = file.file.read()
try:
df = pd.read_csv(io.BytesIO(contents))
except Exception as e:
raise HTTPException(status_code=400, detail=f"Failed to parse CSV: {e}")
return df
def inputs_to_dataframe_from_json(records: List[dict]) -> pd.DataFrame:
try:
df = pd.DataFrame(records)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON records: {e}")
return df
# -------- Endpoint: predict --------
@app.post("/predict")
async def predict(payload: Optional[PredictRequest] = None, file: Optional[UploadFile] = File(None)):
"""
Provide either:
- JSON body: {"records": [{...}, {...}]}
- or upload CSV file as form data
Returns predictions and the input dataframe saved as CSV inside container.
"""
if payload is None and file is None:
raise HTTPException(status_code=400, detail="No input provided. Send JSON 'records' or upload a CSV file.")
# Convert input to dataframe
if file is not None:
df_in = inputs_to_dataframe_from_file(file)
else:
df_in = inputs_to_dataframe_from_json(payload.records)
current_model = app.state.model
if current_model is None:
# This block might be reached if lifespan failed or for debugging, but ideally model is always loaded
try:
load_model_from_hf()
except Exception as e:
raise HTTPException(status_code=500, detail=f"Model not loaded: {e}")
current_model = model # Update if load_model_from_hf was called
try:
preds = current_model.predict(df_in)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Prediction failed: {e}")
# Save inputs
save_path = os.path.join("/app", "inputs.csv")
try:
# Append if file exists
if os.path.exists(save_path):
existing = pd.read_csv(save_path)
newdf = pd.concat([existing, df_in], ignore_index=True)
newdf.to_csv(save_path, index=False)
else:
df_in.to_csv(save_path, index=False)
except Exception as e:
# Non-fatal; continue
print("Warning: failed to save inputs:", e)
return {
"predictions": preds.tolist(),
"n_records": len(df_in),
"saved_to": save_path
}
# -------- Endpoint: save raw inputs only (optional) --------
@app.post("/save_inputs")
async def save_inputs(payload: PredictRequest):
df_in = inputs_to_dataframe_from_json(payload.records)
save_path = os.path.join("/app", "inputs.csv")
if os.path.exists(save_path):
existing = pd.read_csv(save_path)
newdf = pd.concat([existing, df_in], ignore_index=True)
newdf.to_csv(save_path, index=False)
else:
df_in.to_csv(save_path, index=False)
return {"saved_to": save_path, "n_records": len(df_in)}
# -------- Health check --------
@app.get("/health")
def health():
# Access model state via app.state
return {"status": "ok", "model_loaded": app.state.model is not None}
|