learnlanguage / backend /app /services /cefr_predictor.py
Zinebhm's picture
Update backend/app/services/cefr_predictor.py
5db1d57 verified
raw
history blame contribute delete
885 Bytes
import os
from pathlib import Path
import joblib
MODEL_PATH = Path(os.getenv("CEFR_MODEL_PATH", "/app/ml/cefr_model.pkl"))
HF_REPO_ID = os.getenv("CEFR_MODEL_REPO", "Zinebhm/cefr-model-learnlanguage")
HF_FILENAME = os.getenv("CEFR_MODEL_FILENAME", "cefr_model.pkl")
class CEFRPredictor:
def __init__(self):
if not MODEL_PATH.exists():
MODEL_PATH.parent.mkdir(parents=True, exist_ok=True)
# download from Hugging Face
from huggingface_hub import hf_hub_download
downloaded = hf_hub_download(
repo_id=HF_REPO_ID,
filename=HF_FILENAME,
repo_type="model",
)
# copy to expected path
os.replace(downloaded, MODEL_PATH)
self.model = joblib.load(MODEL_PATH)
def predict(self, text: str):
return self.model.predict([text])[0]