Olivier-52 commited on
Commit
c3530e3
·
1 Parent(s): 2c009f0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -6
app.py CHANGED
@@ -30,16 +30,16 @@ tags_metadata = [
30
 
31
  load_dotenv()
32
 
33
- # Variables MLflow : URI de tracking, nom du modèle et stage
34
  MLFLOW_TRACKING_APP_URI = os.getenv("MLFLOW_TRACKING_APP_URI")
35
  MODEL_NAME = os.getenv("MODEL_NAME", "fraud_detection_dtc")
36
  STAGE = os.getenv("STAGE", "production")
37
 
38
- # Variables AWS pour accéder au bucket S3 qui contient les artifacts de MLflow
39
  os.environ["AWS_ACCESS_KEY_ID"] = os.getenv("AWS_ACCESS_KEY_ID")
40
  os.environ["AWS_SECRET_ACCESS_KEY"] = os.getenv("AWS_SECRET_ACCESS_KEY")
41
 
42
- # Variables globales pour stocker le modèle
43
  mlflow.set_tracking_uri(MLFLOW_TRACKING_APP_URI)
44
  model_uri = f"models:/{MODEL_NAME}@{STAGE}"
45
  model = mlflow.sklearn.load_model(model_uri)
@@ -95,7 +95,7 @@ def predict(features: PredictionFeatures):
95
  Returns:
96
  dict: A dictionary containing the prediction of whether the payment is fraud or not
97
  """
98
- df = pd.DataFrame({
99
  "category" : [features.category],
100
  "amt" : [features.amt],
101
  "merch_fraud_level" : [features.merch_fraud_level],
@@ -107,8 +107,7 @@ def predict(features: PredictionFeatures):
107
  })
108
 
109
  try:
110
- df = pd.DataFrame({...})
111
- return model.predict(df)[0]
112
 
113
  except Exception as e:
114
  return {"error": str(e)}
 
30
 
31
  load_dotenv()
32
 
33
+ # Mlflow variables
34
  MLFLOW_TRACKING_APP_URI = os.getenv("MLFLOW_TRACKING_APP_URI")
35
  MODEL_NAME = os.getenv("MODEL_NAME", "fraud_detection_dtc")
36
  STAGE = os.getenv("STAGE", "production")
37
 
38
+ # AWS variables
39
  os.environ["AWS_ACCESS_KEY_ID"] = os.getenv("AWS_ACCESS_KEY_ID")
40
  os.environ["AWS_SECRET_ACCESS_KEY"] = os.getenv("AWS_SECRET_ACCESS_KEY")
41
 
42
+ # Load model
43
  mlflow.set_tracking_uri(MLFLOW_TRACKING_APP_URI)
44
  model_uri = f"models:/{MODEL_NAME}@{STAGE}"
45
  model = mlflow.sklearn.load_model(model_uri)
 
95
  Returns:
96
  dict: A dictionary containing the prediction of whether the payment is fraud or not
97
  """
98
+ data = pd.DataFrame({
99
  "category" : [features.category],
100
  "amt" : [features.amt],
101
  "merch_fraud_level" : [features.merch_fraud_level],
 
107
  })
108
 
109
  try:
110
+ return model.predict(data)[0]
 
111
 
112
  except Exception as e:
113
  return {"error": str(e)}