Mekam commited on
Commit
bb6807b
·
1 Parent(s): 0d743d9

fix(l3_classifier): select only the model feature

Browse files
Files changed (1) hide show
  1. src/agents/l3_classifier.py +32 -12
src/agents/l3_classifier.py CHANGED
@@ -13,17 +13,37 @@ class Classifier:
13
  except Exception as e:
14
  raise HTTPException(status_code=500, detail=f"Erreur lors du chargement du modèle: {e}")
15
 
16
- def predict(self, data):
17
- try:
18
- # Préparer les features
19
- X = data[self.features]
20
 
21
- # Standardisation si scaler présent
22
- if self.scaler is not None:
23
- X = self.scaler.transform(X)
24
 
25
- # Prédictions
26
- preds = self.model.predict(X)
27
- return preds.tolist()
28
- except Exception as e:
29
- raise HTTPException(status_code=500, detail=f"Erreur lors de la prédiction: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  except Exception as e:
14
  raise HTTPException(status_code=500, detail=f"Erreur lors du chargement du modèle: {e}")
15
 
16
+ def predict(self, data):
17
+ try:
18
+ if self.features is None:
19
+ raise HTTPException(status_code=500, detail="Liste des features du modèle introuvable")
20
 
21
+ # Sélectionner uniquement les features utilisées lors de l'entraînement
22
+ X = data[self.features]
 
23
 
24
+ # Standardisation si scaler présent
25
+ if self.scaler is not None:
26
+ X = self.scaler.transform(X)
27
+
28
+ # Prédictions
29
+ preds = self.model.predict(X)
30
+ return preds.tolist()
31
+ except KeyError as e:
32
+ raise HTTPException(status_code=500, detail=f"Colonne manquante dans les données: {e}")
33
+ except Exception as e:
34
+ raise HTTPException(status_code=500, detail=f"Erreur lors de la prédiction: {e}")
35
+
36
+ # def predict(self, data):
37
+ # try:
38
+ # # Préparer les features
39
+ # X = data[self.features]
40
+
41
+ # # Standardisation si scaler présent
42
+ # if self.scaler is not None:
43
+ # X = self.scaler.transform(X)
44
+
45
+ # # Prédictions
46
+ # preds = self.model.predict(X)
47
+ # return preds.tolist()
48
+ # except Exception as e:
49
+ # raise HTTPException(status_code=500, detail=f"Erreur lors de la prédiction: {e}")