ethnmcl commited on
Commit
e3a7252
·
verified ·
1 Parent(s): f918bcb

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +293 -0
main.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # main.py
2
+ import os
3
+ import io
4
+ import json
5
+ import typing as T
6
+ from functools import lru_cache
7
+
8
+ import pandas as pd
9
+ from fastapi import FastAPI, File, UploadFile, HTTPException
10
+ from fastapi import Body
11
+ from fastapi.middleware.cors import CORSMiddleware
12
+ from pydantic import BaseModel, Field
13
+
14
+ from huggingface_hub import login, snapshot_download
15
+ import joblib
16
+ import xgboost as xgb
17
+ import numpy as np
18
+ import torch
19
+ from transformers import pipeline
20
+
21
+ HF_TOKEN = (
22
+ os.environ.get("HF_TOKEN")
23
+ or os.environ.get("HUGGING_FACE_HUB_TOKEN")
24
+ or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
25
+ )
26
+
27
+ XGB_REPO = "ethnmcl/entrepreneur-readiness-xgb"
28
+ GPT2_REPO = "ethnmcl/gpt2-entrepreneur-agent"
29
+
30
+ app = FastAPI(
31
+ title="Entrepreneur Readiness API",
32
+ description=(
33
+ "XGBoost readiness scoring + GPT-2 summarization.\n\n"
34
+ "Models:\n"
35
+ f"- {XGB_REPO}\n- {GPT2_REPO}\n"
36
+ "Use /docs for interactive testing."
37
+ ),
38
+ version="1.0.0",
39
+ )
40
+
41
+ # CORS (relaxed so you can call from browsers / Framer, etc.)
42
+ app.add_middleware(
43
+ CORSMiddleware,
44
+ allow_origins=["*"], # tighten if needed
45
+ allow_credentials=True,
46
+ allow_methods=["*"],
47
+ allow_headers=["*"],
48
+ )
49
+
50
+
51
+ # -----------------------------
52
+ # Model loading
53
+ # -----------------------------
54
+ def _find_file(dirpath: str, candidates: T.Sequence[str], fallback_exts: T.Sequence[str] = ()) -> str:
55
+ for name in candidates:
56
+ p = os.path.join(dirpath, name)
57
+ if os.path.exists(p):
58
+ return p
59
+ for fname in os.listdir(dirpath):
60
+ if any(fname.endswith(ext) for ext in fallback_exts):
61
+ return os.path.join(dirpath, fname)
62
+ raise FileNotFoundError(f"Could not find any of {candidates} (or {fallback_exts}) in {dirpath}")
63
+
64
+
65
+ @lru_cache(maxsize=1)
66
+ def _download_artifacts() -> T.Tuple[str, str]:
67
+ if HF_TOKEN:
68
+ try:
69
+ login(token=HF_TOKEN, add_to_git_credential=True)
70
+ except Exception:
71
+ # If public, keep going
72
+ pass
73
+ xgb_local = snapshot_download(repo_id=XGB_REPO, token=HF_TOKEN, revision=None)
74
+ gpt_local = snapshot_download(repo_id=GPT2_REPO, token=HF_TOKEN, revision=None)
75
+ return xgb_local, gpt_local
76
+
77
+
78
+ @lru_cache(maxsize=1)
79
+ def _load_models():
80
+ xgb_dir, gpt_dir = _download_artifacts()
81
+
82
+ # Preprocessor
83
+ preproc_path = _find_file(
84
+ xgb_dir,
85
+ candidates=[
86
+ "readiness_preprocessor.joblib",
87
+ "preprocessor.joblib",
88
+ "preprocessor.pkl",
89
+ "readiness_preprocessor.pkl",
90
+ ],
91
+ fallback_exts=(".joblib", ".pkl"),
92
+ )
93
+ preprocessor = joblib.load(preproc_path)
94
+
95
+ # Booster
96
+ booster_path = _find_file(
97
+ xgb_dir,
98
+ candidates=[
99
+ "xgb_readiness_model.json",
100
+ "xgb_model.json",
101
+ "model.json",
102
+ "model.ubj",
103
+ "model.bin",
104
+ "readiness_xgb.json",
105
+ ],
106
+ fallback_exts=(".json", ".ubj", ".bin"),
107
+ )
108
+ booster = xgb.Booster()
109
+ booster.load_model(booster_path)
110
+
111
+ # GPT-2 pipeline
112
+ device = 0 if torch.cuda.is_available() else -1
113
+ text_gen = pipeline(
114
+ "text-generation",
115
+ model=gpt_dir,
116
+ tokenizer=gpt_dir,
117
+ device=device,
118
+ trust_remote_code=False,
119
+ )
120
+
121
+ return preprocessor, booster, text_gen, xgb_dir
122
+
123
+
124
+ # -----------------------------
125
+ # Utils
126
+ # -----------------------------
127
+ def _coerce_numeric(df: pd.DataFrame) -> pd.DataFrame:
128
+ out = df.copy()
129
+ for c in out.columns:
130
+ if out[c].dtype == object:
131
+ try:
132
+ out[c] = pd.to_numeric(out[c])
133
+ except Exception:
134
+ pass
135
+ return out
136
+
137
+
138
+ def _to_dmatrix(df: pd.DataFrame, preprocessor) -> xgb.DMatrix:
139
+ X = preprocessor.transform(df)
140
+ return xgb.DMatrix(X)
141
+
142
+
143
+ def _predict_scores(df: pd.DataFrame, preprocessor, booster) -> np.ndarray:
144
+ dmat = _to_dmatrix(df, preprocessor)
145
+ scores = booster.predict(dmat)
146
+ scores = np.array(scores).reshape(-1)
147
+ return scores
148
+
149
+
150
+ def _format_prompt(inputs: dict, score: float) -> str:
151
+ kv = "; ".join(f"{k}: {v}" for k, v in inputs.items())
152
+ return (
153
+ "Summarize the entrepreneur readiness profile succinctly.\n"
154
+ f"Inputs -> {kv}; Score -> {score:.3f}\n"
155
+ "Summary:"
156
+ )
157
+
158
+
159
+ def _summarize(inputs: dict, score: float, text_gen) -> str:
160
+ prompt = _format_prompt(inputs, score)
161
+ out = text_gen(
162
+ prompt,
163
+ max_new_tokens=120,
164
+ do_sample=True,
165
+ temperature=0.7,
166
+ top_p=0.9,
167
+ num_return_sequences=1,
168
+ eos_token_id=None,
169
+ )[0]["generated_text"]
170
+ return out.split("Summary:", 1)[-1].strip() if "Summary:" in out else out.strip()
171
+
172
+
173
+ # -----------------------------
174
+ # Schemas
175
+ # -----------------------------
176
+ class RowDict(BaseModel):
177
+ __root__: dict
178
+
179
+
180
+ class ScoreRequest(BaseModel):
181
+ rows: T.List[dict] = Field(..., description="List of row objects (feature_name -> value).")
182
+
183
+
184
+ class ScoreResponse(BaseModel):
185
+ scores: T.List[float]
186
+
187
+
188
+ class SummarizeRequest(BaseModel):
189
+ inputs: dict = Field(..., description="Feature dict for one example.")
190
+ score: float = Field(..., description="Readiness score used in the summary.")
191
+
192
+
193
+ class SummarizeResponse(BaseModel):
194
+ summary: str
195
+
196
+
197
+ class ScoreAndSummarizeRequest(BaseModel):
198
+ rows: T.List[dict] = Field(..., description="Rows to score and summarize.")
199
+
200
+
201
+ class ScoreAndSummarizeItem(BaseModel):
202
+ score: float
203
+ summary: str
204
+
205
+
206
+ class ScoreAndSummarizeResponse(BaseModel):
207
+ results: T.List[ScoreAndSummarizeItem]
208
+
209
+
210
+ # -----------------------------
211
+ # Endpoints
212
+ # -----------------------------
213
+ @app.get("/health")
214
+ def health():
215
+ try:
216
+ _load_models()
217
+ return {"ok": True}
218
+ except Exception as e:
219
+ raise HTTPException(status_code=500, detail=str(e))
220
+
221
+
222
+ @app.post("/score", response_model=ScoreResponse)
223
+ def score_json(req: ScoreRequest = Body(...)):
224
+ """
225
+ Score a JSON batch of rows.
226
+ """
227
+ preprocessor, booster, _, _ = _load_models()
228
+ if not req.rows:
229
+ raise HTTPException(status_code=400, detail="rows must be non-empty")
230
+
231
+ df = pd.DataFrame(req.rows)
232
+ df = _coerce_numeric(df)
233
+ try:
234
+ scores = _predict_scores(df, preprocessor, booster)
235
+ except Exception as e:
236
+ raise HTTPException(status_code=400, detail=f"Scoring failed: {e}")
237
+
238
+ return ScoreResponse(scores=[float(s) for s in scores])
239
+
240
+
241
+ @app.post("/score_csv", response_model=ScoreResponse)
242
+ async def score_csv(file: UploadFile = File(...)):
243
+ """
244
+ Score a CSV upload. Returns the scores list in row order.
245
+ """
246
+ preprocessor, booster, _, _ = _load_models()
247
+ try:
248
+ content = await file.read()
249
+ df = pd.read_csv(io.BytesIO(content))
250
+ df = _coerce_numeric(df)
251
+ scores = _predict_scores(df, preprocessor, booster)
252
+ except Exception as e:
253
+ raise HTTPException(status_code=400, detail=f"CSV scoring failed: {e}")
254
+ return ScoreResponse(scores=[float(s) for s in scores])
255
+
256
+
257
+ @app.post("/summarize", response_model=SummarizeResponse)
258
+ def summarize(req: SummarizeRequest = Body(...)):
259
+ """
260
+ Summarize a single example given inputs + score.
261
+ """
262
+ _, _, text_gen, _ = _load_models()
263
+ try:
264
+ summary = _summarize(req.inputs, req.score, text_gen)
265
+ except Exception as e:
266
+ raise HTTPException(status_code=400, detail=f"Summarization failed: {e}")
267
+ return SummarizeResponse(summary=summary)
268
+
269
+
270
+ @app.post("/score_and_summarize", response_model=ScoreAndSummarizeResponse)
271
+ def score_and_summarize(req: ScoreAndSummarizeRequest = Body(...)):
272
+ """
273
+ For each row: compute score, then generate a GPT-2 summary.
274
+ """
275
+ preprocessor, booster, text_gen, _ = _load_models()
276
+ if not req.rows:
277
+ raise HTTPException(status_code=400, detail="rows must be non-empty")
278
+ df = pd.DataFrame(req.rows)
279
+ df = _coerce_numeric(df)
280
+
281
+ try:
282
+ scores = _predict_scores(df, preprocessor, booster)
283
+ except Exception as e:
284
+ raise HTTPException(status_code=400, detail=f"Scoring failed: {e}")
285
+
286
+ results = []
287
+ for i, row in enumerate(req.rows):
288
+ try:
289
+ summ = _summarize(row, float(scores[i]), text_gen)
290
+ except Exception as e:
291
+ summ = f"(summary failed: {e})"
292
+ results.append(ScoreAndSummarizeItem(score=float(scores[i]), summary=summ))
293
+ return ScoreAndSummarizeResponse(results=results)