Olivier-52 commited on
Commit ·
c3da7b0
1
Parent(s): 49ce3ff
Update app.py
Browse filesAdd async model and vectorizer loader
app.py
CHANGED
|
@@ -5,30 +5,21 @@ from fastapi import FastAPI, HTTPException, status
|
|
| 5 |
from pydantic import BaseModel
|
| 6 |
from dotenv import load_dotenv
|
| 7 |
from typing import Optional
|
|
|
|
|
|
|
| 8 |
|
| 9 |
# Charge les variables d'environnement
|
| 10 |
load_dotenv()
|
| 11 |
|
| 12 |
# Configuration des variables d'environnement
|
| 13 |
-
MLFLOW_TRACKING_APP_URI = os.getenv("MLFLOW_TRACKING_APP_URI"
|
| 14 |
-
MODEL_NAME = os.getenv("MODEL_NAME"
|
| 15 |
-
STAGE = os.getenv("STAGE"
|
| 16 |
|
| 17 |
# Configure les identifiants AWS pour accéder au bucket S3
|
| 18 |
os.environ["AWS_ACCESS_KEY_ID"] = os.getenv("AWS_ACCESS_KEY_ID")
|
| 19 |
os.environ["AWS_SECRET_ACCESS_KEY"] = os.getenv("AWS_SECRET_ACCESS_KEY")
|
| 20 |
|
| 21 |
-
# Initialise FastAPI
|
| 22 |
-
app = FastAPI(
|
| 23 |
-
title="Climate Fake News Detector API",
|
| 24 |
-
description="API pour détecter les fake news sur le climat avec un modèle XGBoost.",
|
| 25 |
-
version="1.0.0"
|
| 26 |
-
)
|
| 27 |
-
|
| 28 |
-
# Modèle pour les données d'entrée
|
| 29 |
-
class TextInput(BaseModel):
|
| 30 |
-
text: str
|
| 31 |
-
|
| 32 |
# Variables globales pour stocker le modèle et le vectorizer
|
| 33 |
model = None
|
| 34 |
vectorizer = None
|
|
@@ -79,8 +70,39 @@ def load_vectorizer():
|
|
| 79 |
detail=f"Impossible de charger le vectorizer : {e}"
|
| 80 |
)
|
| 81 |
|
| 82 |
-
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
@app.get("/")
|
| 86 |
async def read_root():
|
|
|
|
| 5 |
from pydantic import BaseModel
|
| 6 |
from dotenv import load_dotenv
|
| 7 |
from typing import Optional
|
| 8 |
+
import asyncio
|
| 9 |
+
from contextlib import asynccontextmanager
|
| 10 |
|
| 11 |
# Charge les variables d'environnement
|
| 12 |
load_dotenv()
|
| 13 |
|
| 14 |
# Configuration des variables d'environnement
|
| 15 |
+
MLFLOW_TRACKING_APP_URI = os.getenv("MLFLOW_TRACKING_APP_URI")
|
| 16 |
+
MODEL_NAME = os.getenv("MODEL_NAME")
|
| 17 |
+
STAGE = os.getenv("STAGE")
|
| 18 |
|
| 19 |
# Configure les identifiants AWS pour accéder au bucket S3
|
| 20 |
os.environ["AWS_ACCESS_KEY_ID"] = os.getenv("AWS_ACCESS_KEY_ID")
|
| 21 |
os.environ["AWS_SECRET_ACCESS_KEY"] = os.getenv("AWS_SECRET_ACCESS_KEY")
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
# Variables globales pour stocker le modèle et le vectorizer
|
| 24 |
model = None
|
| 25 |
vectorizer = None
|
|
|
|
| 70 |
detail=f"Impossible de charger le vectorizer : {e}"
|
| 71 |
)
|
| 72 |
|
| 73 |
+
# Fonction asynchrone pour charger le modèle et le vectorizer
|
| 74 |
+
async def load_model_and_vectorizer():
|
| 75 |
+
try:
|
| 76 |
+
loop = asyncio.get_event_loop()
|
| 77 |
+
await loop.run_in_executor(None, load_model)
|
| 78 |
+
global vectorizer
|
| 79 |
+
vectorizer = await loop.run_in_executor(None, load_vectorizer)
|
| 80 |
+
print("Modèle et vectorizer chargés avec succès.")
|
| 81 |
+
except Exception as e:
|
| 82 |
+
print(f"Erreur lors du chargement : {e}")
|
| 83 |
+
raise HTTPException(
|
| 84 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 85 |
+
detail=f"Impossible de charger le modèle ou le vectorizer : {e}"
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
# Charge le modèle et le vectorizer au démarrage
|
| 89 |
+
@asynccontextmanager
|
| 90 |
+
async def lifespan(app: FastAPI):
|
| 91 |
+
# Code à exécuter au démarrage
|
| 92 |
+
await load_model_and_vectorizer()
|
| 93 |
+
yield
|
| 94 |
+
|
| 95 |
+
# Initialise FastAPI
|
| 96 |
+
app = FastAPI(
|
| 97 |
+
title="Climate Fake News Detector API",
|
| 98 |
+
description="API pour détecter les fake news sur le climat avec un modèle XGBoost.",
|
| 99 |
+
version="1.0.0",
|
| 100 |
+
lifespan=lifespan
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# Modèle pour les données d'entrée
|
| 104 |
+
class TextInput(BaseModel):
|
| 105 |
+
text: str
|
| 106 |
|
| 107 |
@app.get("/")
|
| 108 |
async def read_root():
|