# main.py import os import io import json import typing as T from functools import lru_cache import pandas as pd from fastapi import FastAPI, File, UploadFile, HTTPException, Body from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, RedirectResponse from pydantic import BaseModel, Field from huggingface_hub import login, snapshot_download import joblib import xgboost as xgb import numpy as np import torch from transformers import AutoTokenizer, pipeline # -------- Config -------- HF_TOKEN = ( os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") or os.environ.get("HUGGINGFACEHUB_API_TOKEN") ) XGB_REPO = "ethnmcl/entrepreneur-readiness-xgb" GPT2_REPO = "ethnmcl/gpt2-entrepreneur-agent" app = FastAPI( title="Entrepreneur Readiness API", description=( "XGBoost readiness scoring + GPT-2 summarization.\n\n" f"Models:\n- {XGB_REPO}\n- {GPT2_REPO}\n" "Use /docs for interactive testing." ), version="1.1.0", ) # CORS (allow all; tighten for production) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # -------- Convenience root & health -------- @app.get("/", include_in_schema=False) def root(): return JSONResponse( { "ok": True, "message": "Entrepreneur Readiness API is running.", "docs": "/docs", "endpoints": ["/health", "/readiness", "/score", "/score_csv", "/summarize", "/score_and_summarize"], } ) # Liveness-only (no model load) @app.get("/health", include_in_schema=False) def health(): return JSONResponse({"ok": True, "status": "live", "docs": "/docs"}) # Readiness (loads models) @app.get("/readiness") def readiness(): try: _load_models() return {"ok": True, "status": "ready"} except Exception as e: return JSONResponse({"ok": False, "status": "not_ready", "error": str(e)}, status_code=503) # Optional warm-up to trigger downloads/caching @app.post("/warmup", include_in_schema=False) def warmup(): try: _load_models() return {"ok": True, "warmed": True} except Exception as e: return JSONResponse({"ok": False, "error": str(e)}, status_code=500) # -------- Model loading helpers -------- def _find_file(dirpath: str, candidates: T.Sequence[str], fallback_exts: T.Sequence[str] = ()) -> str: for name in candidates: p = os.path.join(dirpath, name) if os.path.exists(p): return p for fname in os.listdir(dirpath): if any(fname.endswith(ext) for ext in fallback_exts): return os.path.join(dirpath, fname) raise FileNotFoundError(f"Could not find any of {candidates} (or {fallback_exts}) in {dirpath}") @lru_cache(maxsize=1) def _download_artifacts() -> T.Tuple[str, str]: if HF_TOKEN: try: login(token=HF_TOKEN, add_to_git_credential=True) except Exception: # Continue if public pass xgb_local = snapshot_download(repo_id=XGB_REPO, token=HF_TOKEN, revision=None) gpt_local = snapshot_download(repo_id=GPT2_REPO, token=HF_TOKEN, revision=None) return xgb_local, gpt_local @lru_cache(maxsize=1) def _load_models(): xgb_dir, gpt_dir = _download_artifacts() # ---- Preprocessor ---- preproc_path = _find_file( xgb_dir, candidates=[ "readiness_preprocessor.joblib", "preprocessor.joblib", "preprocessor.pkl", "readiness_preprocessor.pkl", ], fallback_exts=(".joblib", ".pkl"), ) preprocessor = joblib.load(preproc_path) # ---- XGB booster ---- booster_path = _find_file( xgb_dir, candidates=[ "xgb_readiness_model.json", "xgb_model.json", "model.json", "model.ubj", "model.bin", "readiness_xgb.json", ], fallback_exts=(".json", ".ubj", ".bin"), ) booster = xgb.Booster() booster.load_model(booster_path) # ---- GPT-2 text generation: robust tokenizer selection ---- device = 0 if torch.cuda.is_available() else -1 try: tok = AutoTokenizer.from_pretrained(gpt_dir, use_fast=True, trust_remote_code=False) except Exception: # Fallback for "ModelWrapper" tokenizer.json parse errors tok = AutoTokenizer.from_pretrained(gpt_dir, use_fast=False, trust_remote_code=False) # Ensure a pad token (map to eos if absent) to avoid generation warnings/errors if tok.pad_token is None and tok.eos_token is not None: tok.pad_token = tok.eos_token text_gen = pipeline( "text-generation", model=gpt_dir, tokenizer=tok, device=device, trust_remote_code=False, ) return preprocessor, booster, text_gen, xgb_dir # -------- Utils -------- def _coerce_numeric(df: pd.DataFrame) -> pd.DataFrame: out = df.copy() for c in out.columns: if out[c].dtype == object: try: out[c] = pd.to_numeric(out[c]) except Exception: pass return out def _to_dmatrix(df: pd.DataFrame, preprocessor) -> xgb.DMatrix: X = preprocessor.transform(df) return xgb.DMatrix(X) def _predict_scores(df: pd.DataFrame, preprocessor, booster) -> np.ndarray: dmat = _to_dmatrix(df, preprocessor) scores = booster.predict(dmat) return np.array(scores).reshape(-1) def _format_prompt(inputs: dict, score: float) -> str: kv = "; ".join(f"{k}: {v}" for k, v in inputs.items()) return ( "Summarize the entrepreneur readiness profile succinctly.\n" f"Inputs -> {kv}; Score -> {score:.3f}\n" "Summary:" ) def _summarize(inputs: dict, score: float, text_gen) -> str: generated = text_gen( _format_prompt(inputs, score), max_new_tokens=120, do_sample=True, temperature=0.7, top_p=0.9, num_return_sequences=1, eos_token_id=text_gen.tokenizer.eos_token_id, pad_token_id=text_gen.tokenizer.eos_token_id, )[0]["generated_text"] return generated.split("Summary:", 1)[-1].strip() if "Summary:" in generated else generated.strip() # -------- Schemas (Pydantic v2) -------- class ScoreRequest(BaseModel): rows: T.List[dict] = Field(..., description="List of row objects (feature_name -> value).") class ScoreResponse(BaseModel): scores: T.List[float] class SummarizeRequest(BaseModel): inputs: dict = Field(..., description="Feature dict for one example.") score: float = Field(..., description="Readiness score used in the summary.") class SummarizeResponse(BaseModel): summary: str class ScoreAndSummarizeRequest(BaseModel): rows: T.List[dict] = Field(..., description="Rows to score and summarize.") class ScoreAndSummarizeItem(BaseModel): score: float summary: str class ScoreAndSummarizeResponse(BaseModel): results: T.List[ScoreAndSummarizeItem] # -------- Endpoints -------- @app.post("/score", response_model=ScoreResponse) def score_json(req: ScoreRequest = Body(...)): preprocessor, booster, _, _ = _load_models() if not req.rows: raise HTTPException(status_code=400, detail="rows must be non-empty") df = pd.DataFrame(req.rows) df = _coerce_numeric(df) try: scores = _predict_scores(df, preprocessor, booster) except Exception as e: raise HTTPException(status_code=400, detail=f"Scoring failed: {e}") return ScoreResponse(scores=[float(s) for s in scores]) @app.post("/score_csv", response_model=ScoreResponse) async def score_csv(file: UploadFile = File(...)): preprocessor, booster, _, _ = _load_models() try: content = await file.read() df = pd.read_csv(io.BytesIO(content)) df = _coerce_numeric(df) scores = _predict_scores(df, preprocessor, booster) except Exception as e: raise HTTPException(status_code=400, detail=f"CSV scoring failed: {e}") return ScoreResponse(scores=[float(s) for s in scores]) @app.post("/summarize", response_model=SummarizeResponse) def summarize(req: SummarizeRequest = Body(...)): _, _, text_gen, _ = _load_models() try: summary = _summarize(req.inputs, req.score, text_gen) except Exception as e: raise HTTPException(status_code=400, detail=f"Summarization failed: {e}") return SummarizeResponse(summary=summary) @app.post("/score_and_summarize", response_model=ScoreAndSummarizeResponse) def score_and_summarize(req: ScoreAndSummarizeRequest = Body(...)): preprocessor, booster, text_gen, _ = _load_models() if not req.rows: raise HTTPException(status_code=400, detail="rows must be non-empty") df = pd.DataFrame(req.rows) df = _coerce_numeric(df) try: scores = _predict_scores(df, preprocessor, booster) except Exception as e: raise HTTPException(status_code=400, detail=f"Scoring failed: {e}") results = [] for i, row in enumerate(req.rows): try: summ = _summarize(row, float(scores[i]), text_gen) except Exception as e: summ = f"(summary failed: {e})" results.append(ScoreAndSummarizeItem(score=float(scores[i]), summary=summ)) return ScoreAndSummarizeResponse(results=results)