ethnmcl's picture
Update main.py
b8b7be0 verified
# main.py
import os
import io
import json
import typing as T
from functools import lru_cache
import pandas as pd
from fastapi import FastAPI, File, UploadFile, HTTPException, Body
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, RedirectResponse
from pydantic import BaseModel, Field
from huggingface_hub import login, snapshot_download
import joblib
import xgboost as xgb
import numpy as np
import torch
from transformers import AutoTokenizer, pipeline
# -------- Config --------
HF_TOKEN = (
os.environ.get("HF_TOKEN")
or os.environ.get("HUGGING_FACE_HUB_TOKEN")
or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
)
XGB_REPO = "ethnmcl/entrepreneur-readiness-xgb"
GPT2_REPO = "ethnmcl/gpt2-entrepreneur-agent"
app = FastAPI(
title="Entrepreneur Readiness API",
description=(
"XGBoost readiness scoring + GPT-2 summarization.\n\n"
f"Models:\n- {XGB_REPO}\n- {GPT2_REPO}\n"
"Use /docs for interactive testing."
),
version="1.1.0",
)
# CORS (allow all; tighten for production)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# -------- Convenience root & health --------
@app.get("/", include_in_schema=False)
def root():
return JSONResponse(
{
"ok": True,
"message": "Entrepreneur Readiness API is running.",
"docs": "/docs",
"endpoints": ["/health", "/readiness", "/score", "/score_csv", "/summarize", "/score_and_summarize"],
}
)
# Liveness-only (no model load)
@app.get("/health", include_in_schema=False)
def health():
return JSONResponse({"ok": True, "status": "live", "docs": "/docs"})
# Readiness (loads models)
@app.get("/readiness")
def readiness():
try:
_load_models()
return {"ok": True, "status": "ready"}
except Exception as e:
return JSONResponse({"ok": False, "status": "not_ready", "error": str(e)}, status_code=503)
# Optional warm-up to trigger downloads/caching
@app.post("/warmup", include_in_schema=False)
def warmup():
try:
_load_models()
return {"ok": True, "warmed": True}
except Exception as e:
return JSONResponse({"ok": False, "error": str(e)}, status_code=500)
# -------- Model loading helpers --------
def _find_file(dirpath: str, candidates: T.Sequence[str], fallback_exts: T.Sequence[str] = ()) -> str:
for name in candidates:
p = os.path.join(dirpath, name)
if os.path.exists(p):
return p
for fname in os.listdir(dirpath):
if any(fname.endswith(ext) for ext in fallback_exts):
return os.path.join(dirpath, fname)
raise FileNotFoundError(f"Could not find any of {candidates} (or {fallback_exts}) in {dirpath}")
@lru_cache(maxsize=1)
def _download_artifacts() -> T.Tuple[str, str]:
if HF_TOKEN:
try:
login(token=HF_TOKEN, add_to_git_credential=True)
except Exception:
# Continue if public
pass
xgb_local = snapshot_download(repo_id=XGB_REPO, token=HF_TOKEN, revision=None)
gpt_local = snapshot_download(repo_id=GPT2_REPO, token=HF_TOKEN, revision=None)
return xgb_local, gpt_local
@lru_cache(maxsize=1)
def _load_models():
xgb_dir, gpt_dir = _download_artifacts()
# ---- Preprocessor ----
preproc_path = _find_file(
xgb_dir,
candidates=[
"readiness_preprocessor.joblib",
"preprocessor.joblib",
"preprocessor.pkl",
"readiness_preprocessor.pkl",
],
fallback_exts=(".joblib", ".pkl"),
)
preprocessor = joblib.load(preproc_path)
# ---- XGB booster ----
booster_path = _find_file(
xgb_dir,
candidates=[
"xgb_readiness_model.json",
"xgb_model.json",
"model.json",
"model.ubj",
"model.bin",
"readiness_xgb.json",
],
fallback_exts=(".json", ".ubj", ".bin"),
)
booster = xgb.Booster()
booster.load_model(booster_path)
# ---- GPT-2 text generation: robust tokenizer selection ----
device = 0 if torch.cuda.is_available() else -1
try:
tok = AutoTokenizer.from_pretrained(gpt_dir, use_fast=True, trust_remote_code=False)
except Exception:
# Fallback for "ModelWrapper" tokenizer.json parse errors
tok = AutoTokenizer.from_pretrained(gpt_dir, use_fast=False, trust_remote_code=False)
# Ensure a pad token (map to eos if absent) to avoid generation warnings/errors
if tok.pad_token is None and tok.eos_token is not None:
tok.pad_token = tok.eos_token
text_gen = pipeline(
"text-generation",
model=gpt_dir,
tokenizer=tok,
device=device,
trust_remote_code=False,
)
return preprocessor, booster, text_gen, xgb_dir
# -------- Utils --------
def _coerce_numeric(df: pd.DataFrame) -> pd.DataFrame:
out = df.copy()
for c in out.columns:
if out[c].dtype == object:
try:
out[c] = pd.to_numeric(out[c])
except Exception:
pass
return out
def _to_dmatrix(df: pd.DataFrame, preprocessor) -> xgb.DMatrix:
X = preprocessor.transform(df)
return xgb.DMatrix(X)
def _predict_scores(df: pd.DataFrame, preprocessor, booster) -> np.ndarray:
dmat = _to_dmatrix(df, preprocessor)
scores = booster.predict(dmat)
return np.array(scores).reshape(-1)
def _format_prompt(inputs: dict, score: float) -> str:
kv = "; ".join(f"{k}: {v}" for k, v in inputs.items())
return (
"Summarize the entrepreneur readiness profile succinctly.\n"
f"Inputs -> {kv}; Score -> {score:.3f}\n"
"Summary:"
)
def _summarize(inputs: dict, score: float, text_gen) -> str:
generated = text_gen(
_format_prompt(inputs, score),
max_new_tokens=120,
do_sample=True,
temperature=0.7,
top_p=0.9,
num_return_sequences=1,
eos_token_id=text_gen.tokenizer.eos_token_id,
pad_token_id=text_gen.tokenizer.eos_token_id,
)[0]["generated_text"]
return generated.split("Summary:", 1)[-1].strip() if "Summary:" in generated else generated.strip()
# -------- Schemas (Pydantic v2) --------
class ScoreRequest(BaseModel):
rows: T.List[dict] = Field(..., description="List of row objects (feature_name -> value).")
class ScoreResponse(BaseModel):
scores: T.List[float]
class SummarizeRequest(BaseModel):
inputs: dict = Field(..., description="Feature dict for one example.")
score: float = Field(..., description="Readiness score used in the summary.")
class SummarizeResponse(BaseModel):
summary: str
class ScoreAndSummarizeRequest(BaseModel):
rows: T.List[dict] = Field(..., description="Rows to score and summarize.")
class ScoreAndSummarizeItem(BaseModel):
score: float
summary: str
class ScoreAndSummarizeResponse(BaseModel):
results: T.List[ScoreAndSummarizeItem]
# -------- Endpoints --------
@app.post("/score", response_model=ScoreResponse)
def score_json(req: ScoreRequest = Body(...)):
preprocessor, booster, _, _ = _load_models()
if not req.rows:
raise HTTPException(status_code=400, detail="rows must be non-empty")
df = pd.DataFrame(req.rows)
df = _coerce_numeric(df)
try:
scores = _predict_scores(df, preprocessor, booster)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Scoring failed: {e}")
return ScoreResponse(scores=[float(s) for s in scores])
@app.post("/score_csv", response_model=ScoreResponse)
async def score_csv(file: UploadFile = File(...)):
preprocessor, booster, _, _ = _load_models()
try:
content = await file.read()
df = pd.read_csv(io.BytesIO(content))
df = _coerce_numeric(df)
scores = _predict_scores(df, preprocessor, booster)
except Exception as e:
raise HTTPException(status_code=400, detail=f"CSV scoring failed: {e}")
return ScoreResponse(scores=[float(s) for s in scores])
@app.post("/summarize", response_model=SummarizeResponse)
def summarize(req: SummarizeRequest = Body(...)):
_, _, text_gen, _ = _load_models()
try:
summary = _summarize(req.inputs, req.score, text_gen)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Summarization failed: {e}")
return SummarizeResponse(summary=summary)
@app.post("/score_and_summarize", response_model=ScoreAndSummarizeResponse)
def score_and_summarize(req: ScoreAndSummarizeRequest = Body(...)):
preprocessor, booster, text_gen, _ = _load_models()
if not req.rows:
raise HTTPException(status_code=400, detail="rows must be non-empty")
df = pd.DataFrame(req.rows)
df = _coerce_numeric(df)
try:
scores = _predict_scores(df, preprocessor, booster)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Scoring failed: {e}")
results = []
for i, row in enumerate(req.rows):
try:
summ = _summarize(row, float(scores[i]), text_gen)
except Exception as e:
summ = f"(summary failed: {e})"
results.append(ScoreAndSummarizeItem(score=float(scores[i]), summary=summ))
return ScoreAndSummarizeResponse(results=results)