celpri commited on
Commit
fe32675
·
1 Parent(s): c0c3478

fix: load features locally or from HF depending on environment

Browse files
Files changed (1) hide show
  1. src/api/main.py +15 -4
src/api/main.py CHANGED
@@ -7,17 +7,24 @@ import os
7
  from huggingface_hub import hf_hub_download
8
  from src.model.model import load_model
9
 
 
 
 
 
10
 
11
  class ClientID(BaseModel):
12
  sk_id_curr: int
13
 
14
 
 
 
15
  def get_features_by_id(sk_id_curr: int) -> pd.DataFrame:
16
- # CAS TEST / LOCAL : features injectées par le test
17
- if hasattr(app.state, "features") and app.state.features is not None:
18
- df = app.state.features
 
19
  else:
20
- # CAS PROD HF
21
  path = hf_hub_download(
22
  repo_id="PCelia/credit-scoring-model",
23
  filename="features_clients.csv",
@@ -32,6 +39,7 @@ def get_features_by_id(sk_id_curr: int) -> pd.DataFrame:
32
  return row.drop(columns=["SK_ID_CURR"])
33
 
34
 
 
35
  @asynccontextmanager
36
  async def lifespan(app: FastAPI):
37
  app.state.model = load_model()
@@ -40,6 +48,9 @@ async def lifespan(app: FastAPI):
40
 
41
  app = FastAPI(lifespan=lifespan)
42
 
 
 
 
43
 
44
  @app.post("/predict_by_id")
45
  def predict_by_id(payload: ClientID):
 
7
  from huggingface_hub import hf_hub_download
8
  from src.model.model import load_model
9
 
10
+ from fastapi.responses import RedirectResponse
11
+
12
+
13
+
14
 
15
  class ClientID(BaseModel):
16
  sk_id_curr: int
17
 
18
 
19
+ from pathlib import Path
20
+
21
  def get_features_by_id(sk_id_curr: int) -> pd.DataFrame:
22
+ # CAS LOCAL : CSV présent dans le projet
23
+ local_path = Path(__file__).resolve().parents[2] / "Data" / "features_clients.csv"
24
+ if local_path.exists():
25
+ df = pd.read_csv(local_path)
26
  else:
27
+ # CAS HF : téléchargement depuis le Hub
28
  path = hf_hub_download(
29
  repo_id="PCelia/credit-scoring-model",
30
  filename="features_clients.csv",
 
39
  return row.drop(columns=["SK_ID_CURR"])
40
 
41
 
42
+
43
  @asynccontextmanager
44
  async def lifespan(app: FastAPI):
45
  app.state.model = load_model()
 
48
 
49
  app = FastAPI(lifespan=lifespan)
50
 
51
+ @app.get("/")
52
+ def root():
53
+ return RedirectResponse(url="/docs")
54
 
55
  @app.post("/predict_by_id")
56
  def predict_by_id(payload: ClientID):