Olivier-52 commited on
Commit
f9d87b8
·
1 Parent(s): c3da7b0

FastApi_v2

Browse files

Manage both scikit-learn and transformers models

Files changed (3) hide show
  1. Dockerfile +7 -3
  2. app.py +72 -100
  3. requirements.txt +9 -7
Dockerfile CHANGED
@@ -2,15 +2,19 @@ FROM python:3.10
2
 
3
  WORKDIR /home/app
4
 
5
- RUN apt-get update -y
6
- RUN apt-get install nano unzip -y
7
- RUN apt install curl -y
8
 
9
  RUN curl -fsSL https://get.deta.dev/cli.sh | sh
10
 
 
 
11
  COPY requirements.txt /dependencies/requirements.txt
12
  RUN pip install -r /dependencies/requirements.txt
13
 
14
  COPY . /home/app
15
 
 
 
16
  CMD gunicorn app:app --bind 0.0.0.0:$PORT --worker-class uvicorn.workers.UvicornWorker
 
2
 
3
  WORKDIR /home/app
4
 
5
+ RUN apt-get update -y && \
6
+ apt-get install -y nano unzip libgl1 curl && \
7
+ rm -rf /var/lib/apt/lists/*
8
 
9
  RUN curl -fsSL https://get.deta.dev/cli.sh | sh
10
 
11
+ RUN useradd -m appuser
12
+
13
  COPY requirements.txt /dependencies/requirements.txt
14
  RUN pip install -r /dependencies/requirements.txt
15
 
16
  COPY . /home/app
17
 
18
+ USER appuser
19
+
20
  CMD gunicorn app:app --bind 0.0.0.0:$PORT --worker-class uvicorn.workers.UvicornWorker
app.py CHANGED
@@ -1,130 +1,103 @@
1
  import os
2
- import mlflow
3
- import pickle
4
- from fastapi import FastAPI, HTTPException, status
5
  from pydantic import BaseModel
 
 
6
  from dotenv import load_dotenv
7
- from typing import Optional
8
- import asyncio
9
- from contextlib import asynccontextmanager
10
 
11
- # Charge les variables d'environnement
12
- load_dotenv()
13
 
14
- # Configuration des variables d'environnement
15
- MLFLOW_TRACKING_APP_URI = os.getenv("MLFLOW_TRACKING_APP_URI")
16
- MODEL_NAME = os.getenv("MODEL_NAME")
17
- STAGE = os.getenv("STAGE")
18
 
19
- # Configure les identifiants AWS pour accéder au bucket S3
20
- os.environ["AWS_ACCESS_KEY_ID"] = os.getenv("AWS_ACCESS_KEY_ID")
21
- os.environ["AWS_SECRET_ACCESS_KEY"] = os.getenv("AWS_SECRET_ACCESS_KEY")
22
 
23
- # Variables globales pour stocker le modèle et le vectorizer
24
- model = None
25
- vectorizer = None
26
 
27
- # Fonction pour charger le modèle depuis MLflow
28
- def load_model():
29
- global model
30
- try:
31
- # Configure l'URI de tracking MLflow
32
- mlflow.set_tracking_uri(MLFLOW_TRACKING_APP_URI)
33
 
34
- # Charge le modèle depuis MLflow
35
- model_uri = f"models:/{MODEL_NAME}@{STAGE}"
36
- model = mlflow.sklearn.load_model(model_uri)
37
- print("Modèle chargé avec succès depuis MLflow.")
38
- except Exception as e:
39
- print(f"Erreur lors du chargement du modèle depuis MLflow : {e}")
40
- raise HTTPException(
41
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
42
- detail=f"Impossible de charger le modèle depuis MLflow : {e}"
43
- )
44
-
45
- # Fonction pour charger le vectorizer depuis MLflow
46
- def load_vectorizer():
47
- try:
48
- # Initialise le client MLflow
49
- client = mlflow.MlflowClient(MLFLOW_TRACKING_APP_URI)
50
-
51
- # Récupère les informations sur le modèle
52
- model_info = client.get_model_version_by_alias(MODEL_NAME, STAGE)
53
- run_id = model_info.run_id
54
 
55
- # Télécharge le fichier vectorizer.pkl depuis MLflow
56
- local_path = mlflow.artifacts.download_artifacts(
57
- artifact_path="vectorizer.pkl",
58
- run_id=run_id
59
- )
60
 
61
- # Charge le vectorizer depuis le fichier
62
- with open(local_path, "rb") as f:
63
- vectorizer = pickle.load(f)
64
 
65
- return vectorizer
66
- except Exception as e:
67
- print(f"Erreur lors du chargement du vectorizer : {e}")
68
- raise HTTPException(
69
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
70
- detail=f"Impossible de charger le vectorizer : {e}"
71
- )
72
 
73
- # Fonction asynchrone pour charger le modèle et le vectorizer
74
- async def load_model_and_vectorizer():
 
 
 
 
75
  try:
76
- loop = asyncio.get_event_loop()
77
- await loop.run_in_executor(None, load_model)
78
- global vectorizer
79
- vectorizer = await loop.run_in_executor(None, load_vectorizer)
80
- print("Modèle et vectorizer chargés avec succès.")
81
  except Exception as e:
82
- print(f"Erreur lors du chargement : {e}")
83
  raise HTTPException(
84
  status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
85
- detail=f"Impossible de charger le modèle ou le vectorizer : {e}"
86
  )
87
 
88
- # Charge le modèle et le vectorizer au démarrage
89
- @asynccontextmanager
90
- async def lifespan(app: FastAPI):
91
- # Code à exécuter au démarrage
92
- await load_model_and_vectorizer()
93
- yield
94
-
95
- # Initialise FastAPI
96
  app = FastAPI(
97
  title="Climate Fake News Detector API",
98
- description="API pour détecter les fake news sur le climat avec un modèle XGBoost.",
99
- version="1.0.0",
100
- lifespan=lifespan
101
  )
102
 
103
- # Modèle pour les données d'entrée
 
 
 
 
 
 
104
  class TextInput(BaseModel):
105
  text: str
106
 
107
- @app.get("/")
108
- async def read_root():
109
- return {
110
- "message": "Bienvenue sur l'API Climate Fake News Detector !",
111
- "documentation": "Consultez la documentation de l'API à l'adresse /docs."
112
- }
113
-
114
- @app.post("/predict")
115
- async def predict(input_data: TextInput):
116
- global model, vectorizer
117
- if model is None or vectorizer is None:
118
- raise HTTPException(
119
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
120
- detail="Le modèle ou le vectorizer n'est pas chargé."
121
- )
122
 
 
 
 
 
 
 
 
 
 
123
  try:
124
- X_vectorized = vectorizer.transform([input_data.text]).toarray()
125
- prediction = model.predict(X_vectorized)
126
- return {"prediction": int(prediction[0])}
127
-
 
 
 
128
  except Exception as e:
129
  raise HTTPException(
130
  status_code=status.HTTP_400_BAD_REQUEST,
@@ -132,5 +105,4 @@ async def predict(input_data: TextInput):
132
  )
133
 
134
  if __name__ == "__main__":
135
- import uvicorn
136
  uvicorn.run(app, host="localhost", port=8000)
 
1
  import os
2
+ import uvicorn
3
+ import pandas as pd
 
4
  from pydantic import BaseModel
5
+ from fastapi import FastAPI, HTTPException, status
6
+ import mlflow
7
  from dotenv import load_dotenv
 
 
 
8
 
9
+ description = """
10
+ # [Détection des fausses informations sur le réchauffement climatique]
11
 
12
+ ## À propos
13
+ Les fausses informations et les contenus manipulateurs sur le climat se propagent rapidement,
14
+ nuisant à la lutte contre le réchauffement climatique.
15
+ Ce projet vise à automatiser la classification des articles en trois catégories : vrai, biaisé ou faux.
16
 
17
+ ## Machine-Learning
18
+ Where you can:
19
+ * `/predict` : prediction for a single value
20
 
21
+ Check out documentation for more information on each endpoint.
22
+ """
 
23
 
24
+ tags_metadata = [
25
+ {
26
+ "name": "Predictions",
27
+ "description": "Endpoints that uses our Machine Learning model",
28
+ },
29
+ ]
30
 
31
+ load_dotenv()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ # Variables MLflow : URI de tracking, nom du modèle et stage
34
+ MLFLOW_TRACKING_APP_URI = os.getenv("MLFLOW_TRACKING_APP_URI")
35
+ MODEL_NAME = os.getenv("MODEL_NAME")
36
+ STAGE = os.getenv("STAGE", "production")
 
37
 
38
+ # Variables AWS pour accéder au bucket S3 qui contient les artifacts de MLflow
39
+ os.environ["AWS_ACCESS_KEY_ID"] = os.getenv("AWS_ACCESS_KEY_ID")
40
+ os.environ["AWS_SECRET_ACCESS_KEY"] = os.getenv("AWS_SECRET_ACCESS_KEY")
41
 
42
+ # Variables globales pour stocker le modèle
43
+ mlflow.set_tracking_uri(MLFLOW_TRACKING_APP_URI)
44
+ model_uri = f"models:/{MODEL_NAME}@{STAGE}"
 
 
 
 
45
 
46
+ # Chargement conditionnel du modèle
47
+ try:
48
+ # Essayer de charger un modèle scikit-learn
49
+ model = mlflow.sklearn.load_model(model_uri)
50
+ print("Modèle scikit-learn chargé avec succès.")
51
+ except mlflow.exceptions.MlflowException:
52
  try:
53
+ # Si échec, essayer de charger un modèle Transformers
54
+ model = mlflow.transformers.load_model(model_uri)
55
+ print("Modèle Transformers chargé avec succès.")
 
 
56
  except Exception as e:
 
57
  raise HTTPException(
58
  status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
59
+ detail=f"Erreur lors du chargement du modèle : {e}"
60
  )
61
 
 
 
 
 
 
 
 
 
62
  app = FastAPI(
63
  title="Climate Fake News Detector API",
64
+ description="API pour détecter les fake news sur le climat",
65
+ version="1.0",
66
+ openapi_tags=tags_metadata,
67
  )
68
 
69
+ @app.get("/")
70
+ def index():
71
+ """
72
+ Renvoie un message de bienvenue sur l'API ainsi que le lien vers la documentation.
73
+ """
74
+ return "Hello world! Go to /docs to try the API."
75
+
76
  class TextInput(BaseModel):
77
  text: str
78
 
79
+ @app.post("/predict", tags=["Predictions"])
80
+ def predict(features: TextInput):
81
+ """
82
+ Fait une prédiction sur un texte donné en utilisant le modèle chargé.
 
 
 
 
 
 
 
 
 
 
 
83
 
84
+ Args:
85
+ input_data (TextInput): Objet contenant le texte à prédire.
86
+
87
+ Returns:
88
+ dict: Dictionnaire contenant la prédiction (0 les articles avec un biais, 1 pour les articles faux, et 2 pour les articles fiable).
89
+
90
+ Raises:
91
+ HTTPException: Si une erreur survient lors de la prédiction.
92
+ """
93
  try:
94
+ # Préparation des données pour la prédiction
95
+ df = pd.DataFrame({"text": [features.text]})
96
+
97
+ # Prédiction
98
+ prediction = model.predict(df["text"].tolist())[0]
99
+ return {"prediction": int(prediction)}
100
+
101
  except Exception as e:
102
  raise HTTPException(
103
  status_code=status.HTTP_400_BAD_REQUEST,
 
105
  )
106
 
107
  if __name__ == "__main__":
 
108
  uvicorn.run(app, host="localhost", port=8000)
requirements.txt CHANGED
@@ -1,14 +1,16 @@
1
  mlflow==2.21.3
2
  scikit-learn==1.4.2
 
 
 
3
  requests>=2.31.0,<3
4
- fastapi
5
  uvicorn[standard]
6
- pydantic
7
- typing
8
- pandas
9
- gunicorn
10
- openpyxl
11
  boto3
12
  python-multipart
13
- dotenv
14
  xgboost
 
1
  mlflow==2.21.3
2
  scikit-learn==1.4.2
3
+ transformers>=4.40.0
4
+ torch>=2.0.0
5
+ tokenizers>=0.15.0
6
  requests>=2.31.0,<3
7
+ fastapi
8
  uvicorn[standard]
9
+ pydantic
10
+ pandas
11
+ gunicorn
12
+ openpyxl
 
13
  boto3
14
  python-multipart
15
+ python-dotenv
16
  xgboost