Aurélie GABU commited on
Commit
8344437
·
1 Parent(s): 9f8ea7b

docs: predict file updated

Browse files
Files changed (1) hide show
  1. App/predict.py +7 -8
App/predict.py CHANGED
@@ -7,10 +7,6 @@ from huggingface_hub import hf_hub_download
7
 
8
  MODEL_REPO = "Diaure/xgb_model"
9
 
10
- # Chemin des fichiers
11
- chemin_model = Path(hf_hub_download(repo_id=MODEL_REPO, filename="modele_final_xgb.joblib"))
12
- chemin_mapping = Path(hf_hub_download(repo_id=MODEL_REPO, filename="mapping_classes.json"))
13
-
14
  # Variables chargées
15
  model = None
16
  classes_mapping = None
@@ -21,14 +17,17 @@ Features = list(EmployeeFeatures.model_fields.keys())
21
  # Chargement des fichiers: fonction pour charger le modèle, le mapping afin de permettre à l'API de démarrer m^me si les éléments ne sont pas présents
22
  def files_load():
23
  global model, classes_mapping
 
24
  if model is None:
25
- if not chemin_model.exists():
26
- raise RuntimeError("Eléments du modèle introuvable.")
 
27
  model =joblib.load(chemin_model)
28
 
29
  if classes_mapping is None:
30
- if not chemin_mapping.exists():
31
- raise RuntimeError("Mapping des classes introuvable.")
 
32
  with open(chemin_mapping) as f:
33
  classes_mapping = json.load(f)
34
 
 
7
 
8
  MODEL_REPO = "Diaure/xgb_model"
9
 
 
 
 
 
10
  # Variables chargées
11
  model = None
12
  classes_mapping = None
 
17
  # Chargement des fichiers: fonction pour charger le modèle, le mapping afin de permettre à l'API de démarrer m^me si les éléments ne sont pas présents
18
  def files_load():
19
  global model, classes_mapping
20
+
21
  if model is None:
22
+ chemin_model = Path(hf_hub_download(repo_id=MODEL_REPO, filename="modele_final_xgb.joblib"))
23
+ # if not chemin_model.exists():
24
+ # raise RuntimeError("Eléments du modèle introuvable.")
25
  model =joblib.load(chemin_model)
26
 
27
  if classes_mapping is None:
28
+ chemin_mapping = Path(hf_hub_download(repo_id=MODEL_REPO, filename="mapping_classes.json"))
29
+ # if not chemin_mapping.exists():
30
+ # raise RuntimeError("Mapping des classes introuvable.")
31
  with open(chemin_mapping) as f:
32
  classes_mapping = json.load(f)
33