marintosti12 commited on
Commit
84394db
·
1 Parent(s): 8afccd5

feat (model loader) : update loader for model in local

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. src/model_loader.py +20 -3
README.md CHANGED
@@ -58,7 +58,7 @@ HF_TOKEN= Token Hugging Face
58
  ### 4. Base de données (PostgreSQL)
59
 
60
  ~~~bash
61
- docker compose up -d
62
  ~~~
63
 
64
 
 
58
  ### 4. Base de données (PostgreSQL)
59
 
60
  ~~~bash
61
+ sudo docker compose up -d
62
  ~~~
63
 
64
 
src/model_loader.py CHANGED
@@ -1,18 +1,35 @@
1
  import os
2
  from functools import lru_cache
3
- from typing import Any
 
4
  from huggingface_hub import hf_hub_download
5
  import joblib
6
 
7
  HF_REPO_ID = os.getenv("HF_REPO_ID", "Marintosti/attrition")
8
  HF_TOKEN = os.getenv("HF_TOKEN")
9
 
 
 
 
 
 
 
 
 
 
 
 
10
  @lru_cache(maxsize=1)
11
  def load_model(name) -> Any:
12
- local_path = hf_hub_download(
 
 
 
 
13
  repo_id=HF_REPO_ID,
14
  filename=f"{name}.joblib",
15
  token=HF_TOKEN,
16
  local_files_only=False,
17
  )
18
- return joblib.load(local_path)
 
 
1
  import os
2
  from functools import lru_cache
3
+ from pathlib import Path
4
+ from typing import Any, Literal
5
  from huggingface_hub import hf_hub_download
6
  import joblib
7
 
8
  HF_REPO_ID = os.getenv("HF_REPO_ID", "Marintosti/attrition")
9
  HF_TOKEN = os.getenv("HF_TOKEN")
10
 
11
+ ENV: Literal["dev", "test", "prod"] = os.getenv("APP_ENV", "dev").lower()
12
+ ARTIFACTS_DIR = Path(os.getenv("ARTIFACTS_DIR", "artifacts"))
13
+
14
+ def _load_local(name: str) -> Any:
15
+ path = ARTIFACTS_DIR / f"{name}.joblib"
16
+ if not path.exists():
17
+ raise FileNotFoundError(
18
+ f"Modèle local introuvable: {path}. "
19
+ )
20
+ return joblib.load(path)
21
+
22
  @lru_cache(maxsize=1)
23
  def load_model(name) -> Any:
24
+
25
+ if ENV in ("dev", "test"):
26
+ return _load_local(name)
27
+
28
+ hf_path = hf_hub_download(
29
  repo_id=HF_REPO_ID,
30
  filename=f"{name}.joblib",
31
  token=HF_TOKEN,
32
  local_files_only=False,
33
  )
34
+
35
+ return joblib.load(hf_path)