| import os |
| import mlflow |
| import pickle |
| from fastapi import FastAPI, HTTPException, status |
| from pydantic import BaseModel |
| from dotenv import load_dotenv |
| from typing import Optional |
|
|
| |
| load_dotenv() |
|
|
| |
| MLFLOW_TRACKING_APP_URI = os.getenv("MLFLOW_TRACKING_APP_URI", "https://olivier-52-ml-flow.hf.space") |
| MODEL_NAME = os.getenv("MODEL_NAME", "climate-fake-news-detector-model-XGBoost-v1") |
| STAGE = os.getenv("STAGE", "production") |
|
|
| |
| os.environ["AWS_ACCESS_KEY_ID"] = os.getenv("AWS_ACCESS_KEY_ID") |
| os.environ["AWS_SECRET_ACCESS_KEY"] = os.getenv("AWS_SECRET_ACCESS_KEY") |
|
|
| |
| app = FastAPI( |
| title="Climate Fake News Detector API", |
| description="API pour détecter les fake news sur le climat avec un modèle XGBoost.", |
| version="1.0.0" |
| ) |
|
|
| |
| class TextInput(BaseModel): |
| text: str |
|
|
| |
| model = None |
| vectorizer = None |
|
|
| |
| def load_model(): |
| global model |
| try: |
| |
| mlflow.set_tracking_uri(MLFLOW_TRACKING_APP_URI) |
|
|
| |
| model_uri = f"models:/{MODEL_NAME}@{STAGE}" |
| model = mlflow.sklearn.load_model(model_uri) |
| print("Modèle chargé avec succès depuis MLflow.") |
| except Exception as e: |
| print(f"Erreur lors du chargement du modèle depuis MLflow : {e}") |
| raise HTTPException( |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| detail=f"Impossible de charger le modèle depuis MLflow : {e}" |
| ) |
|
|
| |
| def load_vectorizer(): |
| try: |
| |
| client = mlflow.MlflowClient(MLFLOW_TRACKING_APP_URI) |
|
|
| |
| model_info = client.get_model_version_by_alias(MODEL_NAME, STAGE) |
| run_id = model_info.run_id |
|
|
| |
| local_path = mlflow.artifacts.download_artifacts( |
| artifact_path="vectorizer.pkl", |
| run_id=run_id |
| ) |
|
|
| |
| with open(local_path, "rb") as f: |
| vectorizer = pickle.load(f) |
|
|
| return vectorizer |
| except Exception as e: |
| print(f"Erreur lors du chargement du vectorizer : {e}") |
| raise HTTPException( |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| detail=f"Impossible de charger le vectorizer : {e}" |
| ) |
|
|
| load_model() |
| vectorizer = load_vectorizer() |
|
|
| @app.get("/") |
| async def read_root(): |
| return { |
| "message": "Bienvenue sur l'API Climate Fake News Detector !", |
| "documentation": "Consultez la documentation de l'API à l'adresse /docs." |
| } |
|
|
| @app.post("/predict") |
| async def predict(input_data: TextInput): |
| global model, vectorizer |
| if model is None or vectorizer is None: |
| raise HTTPException( |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| detail="Le modèle ou le vectorizer n'est pas chargé." |
| ) |
|
|
| try: |
| X_vectorized = vectorizer.transform([input_data.text]).toarray() |
| prediction = model.predict(X_vectorized) |
| return {"prediction": int(prediction[0])} |
| |
| except Exception as e: |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail=f"Erreur lors de la prédiction : {e}" |
| ) |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run(app, host="localhost", port=8000) |