ethnmcl commited on
Commit
b8b7be0
·
verified ·
1 Parent(s): 2ec6227

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +58 -19
main.py CHANGED
@@ -8,6 +8,7 @@ from functools import lru_cache
8
  import pandas as pd
9
  from fastapi import FastAPI, File, UploadFile, HTTPException, Body
10
  from fastapi.middleware.cors import CORSMiddleware
 
11
  from pydantic import BaseModel, Field
12
 
13
  from huggingface_hub import login, snapshot_download
@@ -15,7 +16,7 @@ import joblib
15
  import xgboost as xgb
16
  import numpy as np
17
  import torch
18
- from transformers import pipeline
19
 
20
  # -------- Config --------
21
  HF_TOKEN = (
@@ -34,7 +35,7 @@ app = FastAPI(
34
  f"Models:\n- {XGB_REPO}\n- {GPT2_REPO}\n"
35
  "Use /docs for interactive testing."
36
  ),
37
- version="1.0.1",
38
  )
39
 
40
  # CORS (allow all; tighten for production)
@@ -46,7 +47,42 @@ app.add_middleware(
46
  allow_headers=["*"],
47
  )
48
 
49
- # -------- Model loading --------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  def _find_file(dirpath: str, candidates: T.Sequence[str], fallback_exts: T.Sequence[str] = ()) -> str:
51
  for name in candidates:
52
  p = os.path.join(dirpath, name)
@@ -63,7 +99,7 @@ def _download_artifacts() -> T.Tuple[str, str]:
63
  try:
64
  login(token=HF_TOKEN, add_to_git_credential=True)
65
  except Exception:
66
- # Public models still download
67
  pass
68
  xgb_local = snapshot_download(repo_id=XGB_REPO, token=HF_TOKEN, revision=None)
69
  gpt_local = snapshot_download(repo_id=GPT2_REPO, token=HF_TOKEN, revision=None)
@@ -73,7 +109,7 @@ def _download_artifacts() -> T.Tuple[str, str]:
73
  def _load_models():
74
  xgb_dir, gpt_dir = _download_artifacts()
75
 
76
- # Preprocessor
77
  preproc_path = _find_file(
78
  xgb_dir,
79
  candidates=[
@@ -86,7 +122,7 @@ def _load_models():
86
  )
87
  preprocessor = joblib.load(preproc_path)
88
 
89
- # Booster
90
  booster_path = _find_file(
91
  xgb_dir,
92
  candidates=[
@@ -102,15 +138,25 @@ def _load_models():
102
  booster = xgb.Booster()
103
  booster.load_model(booster_path)
104
 
105
- # GPT-2 text generation
106
  device = 0 if torch.cuda.is_available() else -1
 
 
 
 
 
 
 
 
 
107
  text_gen = pipeline(
108
  "text-generation",
109
  model=gpt_dir,
110
- tokenizer=gpt_dir,
111
  device=device,
112
  trust_remote_code=False,
113
  )
 
114
  return preprocessor, booster, text_gen, xgb_dir
115
 
116
  # -------- Utils --------
@@ -142,16 +188,17 @@ def _format_prompt(inputs: dict, score: float) -> str:
142
  )
143
 
144
  def _summarize(inputs: dict, score: float, text_gen) -> str:
145
- out = text_gen(
146
  _format_prompt(inputs, score),
147
  max_new_tokens=120,
148
  do_sample=True,
149
  temperature=0.7,
150
  top_p=0.9,
151
  num_return_sequences=1,
152
- eos_token_id=None,
 
153
  )[0]["generated_text"]
154
- return out.split("Summary:", 1)[-1].strip() if "Summary:" in out else out.strip()
155
 
156
  # -------- Schemas (Pydantic v2) --------
157
  class ScoreRequest(BaseModel):
@@ -178,14 +225,6 @@ class ScoreAndSummarizeResponse(BaseModel):
178
  results: T.List[ScoreAndSummarizeItem]
179
 
180
  # -------- Endpoints --------
181
- @app.get("/health")
182
- def health():
183
- try:
184
- _load_models()
185
- return {"ok": True}
186
- except Exception as e:
187
- raise HTTPException(status_code=500, detail=str(e))
188
-
189
  @app.post("/score", response_model=ScoreResponse)
190
  def score_json(req: ScoreRequest = Body(...)):
191
  preprocessor, booster, _, _ = _load_models()
 
8
  import pandas as pd
9
  from fastapi import FastAPI, File, UploadFile, HTTPException, Body
10
  from fastapi.middleware.cors import CORSMiddleware
11
+ from fastapi.responses import JSONResponse, RedirectResponse
12
  from pydantic import BaseModel, Field
13
 
14
  from huggingface_hub import login, snapshot_download
 
16
  import xgboost as xgb
17
  import numpy as np
18
  import torch
19
+ from transformers import AutoTokenizer, pipeline
20
 
21
  # -------- Config --------
22
  HF_TOKEN = (
 
35
  f"Models:\n- {XGB_REPO}\n- {GPT2_REPO}\n"
36
  "Use /docs for interactive testing."
37
  ),
38
+ version="1.1.0",
39
  )
40
 
41
  # CORS (allow all; tighten for production)
 
47
  allow_headers=["*"],
48
  )
49
 
50
+ # -------- Convenience root & health --------
51
+ @app.get("/", include_in_schema=False)
52
+ def root():
53
+ return JSONResponse(
54
+ {
55
+ "ok": True,
56
+ "message": "Entrepreneur Readiness API is running.",
57
+ "docs": "/docs",
58
+ "endpoints": ["/health", "/readiness", "/score", "/score_csv", "/summarize", "/score_and_summarize"],
59
+ }
60
+ )
61
+
62
+ # Liveness-only (no model load)
63
+ @app.get("/health", include_in_schema=False)
64
+ def health():
65
+ return JSONResponse({"ok": True, "status": "live", "docs": "/docs"})
66
+
67
+ # Readiness (loads models)
68
+ @app.get("/readiness")
69
+ def readiness():
70
+ try:
71
+ _load_models()
72
+ return {"ok": True, "status": "ready"}
73
+ except Exception as e:
74
+ return JSONResponse({"ok": False, "status": "not_ready", "error": str(e)}, status_code=503)
75
+
76
+ # Optional warm-up to trigger downloads/caching
77
+ @app.post("/warmup", include_in_schema=False)
78
+ def warmup():
79
+ try:
80
+ _load_models()
81
+ return {"ok": True, "warmed": True}
82
+ except Exception as e:
83
+ return JSONResponse({"ok": False, "error": str(e)}, status_code=500)
84
+
85
+ # -------- Model loading helpers --------
86
  def _find_file(dirpath: str, candidates: T.Sequence[str], fallback_exts: T.Sequence[str] = ()) -> str:
87
  for name in candidates:
88
  p = os.path.join(dirpath, name)
 
99
  try:
100
  login(token=HF_TOKEN, add_to_git_credential=True)
101
  except Exception:
102
+ # Continue if public
103
  pass
104
  xgb_local = snapshot_download(repo_id=XGB_REPO, token=HF_TOKEN, revision=None)
105
  gpt_local = snapshot_download(repo_id=GPT2_REPO, token=HF_TOKEN, revision=None)
 
109
  def _load_models():
110
  xgb_dir, gpt_dir = _download_artifacts()
111
 
112
+ # ---- Preprocessor ----
113
  preproc_path = _find_file(
114
  xgb_dir,
115
  candidates=[
 
122
  )
123
  preprocessor = joblib.load(preproc_path)
124
 
125
+ # ---- XGB booster ----
126
  booster_path = _find_file(
127
  xgb_dir,
128
  candidates=[
 
138
  booster = xgb.Booster()
139
  booster.load_model(booster_path)
140
 
141
+ # ---- GPT-2 text generation: robust tokenizer selection ----
142
  device = 0 if torch.cuda.is_available() else -1
143
+ try:
144
+ tok = AutoTokenizer.from_pretrained(gpt_dir, use_fast=True, trust_remote_code=False)
145
+ except Exception:
146
+ # Fallback for "ModelWrapper" tokenizer.json parse errors
147
+ tok = AutoTokenizer.from_pretrained(gpt_dir, use_fast=False, trust_remote_code=False)
148
+ # Ensure a pad token (map to eos if absent) to avoid generation warnings/errors
149
+ if tok.pad_token is None and tok.eos_token is not None:
150
+ tok.pad_token = tok.eos_token
151
+
152
  text_gen = pipeline(
153
  "text-generation",
154
  model=gpt_dir,
155
+ tokenizer=tok,
156
  device=device,
157
  trust_remote_code=False,
158
  )
159
+
160
  return preprocessor, booster, text_gen, xgb_dir
161
 
162
  # -------- Utils --------
 
188
  )
189
 
190
  def _summarize(inputs: dict, score: float, text_gen) -> str:
191
+ generated = text_gen(
192
  _format_prompt(inputs, score),
193
  max_new_tokens=120,
194
  do_sample=True,
195
  temperature=0.7,
196
  top_p=0.9,
197
  num_return_sequences=1,
198
+ eos_token_id=text_gen.tokenizer.eos_token_id,
199
+ pad_token_id=text_gen.tokenizer.eos_token_id,
200
  )[0]["generated_text"]
201
+ return generated.split("Summary:", 1)[-1].strip() if "Summary:" in generated else generated.strip()
202
 
203
  # -------- Schemas (Pydantic v2) --------
204
  class ScoreRequest(BaseModel):
 
225
  results: T.List[ScoreAndSummarizeItem]
226
 
227
  # -------- Endpoints --------
 
 
 
 
 
 
 
 
228
  @app.post("/score", response_model=ScoreResponse)
229
  def score_json(req: ScoreRequest = Body(...)):
230
  preprocessor, booster, _, _ = _load_models()