File size: 3,952 Bytes
91f0a33 f9d87b8 55a19e2 fe10113 a5498e9 67af143 f9d87b8 91f0a33 f9d87b8 af97d0a f9d87b8 fe10113 f9d87b8 b272ac7 f9d87b8 fe10113 f9d87b8 fe10113 f9d87b8 fe10113 f9d87b8 fe10113 f9d87b8 fe10113 f9d87b8 67af143 fe10113 f9d87b8 67af143 f9d87b8 67af143 f9d87b8 c3da7b0 426a949 67af143 426a949 67af143 426a949 c3da7b0 f9d87b8 426a949 c3da7b0 f9d87b8 c3da7b0 f9d87b8 c3da7b0 91f0a33 f9d87b8 fe10113 67af143 f9d87b8 fe10113 91f0a33 fe10113 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 | import os
import uvicorn
import pandas as pd
import numpy as np
from pydantic import BaseModel
from fastapi import FastAPI, HTTPException, status, File, UploadFile
from transformers import pipeline
import mlflow
from dotenv import load_dotenv
description = """
# [Détection des fausses informations sur le réchauffement climatique]
## À propos
Les fausses informations et les contenus manipulateurs sur le climat se propagent rapidement,
nuisant à la lutte contre le réchauffement climatique.
Ce projet vise à automatiser la classification des articles en trois catégories : vrai, biaisé ou faux.
## Machine-Learning
Where you can:
* `/predict` : prediction for a single value
Check out documentation for more information on each endpoint.
"""
tags_metadata = [
{
"name": "Predictions",
"description": "Endpoints that uses our Machine Learning model",
},
]
load_dotenv()
# Variables MLflow : URI de tracking, nom du modèle et stage
MLFLOW_TRACKING_APP_URI = os.getenv("MLFLOW_TRACKING_APP_URI")
MODEL_NAME = os.getenv("MODEL_NAME")
STAGE = os.getenv("STAGE", "production")
# Variables AWS pour accéder au bucket S3 qui contient les artifacts de MLflow
os.environ["AWS_ACCESS_KEY_ID"] = os.getenv("AWS_ACCESS_KEY_ID")
os.environ["AWS_SECRET_ACCESS_KEY"] = os.getenv("AWS_SECRET_ACCESS_KEY")
# Variables globales pour stocker le modèle
mlflow.set_tracking_uri(MLFLOW_TRACKING_APP_URI)
model_uri = f"models:/{MODEL_NAME}@{STAGE}"
model = None
model_type = None # "sklearn" ou "pytorch"
classifier = None # Pour le pipeline Hugging Face
# Chargement conditionnel du modèle
try:
# Essayer de charger un modèle scikit-learn
model = mlflow.sklearn.load_model(model_uri)
model_type = "sklearn"
print("Modèle scikit-learn chargé avec succès.")
except mlflow.exceptions.MlflowException:
try:
model = mlflow.pytorch.load_model(model_uri)
model_type = "pytorch"
print("Modèle PyTorch chargé avec succès.")
classifier = pipeline(task="text-classification",
model=model,
tokenizer="camembert-base")
except mlflow.exceptions.MlflowException as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Erreur lors du chargement du modèle : {e}"
)
app = FastAPI(
title="Climate Fake News Detector API",
description="API pour détecter les fake news sur le climat",
version="1.0",
openapi_tags=tags_metadata,
)
@app.get("/")
def index():
"""
Renvoie un message de bienvenue sur l'API ainsi que le lien vers la documentation.
"""
return "Hello world! Go to /docs to try the API."
class TextInput(BaseModel):
text: str
@app.post("/predict", tags=["Predictions"])
def predict(features: TextInput):
"""
Fait une prédiction sur un texte donné en utilisant le modèle chargé.
"""
try:
if model_type == "sklearn":
# Cas scikit-learn : prédiction directe
df = pd.DataFrame({"Text": [features.text]})
prediction = model.predict(df)[0]
return {"prediction": int(prediction)}
elif model_type == "pytorch":
# Cas PyTorch (transformers) : utiliser le pipeline
result = classifier(features.text)
return {
"prediction": result[0]["label"],
"score": result[0]["score"]
}
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Aucun modèle chargé ou type de modèle non reconnu."
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Erreur lors de la prédiction : {e}"
)
if __name__ == "__main__":
uvicorn.run(app, host="localhost", port=8000) |