ethnmcl commited on
Commit
c5ddf38
·
verified ·
1 Parent(s): 334fef3

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +15 -66
main.py CHANGED
@@ -6,8 +6,7 @@ import typing as T
6
  from functools import lru_cache
7
 
8
  import pandas as pd
9
- from fastapi import FastAPI, File, UploadFile, HTTPException
10
- from fastapi import Body
11
  from fastapi.middleware.cors import CORSMiddleware
12
  from pydantic import BaseModel, Field
13
 
@@ -18,6 +17,7 @@ import numpy as np
18
  import torch
19
  from transformers import pipeline
20
 
 
21
  HF_TOKEN = (
22
  os.environ.get("HF_TOKEN")
23
  or os.environ.get("HUGGING_FACE_HUB_TOKEN")
@@ -31,26 +31,22 @@ app = FastAPI(
31
  title="Entrepreneur Readiness API",
32
  description=(
33
  "XGBoost readiness scoring + GPT-2 summarization.\n\n"
34
- "Models:\n"
35
- f"- {XGB_REPO}\n- {GPT2_REPO}\n"
36
  "Use /docs for interactive testing."
37
  ),
38
- version="1.0.0",
39
  )
40
 
41
- # CORS (relaxed so you can call from browsers / Framer, etc.)
42
  app.add_middleware(
43
  CORSMiddleware,
44
- allow_origins=["*"], # tighten if needed
45
  allow_credentials=True,
46
  allow_methods=["*"],
47
  allow_headers=["*"],
48
  )
49
 
50
-
51
- # -----------------------------
52
- # Model loading
53
- # -----------------------------
54
  def _find_file(dirpath: str, candidates: T.Sequence[str], fallback_exts: T.Sequence[str] = ()) -> str:
55
  for name in candidates:
56
  p = os.path.join(dirpath, name)
@@ -61,20 +57,18 @@ def _find_file(dirpath: str, candidates: T.Sequence[str], fallback_exts: T.Seque
61
  return os.path.join(dirpath, fname)
62
  raise FileNotFoundError(f"Could not find any of {candidates} (or {fallback_exts}) in {dirpath}")
63
 
64
-
65
  @lru_cache(maxsize=1)
66
  def _download_artifacts() -> T.Tuple[str, str]:
67
  if HF_TOKEN:
68
  try:
69
  login(token=HF_TOKEN, add_to_git_credential=True)
70
  except Exception:
71
- # If public, keep going
72
  pass
73
  xgb_local = snapshot_download(repo_id=XGB_REPO, token=HF_TOKEN, revision=None)
74
  gpt_local = snapshot_download(repo_id=GPT2_REPO, token=HF_TOKEN, revision=None)
75
  return xgb_local, gpt_local
76
 
77
-
78
  @lru_cache(maxsize=1)
79
  def _load_models():
80
  xgb_dir, gpt_dir = _download_artifacts()
@@ -108,7 +102,7 @@ def _load_models():
108
  booster = xgb.Booster()
109
  booster.load_model(booster_path)
110
 
111
- # GPT-2 pipeline
112
  device = 0 if torch.cuda.is_available() else -1
113
  text_gen = pipeline(
114
  "text-generation",
@@ -117,13 +111,9 @@ def _load_models():
117
  device=device,
118
  trust_remote_code=False,
119
  )
120
-
121
  return preprocessor, booster, text_gen, xgb_dir
122
 
123
-
124
- # -----------------------------
125
- # Utils
126
- # -----------------------------
127
  def _coerce_numeric(df: pd.DataFrame) -> pd.DataFrame:
128
  out = df.copy()
129
  for c in out.columns:
@@ -134,18 +124,14 @@ def _coerce_numeric(df: pd.DataFrame) -> pd.DataFrame:
134
  pass
135
  return out
136
 
137
-
138
  def _to_dmatrix(df: pd.DataFrame, preprocessor) -> xgb.DMatrix:
139
  X = preprocessor.transform(df)
140
  return xgb.DMatrix(X)
141
 
142
-
143
  def _predict_scores(df: pd.DataFrame, preprocessor, booster) -> np.ndarray:
144
  dmat = _to_dmatrix(df, preprocessor)
145
  scores = booster.predict(dmat)
146
- scores = np.array(scores).reshape(-1)
147
- return scores
148
-
149
 
150
  def _format_prompt(inputs: dict, score: float) -> str:
151
  kv = "; ".join(f"{k}: {v}" for k, v in inputs.items())
@@ -155,11 +141,9 @@ def _format_prompt(inputs: dict, score: float) -> str:
155
  "Summary:"
156
  )
157
 
158
-
159
  def _summarize(inputs: dict, score: float, text_gen) -> str:
160
- prompt = _format_prompt(inputs, score)
161
  out = text_gen(
162
- prompt,
163
  max_new_tokens=120,
164
  do_sample=True,
165
  temperature=0.7,
@@ -169,47 +153,31 @@ def _summarize(inputs: dict, score: float, text_gen) -> str:
169
  )[0]["generated_text"]
170
  return out.split("Summary:", 1)[-1].strip() if "Summary:" in out else out.strip()
171
 
172
-
173
- # -----------------------------
174
- # Schemas
175
- # -----------------------------
176
- class RowDict(BaseModel):
177
- __root__: dict
178
-
179
-
180
  class ScoreRequest(BaseModel):
181
  rows: T.List[dict] = Field(..., description="List of row objects (feature_name -> value).")
182
 
183
-
184
  class ScoreResponse(BaseModel):
185
  scores: T.List[float]
186
 
187
-
188
  class SummarizeRequest(BaseModel):
189
  inputs: dict = Field(..., description="Feature dict for one example.")
190
  score: float = Field(..., description="Readiness score used in the summary.")
191
 
192
-
193
  class SummarizeResponse(BaseModel):
194
  summary: str
195
 
196
-
197
  class ScoreAndSummarizeRequest(BaseModel):
198
  rows: T.List[dict] = Field(..., description="Rows to score and summarize.")
199
 
200
-
201
  class ScoreAndSummarizeItem(BaseModel):
202
  score: float
203
  summary: str
204
 
205
-
206
  class ScoreAndSummarizeResponse(BaseModel):
207
  results: T.List[ScoreAndSummarizeItem]
208
 
209
-
210
- # -----------------------------
211
- # Endpoints
212
- # -----------------------------
213
  @app.get("/health")
214
  def health():
215
  try:
@@ -218,31 +186,21 @@ def health():
218
  except Exception as e:
219
  raise HTTPException(status_code=500, detail=str(e))
220
 
221
-
222
  @app.post("/score", response_model=ScoreResponse)
223
  def score_json(req: ScoreRequest = Body(...)):
224
- """
225
- Score a JSON batch of rows.
226
- """
227
  preprocessor, booster, _, _ = _load_models()
228
  if not req.rows:
229
  raise HTTPException(status_code=400, detail="rows must be non-empty")
230
-
231
  df = pd.DataFrame(req.rows)
232
  df = _coerce_numeric(df)
233
  try:
234
  scores = _predict_scores(df, preprocessor, booster)
235
  except Exception as e:
236
  raise HTTPException(status_code=400, detail=f"Scoring failed: {e}")
237
-
238
  return ScoreResponse(scores=[float(s) for s in scores])
239
 
240
-
241
  @app.post("/score_csv", response_model=ScoreResponse)
242
  async def score_csv(file: UploadFile = File(...)):
243
- """
244
- Score a CSV upload. Returns the scores list in row order.
245
- """
246
  preprocessor, booster, _, _ = _load_models()
247
  try:
248
  content = await file.read()
@@ -253,12 +211,8 @@ async def score_csv(file: UploadFile = File(...)):
253
  raise HTTPException(status_code=400, detail=f"CSV scoring failed: {e}")
254
  return ScoreResponse(scores=[float(s) for s in scores])
255
 
256
-
257
  @app.post("/summarize", response_model=SummarizeResponse)
258
  def summarize(req: SummarizeRequest = Body(...)):
259
- """
260
- Summarize a single example given inputs + score.
261
- """
262
  _, _, text_gen, _ = _load_models()
263
  try:
264
  summary = _summarize(req.inputs, req.score, text_gen)
@@ -266,23 +220,17 @@ def summarize(req: SummarizeRequest = Body(...)):
266
  raise HTTPException(status_code=400, detail=f"Summarization failed: {e}")
267
  return SummarizeResponse(summary=summary)
268
 
269
-
270
  @app.post("/score_and_summarize", response_model=ScoreAndSummarizeResponse)
271
  def score_and_summarize(req: ScoreAndSummarizeRequest = Body(...)):
272
- """
273
- For each row: compute score, then generate a GPT-2 summary.
274
- """
275
  preprocessor, booster, text_gen, _ = _load_models()
276
  if not req.rows:
277
  raise HTTPException(status_code=400, detail="rows must be non-empty")
278
  df = pd.DataFrame(req.rows)
279
  df = _coerce_numeric(df)
280
-
281
  try:
282
  scores = _predict_scores(df, preprocessor, booster)
283
  except Exception as e:
284
  raise HTTPException(status_code=400, detail=f"Scoring failed: {e}")
285
-
286
  results = []
287
  for i, row in enumerate(req.rows):
288
  try:
@@ -291,3 +239,4 @@ def score_and_summarize(req: ScoreAndSummarizeRequest = Body(...)):
291
  summ = f"(summary failed: {e})"
292
  results.append(ScoreAndSummarizeItem(score=float(scores[i]), summary=summ))
293
  return ScoreAndSummarizeResponse(results=results)
 
 
6
  from functools import lru_cache
7
 
8
  import pandas as pd
9
+ from fastapi import FastAPI, File, UploadFile, HTTPException, Body
 
10
  from fastapi.middleware.cors import CORSMiddleware
11
  from pydantic import BaseModel, Field
12
 
 
17
  import torch
18
  from transformers import pipeline
19
 
20
+ # -------- Config --------
21
  HF_TOKEN = (
22
  os.environ.get("HF_TOKEN")
23
  or os.environ.get("HUGGING_FACE_HUB_TOKEN")
 
31
  title="Entrepreneur Readiness API",
32
  description=(
33
  "XGBoost readiness scoring + GPT-2 summarization.\n\n"
34
+ f"Models:\n- {XGB_REPO}\n- {GPT2_REPO}\n"
 
35
  "Use /docs for interactive testing."
36
  ),
37
+ version="1.0.1",
38
  )
39
 
40
+ # CORS (allow all; tighten for production)
41
  app.add_middleware(
42
  CORSMiddleware,
43
+ allow_origins=["*"],
44
  allow_credentials=True,
45
  allow_methods=["*"],
46
  allow_headers=["*"],
47
  )
48
 
49
+ # -------- Model loading --------
 
 
 
50
  def _find_file(dirpath: str, candidates: T.Sequence[str], fallback_exts: T.Sequence[str] = ()) -> str:
51
  for name in candidates:
52
  p = os.path.join(dirpath, name)
 
57
  return os.path.join(dirpath, fname)
58
  raise FileNotFoundError(f"Could not find any of {candidates} (or {fallback_exts}) in {dirpath}")
59
 
 
60
  @lru_cache(maxsize=1)
61
  def _download_artifacts() -> T.Tuple[str, str]:
62
  if HF_TOKEN:
63
  try:
64
  login(token=HF_TOKEN, add_to_git_credential=True)
65
  except Exception:
66
+ # Public models still download
67
  pass
68
  xgb_local = snapshot_download(repo_id=XGB_REPO, token=HF_TOKEN, revision=None)
69
  gpt_local = snapshot_download(repo_id=GPT2_REPO, token=HF_TOKEN, revision=None)
70
  return xgb_local, gpt_local
71
 
 
72
  @lru_cache(maxsize=1)
73
  def _load_models():
74
  xgb_dir, gpt_dir = _download_artifacts()
 
102
  booster = xgb.Booster()
103
  booster.load_model(booster_path)
104
 
105
+ # GPT-2 text generation
106
  device = 0 if torch.cuda.is_available() else -1
107
  text_gen = pipeline(
108
  "text-generation",
 
111
  device=device,
112
  trust_remote_code=False,
113
  )
 
114
  return preprocessor, booster, text_gen, xgb_dir
115
 
116
+ # -------- Utils --------
 
 
 
117
  def _coerce_numeric(df: pd.DataFrame) -> pd.DataFrame:
118
  out = df.copy()
119
  for c in out.columns:
 
124
  pass
125
  return out
126
 
 
127
  def _to_dmatrix(df: pd.DataFrame, preprocessor) -> xgb.DMatrix:
128
  X = preprocessor.transform(df)
129
  return xgb.DMatrix(X)
130
 
 
131
  def _predict_scores(df: pd.DataFrame, preprocessor, booster) -> np.ndarray:
132
  dmat = _to_dmatrix(df, preprocessor)
133
  scores = booster.predict(dmat)
134
+ return np.array(scores).reshape(-1)
 
 
135
 
136
  def _format_prompt(inputs: dict, score: float) -> str:
137
  kv = "; ".join(f"{k}: {v}" for k, v in inputs.items())
 
141
  "Summary:"
142
  )
143
 
 
144
  def _summarize(inputs: dict, score: float, text_gen) -> str:
 
145
  out = text_gen(
146
+ _format_prompt(inputs, score),
147
  max_new_tokens=120,
148
  do_sample=True,
149
  temperature=0.7,
 
153
  )[0]["generated_text"]
154
  return out.split("Summary:", 1)[-1].strip() if "Summary:" in out else out.strip()
155
 
156
+ # -------- Schemas (Pydantic v2) --------
 
 
 
 
 
 
 
157
  class ScoreRequest(BaseModel):
158
  rows: T.List[dict] = Field(..., description="List of row objects (feature_name -> value).")
159
 
 
160
  class ScoreResponse(BaseModel):
161
  scores: T.List[float]
162
 
 
163
  class SummarizeRequest(BaseModel):
164
  inputs: dict = Field(..., description="Feature dict for one example.")
165
  score: float = Field(..., description="Readiness score used in the summary.")
166
 
 
167
  class SummarizeResponse(BaseModel):
168
  summary: str
169
 
 
170
  class ScoreAndSummarizeRequest(BaseModel):
171
  rows: T.List[dict] = Field(..., description="Rows to score and summarize.")
172
 
 
173
  class ScoreAndSummarizeItem(BaseModel):
174
  score: float
175
  summary: str
176
 
 
177
  class ScoreAndSummarizeResponse(BaseModel):
178
  results: T.List[ScoreAndSummarizeItem]
179
 
180
+ # -------- Endpoints --------
 
 
 
181
  @app.get("/health")
182
  def health():
183
  try:
 
186
  except Exception as e:
187
  raise HTTPException(status_code=500, detail=str(e))
188
 
 
189
  @app.post("/score", response_model=ScoreResponse)
190
  def score_json(req: ScoreRequest = Body(...)):
 
 
 
191
  preprocessor, booster, _, _ = _load_models()
192
  if not req.rows:
193
  raise HTTPException(status_code=400, detail="rows must be non-empty")
 
194
  df = pd.DataFrame(req.rows)
195
  df = _coerce_numeric(df)
196
  try:
197
  scores = _predict_scores(df, preprocessor, booster)
198
  except Exception as e:
199
  raise HTTPException(status_code=400, detail=f"Scoring failed: {e}")
 
200
  return ScoreResponse(scores=[float(s) for s in scores])
201
 
 
202
  @app.post("/score_csv", response_model=ScoreResponse)
203
  async def score_csv(file: UploadFile = File(...)):
 
 
 
204
  preprocessor, booster, _, _ = _load_models()
205
  try:
206
  content = await file.read()
 
211
  raise HTTPException(status_code=400, detail=f"CSV scoring failed: {e}")
212
  return ScoreResponse(scores=[float(s) for s in scores])
213
 
 
214
  @app.post("/summarize", response_model=SummarizeResponse)
215
  def summarize(req: SummarizeRequest = Body(...)):
 
 
 
216
  _, _, text_gen, _ = _load_models()
217
  try:
218
  summary = _summarize(req.inputs, req.score, text_gen)
 
220
  raise HTTPException(status_code=400, detail=f"Summarization failed: {e}")
221
  return SummarizeResponse(summary=summary)
222
 
 
223
  @app.post("/score_and_summarize", response_model=ScoreAndSummarizeResponse)
224
  def score_and_summarize(req: ScoreAndSummarizeRequest = Body(...)):
 
 
 
225
  preprocessor, booster, text_gen, _ = _load_models()
226
  if not req.rows:
227
  raise HTTPException(status_code=400, detail="rows must be non-empty")
228
  df = pd.DataFrame(req.rows)
229
  df = _coerce_numeric(df)
 
230
  try:
231
  scores = _predict_scores(df, preprocessor, booster)
232
  except Exception as e:
233
  raise HTTPException(status_code=400, detail=f"Scoring failed: {e}")
 
234
  results = []
235
  for i, row in enumerate(req.rows):
236
  try:
 
239
  summ = f"(summary failed: {e})"
240
  results.append(ScoreAndSummarizeItem(score=float(scores[i]), summary=summ))
241
  return ScoreAndSummarizeResponse(results=results)
242
+