Spaces:
Sleeping
Sleeping
Merge pull request #39 from marintosti12/feature/loader-model-env
Browse files- .github/workflows/ci.yaml +3 -0
- README.md +1 -1
- src/model_loader.py +20 -3
.github/workflows/ci.yaml
CHANGED
|
@@ -18,6 +18,7 @@ jobs:
|
|
| 18 |
POSTGRES_USER: ci
|
| 19 |
POSTGRES_PASSWORD: ci
|
| 20 |
POSTGRES_DB: futurisys_ci
|
|
|
|
| 21 |
ports:
|
| 22 |
- 5432:5432
|
| 23 |
options: >-
|
|
@@ -29,6 +30,8 @@ jobs:
|
|
| 29 |
env:
|
| 30 |
DATABASE_URL: postgresql+asyncpg://ci:ci@localhost:5432/futurisys_ci
|
| 31 |
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
|
|
|
|
|
|
| 32 |
|
| 33 |
steps:
|
| 34 |
- name: Checkout
|
|
|
|
| 18 |
POSTGRES_USER: ci
|
| 19 |
POSTGRES_PASSWORD: ci
|
| 20 |
POSTGRES_DB: futurisys_ci
|
| 21 |
+
|
| 22 |
ports:
|
| 23 |
- 5432:5432
|
| 24 |
options: >-
|
|
|
|
| 30 |
env:
|
| 31 |
DATABASE_URL: postgresql+asyncpg://ci:ci@localhost:5432/futurisys_ci
|
| 32 |
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
| 33 |
+
APP_ENV: test
|
| 34 |
+
|
| 35 |
|
| 36 |
steps:
|
| 37 |
- name: Checkout
|
README.md
CHANGED
|
@@ -58,7 +58,7 @@ HF_TOKEN= Token Hugging Face
|
|
| 58 |
### 4. Base de données (PostgreSQL)
|
| 59 |
|
| 60 |
~~~bash
|
| 61 |
-
docker compose up -d
|
| 62 |
~~~
|
| 63 |
|
| 64 |
|
|
|
|
| 58 |
### 4. Base de données (PostgreSQL)
|
| 59 |
|
| 60 |
~~~bash
|
| 61 |
+
sudo docker compose up -d
|
| 62 |
~~~
|
| 63 |
|
| 64 |
|
src/model_loader.py
CHANGED
|
@@ -1,18 +1,35 @@
|
|
| 1 |
import os
|
| 2 |
from functools import lru_cache
|
| 3 |
-
from
|
|
|
|
| 4 |
from huggingface_hub import hf_hub_download
|
| 5 |
import joblib
|
| 6 |
|
| 7 |
HF_REPO_ID = os.getenv("HF_REPO_ID", "Marintosti/attrition")
|
| 8 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
@lru_cache(maxsize=1)
|
| 11 |
def load_model(name) -> Any:
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
repo_id=HF_REPO_ID,
|
| 14 |
filename=f"{name}.joblib",
|
| 15 |
token=HF_TOKEN,
|
| 16 |
local_files_only=False,
|
| 17 |
)
|
| 18 |
-
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
from functools import lru_cache
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Any, Literal
|
| 5 |
from huggingface_hub import hf_hub_download
|
| 6 |
import joblib
|
| 7 |
|
| 8 |
HF_REPO_ID = os.getenv("HF_REPO_ID", "Marintosti/attrition")
|
| 9 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 10 |
|
| 11 |
+
ENV: Literal["dev", "test", "prod"] = os.getenv("APP_ENV", "dev").lower()
|
| 12 |
+
ARTIFACTS_DIR = Path(os.getenv("ARTIFACTS_DIR", "artifacts"))
|
| 13 |
+
|
| 14 |
+
def _load_local(name: str) -> Any:
|
| 15 |
+
path = ARTIFACTS_DIR / f"{name}.joblib"
|
| 16 |
+
if not path.exists():
|
| 17 |
+
raise FileNotFoundError(
|
| 18 |
+
f"Modèle local introuvable: {path}. "
|
| 19 |
+
)
|
| 20 |
+
return joblib.load(path)
|
| 21 |
+
|
| 22 |
@lru_cache(maxsize=1)
|
| 23 |
def load_model(name) -> Any:
|
| 24 |
+
|
| 25 |
+
if ENV in ("dev"):
|
| 26 |
+
return _load_local(name)
|
| 27 |
+
|
| 28 |
+
hf_path = hf_hub_download(
|
| 29 |
repo_id=HF_REPO_ID,
|
| 30 |
filename=f"{name}.joblib",
|
| 31 |
token=HF_TOKEN,
|
| 32 |
local_files_only=False,
|
| 33 |
)
|
| 34 |
+
|
| 35 |
+
return joblib.load(hf_path)
|