Olivier-52 commited on
Commit
c3da7b0
·
1 Parent(s): 49ce3ff

Update app.py

Browse files

Add async model and vectorizer loader

Files changed (1) hide show
  1. app.py +38 -16
app.py CHANGED
@@ -5,30 +5,21 @@ from fastapi import FastAPI, HTTPException, status
5
  from pydantic import BaseModel
6
  from dotenv import load_dotenv
7
  from typing import Optional
 
 
8
 
9
  # Charge les variables d'environnement
10
  load_dotenv()
11
 
12
  # Configuration des variables d'environnement
13
- MLFLOW_TRACKING_APP_URI = os.getenv("MLFLOW_TRACKING_APP_URI", "https://olivier-52-ml-flow.hf.space")
14
- MODEL_NAME = os.getenv("MODEL_NAME", "climate-fake-news-detector-model-XGBoost-v1")
15
- STAGE = os.getenv("STAGE", "production")
16
 
17
  # Configure les identifiants AWS pour accéder au bucket S3
18
  os.environ["AWS_ACCESS_KEY_ID"] = os.getenv("AWS_ACCESS_KEY_ID")
19
  os.environ["AWS_SECRET_ACCESS_KEY"] = os.getenv("AWS_SECRET_ACCESS_KEY")
20
 
21
- # Initialise FastAPI
22
- app = FastAPI(
23
- title="Climate Fake News Detector API",
24
- description="API pour détecter les fake news sur le climat avec un modèle XGBoost.",
25
- version="1.0.0"
26
- )
27
-
28
- # Modèle pour les données d'entrée
29
- class TextInput(BaseModel):
30
- text: str
31
-
32
  # Variables globales pour stocker le modèle et le vectorizer
33
  model = None
34
  vectorizer = None
@@ -79,8 +70,39 @@ def load_vectorizer():
79
  detail=f"Impossible de charger le vectorizer : {e}"
80
  )
81
 
82
- load_model()
83
- vectorizer = load_vectorizer()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  @app.get("/")
86
  async def read_root():
 
5
  from pydantic import BaseModel
6
  from dotenv import load_dotenv
7
  from typing import Optional
8
+ import asyncio
9
+ from contextlib import asynccontextmanager
10
 
11
  # Charge les variables d'environnement
12
  load_dotenv()
13
 
14
  # Configuration des variables d'environnement
15
+ MLFLOW_TRACKING_APP_URI = os.getenv("MLFLOW_TRACKING_APP_URI")
16
+ MODEL_NAME = os.getenv("MODEL_NAME")
17
+ STAGE = os.getenv("STAGE")
18
 
19
  # Configure les identifiants AWS pour accéder au bucket S3
20
  os.environ["AWS_ACCESS_KEY_ID"] = os.getenv("AWS_ACCESS_KEY_ID")
21
  os.environ["AWS_SECRET_ACCESS_KEY"] = os.getenv("AWS_SECRET_ACCESS_KEY")
22
 
 
 
 
 
 
 
 
 
 
 
 
23
  # Variables globales pour stocker le modèle et le vectorizer
24
  model = None
25
  vectorizer = None
 
70
  detail=f"Impossible de charger le vectorizer : {e}"
71
  )
72
 
73
+ # Fonction asynchrone pour charger le modèle et le vectorizer
74
+ async def load_model_and_vectorizer():
75
+ try:
76
+ loop = asyncio.get_event_loop()
77
+ await loop.run_in_executor(None, load_model)
78
+ global vectorizer
79
+ vectorizer = await loop.run_in_executor(None, load_vectorizer)
80
+ print("Modèle et vectorizer chargés avec succès.")
81
+ except Exception as e:
82
+ print(f"Erreur lors du chargement : {e}")
83
+ raise HTTPException(
84
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
85
+ detail=f"Impossible de charger le modèle ou le vectorizer : {e}"
86
+ )
87
+
88
+ # Charge le modèle et le vectorizer au démarrage
89
+ @asynccontextmanager
90
+ async def lifespan(app: FastAPI):
91
+ # Code à exécuter au démarrage
92
+ await load_model_and_vectorizer()
93
+ yield
94
+
95
+ # Initialise FastAPI
96
+ app = FastAPI(
97
+ title="Climate Fake News Detector API",
98
+ description="API pour détecter les fake news sur le climat avec un modèle XGBoost.",
99
+ version="1.0.0",
100
+ lifespan=lifespan
101
+ )
102
+
103
+ # Modèle pour les données d'entrée
104
+ class TextInput(BaseModel):
105
+ text: str
106
 
107
  @app.get("/")
108
  async def read_root():