Olivier-52 commited on
Commit
67af143
·
1 Parent(s): 426a949

Update app.py

Browse files

Change API to manage model and pipeline prediction

Files changed (1) hide show
  1. app.py +30 -11
app.py CHANGED
@@ -4,6 +4,7 @@ import pandas as pd
4
  import numpy as np
5
  from pydantic import BaseModel
6
  from fastapi import FastAPI, HTTPException, status, File, UploadFile
 
7
  import mlflow
8
  from dotenv import load_dotenv
9
 
@@ -43,16 +44,26 @@ os.environ["AWS_SECRET_ACCESS_KEY"] = os.getenv("AWS_SECRET_ACCESS_KEY")
43
  # Variables globales pour stocker le modèle
44
  mlflow.set_tracking_uri(MLFLOW_TRACKING_APP_URI)
45
  model_uri = f"models:/{MODEL_NAME}@{STAGE}"
 
 
 
46
 
47
  # Chargement conditionnel du modèle
48
  try:
49
  # Essayer de charger un modèle scikit-learn
50
  model = mlflow.sklearn.load_model(model_uri)
 
51
  print("Modèle scikit-learn chargé avec succès.")
 
52
  except mlflow.exceptions.MlflowException:
53
  try:
54
  model = mlflow.pytorch.load_model(model_uri)
 
55
  print("Modèle PyTorch chargé avec succès.")
 
 
 
 
56
  except mlflow.exceptions.MlflowException as e:
57
  raise HTTPException(
58
  status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@@ -80,19 +91,27 @@ class TextInput(BaseModel):
80
  def predict(features: TextInput):
81
  """
82
  Fait une prédiction sur un texte donné en utilisant le modèle chargé.
83
-
84
- Args:
85
- features (TextInput): Objet contenant le texte à prédire.
86
-
87
- Returns:
88
- dict: Dictionnaire contenant la prédiction.
89
  """
90
  try:
91
- df = pd.DataFrame({"Text": [features.text]})
92
-
93
- prediction = model.predict(df)[0]
94
-
95
- return {"prediction": int(prediction)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  except Exception as e:
98
  raise HTTPException(
 
4
  import numpy as np
5
  from pydantic import BaseModel
6
  from fastapi import FastAPI, HTTPException, status, File, UploadFile
7
+ from transformers import pipeline
8
  import mlflow
9
  from dotenv import load_dotenv
10
 
 
44
  # Variables globales pour stocker le modèle
45
  mlflow.set_tracking_uri(MLFLOW_TRACKING_APP_URI)
46
  model_uri = f"models:/{MODEL_NAME}@{STAGE}"
47
+ model = None
48
+ model_type = None # "sklearn" ou "pytorch"
49
+ classifier = None # Pour le pipeline Hugging Face
50
 
51
  # Chargement conditionnel du modèle
52
  try:
53
  # Essayer de charger un modèle scikit-learn
54
  model = mlflow.sklearn.load_model(model_uri)
55
+ model_type = "sklearn"
56
  print("Modèle scikit-learn chargé avec succès.")
57
+
58
  except mlflow.exceptions.MlflowException:
59
  try:
60
  model = mlflow.pytorch.load_model(model_uri)
61
+ model_type = "pytorch"
62
  print("Modèle PyTorch chargé avec succès.")
63
+ classifier = pipeline(task="text-classification",
64
+ model=model,
65
+ tokenizer="camembert-base")
66
+
67
  except mlflow.exceptions.MlflowException as e:
68
  raise HTTPException(
69
  status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
 
91
  def predict(features: TextInput):
92
  """
93
  Fait une prédiction sur un texte donné en utilisant le modèle chargé.
 
 
 
 
 
 
94
  """
95
  try:
96
+ if model_type == "sklearn":
97
+ # Cas scikit-learn : prédiction directe
98
+ df = pd.DataFrame({"Text": [features.text]})
99
+ prediction = model.predict(df)[0]
100
+ return {"prediction": int(prediction)}
101
+
102
+ elif model_type == "pytorch":
103
+ # Cas PyTorch (transformers) : utiliser le pipeline
104
+ result = classifier(features.text)
105
+ return {
106
+ "prediction": result[0]["label"],
107
+ "score": result[0]["score"]
108
+ }
109
+
110
+ else:
111
+ raise HTTPException(
112
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
113
+ detail="Aucun modèle chargé ou type de modèle non reconnu."
114
+ )
115
 
116
  except Exception as e:
117
  raise HTTPException(