Olivier-52
RB model URI
49ce3ff
raw
history blame
3.75 kB
import os
import mlflow
import pickle
from fastapi import FastAPI, HTTPException, status
from pydantic import BaseModel
from dotenv import load_dotenv
from typing import Optional
# Charge les variables d'environnement
load_dotenv()
# Configuration des variables d'environnement
MLFLOW_TRACKING_APP_URI = os.getenv("MLFLOW_TRACKING_APP_URI", "https://olivier-52-ml-flow.hf.space")
MODEL_NAME = os.getenv("MODEL_NAME", "climate-fake-news-detector-model-XGBoost-v1")
STAGE = os.getenv("STAGE", "production")
# Configure les identifiants AWS pour accéder au bucket S3
os.environ["AWS_ACCESS_KEY_ID"] = os.getenv("AWS_ACCESS_KEY_ID")
os.environ["AWS_SECRET_ACCESS_KEY"] = os.getenv("AWS_SECRET_ACCESS_KEY")
# Initialise FastAPI
app = FastAPI(
title="Climate Fake News Detector API",
description="API pour détecter les fake news sur le climat avec un modèle XGBoost.",
version="1.0.0"
)
# Modèle pour les données d'entrée
class TextInput(BaseModel):
text: str
# Variables globales pour stocker le modèle et le vectorizer
model = None
vectorizer = None
# Fonction pour charger le modèle depuis MLflow
def load_model():
global model
try:
# Configure l'URI de tracking MLflow
mlflow.set_tracking_uri(MLFLOW_TRACKING_APP_URI)
# Charge le modèle depuis MLflow
model_uri = f"models:/{MODEL_NAME}@{STAGE}"
model = mlflow.sklearn.load_model(model_uri)
print("Modèle chargé avec succès depuis MLflow.")
except Exception as e:
print(f"Erreur lors du chargement du modèle depuis MLflow : {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Impossible de charger le modèle depuis MLflow : {e}"
)
# Fonction pour charger le vectorizer depuis MLflow
def load_vectorizer():
try:
# Initialise le client MLflow
client = mlflow.MlflowClient(MLFLOW_TRACKING_APP_URI)
# Récupère les informations sur le modèle
model_info = client.get_model_version_by_alias(MODEL_NAME, STAGE)
run_id = model_info.run_id
# Télécharge le fichier vectorizer.pkl depuis MLflow
local_path = mlflow.artifacts.download_artifacts(
artifact_path="vectorizer.pkl",
run_id=run_id
)
# Charge le vectorizer depuis le fichier
with open(local_path, "rb") as f:
vectorizer = pickle.load(f)
return vectorizer
except Exception as e:
print(f"Erreur lors du chargement du vectorizer : {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Impossible de charger le vectorizer : {e}"
)
load_model()
vectorizer = load_vectorizer()
@app.get("/")
async def read_root():
return {
"message": "Bienvenue sur l'API Climate Fake News Detector !",
"documentation": "Consultez la documentation de l'API à l'adresse /docs."
}
@app.post("/predict")
async def predict(input_data: TextInput):
global model, vectorizer
if model is None or vectorizer is None:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Le modèle ou le vectorizer n'est pas chargé."
)
try:
X_vectorized = vectorizer.transform([input_data.text]).toarray()
prediction = model.predict(X_vectorized)
return {"prediction": int(prediction[0])}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Erreur lors de la prédiction : {e}"
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="localhost", port=8000)