marintosti12 commited on
Commit
d61744e
·
unverified ·
2 Parent(s): 8afccd5fbfb5a1

Merge pull request #39 from marintosti12/feature/loader-model-env

Browse files
Files changed (3) hide show
  1. .github/workflows/ci.yaml +3 -0
  2. README.md +1 -1
  3. src/model_loader.py +20 -3
.github/workflows/ci.yaml CHANGED
@@ -18,6 +18,7 @@ jobs:
18
  POSTGRES_USER: ci
19
  POSTGRES_PASSWORD: ci
20
  POSTGRES_DB: futurisys_ci
 
21
  ports:
22
  - 5432:5432
23
  options: >-
@@ -29,6 +30,8 @@ jobs:
29
  env:
30
  DATABASE_URL: postgresql+asyncpg://ci:ci@localhost:5432/futurisys_ci
31
  HF_TOKEN: ${{ secrets.HF_TOKEN }}
 
 
32
 
33
  steps:
34
  - name: Checkout
 
18
  POSTGRES_USER: ci
19
  POSTGRES_PASSWORD: ci
20
  POSTGRES_DB: futurisys_ci
21
+
22
  ports:
23
  - 5432:5432
24
  options: >-
 
30
  env:
31
  DATABASE_URL: postgresql+asyncpg://ci:ci@localhost:5432/futurisys_ci
32
  HF_TOKEN: ${{ secrets.HF_TOKEN }}
33
+ APP_ENV: test
34
+
35
 
36
  steps:
37
  - name: Checkout
README.md CHANGED
@@ -58,7 +58,7 @@ HF_TOKEN= Token Hugging Face
58
  ### 4. Base de données (PostgreSQL)
59
 
60
  ~~~bash
61
- docker compose up -d
62
  ~~~
63
 
64
 
 
58
  ### 4. Base de données (PostgreSQL)
59
 
60
  ~~~bash
61
+ sudo docker compose up -d
62
  ~~~
63
 
64
 
src/model_loader.py CHANGED
@@ -1,18 +1,35 @@
1
  import os
2
  from functools import lru_cache
3
- from typing import Any
 
4
  from huggingface_hub import hf_hub_download
5
  import joblib
6
 
7
  HF_REPO_ID = os.getenv("HF_REPO_ID", "Marintosti/attrition")
8
  HF_TOKEN = os.getenv("HF_TOKEN")
9
 
 
 
 
 
 
 
 
 
 
 
 
10
  @lru_cache(maxsize=1)
11
  def load_model(name) -> Any:
12
- local_path = hf_hub_download(
 
 
 
 
13
  repo_id=HF_REPO_ID,
14
  filename=f"{name}.joblib",
15
  token=HF_TOKEN,
16
  local_files_only=False,
17
  )
18
- return joblib.load(local_path)
 
 
1
  import os
2
  from functools import lru_cache
3
+ from pathlib import Path
4
+ from typing import Any, Literal
5
  from huggingface_hub import hf_hub_download
6
  import joblib
7
 
8
  HF_REPO_ID = os.getenv("HF_REPO_ID", "Marintosti/attrition")
9
  HF_TOKEN = os.getenv("HF_TOKEN")
10
 
11
+ ENV: Literal["dev", "test", "prod"] = os.getenv("APP_ENV", "dev").lower()
12
+ ARTIFACTS_DIR = Path(os.getenv("ARTIFACTS_DIR", "artifacts"))
13
+
14
+ def _load_local(name: str) -> Any:
15
+ path = ARTIFACTS_DIR / f"{name}.joblib"
16
+ if not path.exists():
17
+ raise FileNotFoundError(
18
+ f"Modèle local introuvable: {path}. "
19
+ )
20
+ return joblib.load(path)
21
+
22
  @lru_cache(maxsize=1)
23
  def load_model(name) -> Any:
24
+
25
+ if ENV in ("dev"):
26
+ return _load_local(name)
27
+
28
+ hf_path = hf_hub_download(
29
  repo_id=HF_REPO_ID,
30
  filename=f"{name}.joblib",
31
  token=HF_TOKEN,
32
  local_files_only=False,
33
  )
34
+
35
+ return joblib.load(hf_path)