| | import os |
| | import uvicorn |
| | import pandas as pd |
| | import numpy as np |
| | from pydantic import BaseModel |
| | from fastapi import FastAPI, HTTPException, status, File, UploadFile |
| | from transformers import pipeline |
| | import mlflow |
| | from dotenv import load_dotenv |
| |
|
| | description = """ |
| | # [Détection des fausses informations sur le réchauffement climatique] |
| | |
| | ## À propos |
| | Les fausses informations et les contenus manipulateurs sur le climat se propagent rapidement, |
| | nuisant à la lutte contre le réchauffement climatique. |
| | Ce projet vise à automatiser la classification des articles en trois catégories : vrai, biaisé ou faux. |
| | |
| | ## Machine-Learning |
| | Where you can: |
| | * `/predict` : prediction for a single value |
| | |
| | Check out documentation for more information on each endpoint. |
| | """ |
| |
|
| | tags_metadata = [ |
| | { |
| | "name": "Predictions", |
| | "description": "Endpoints that uses our Machine Learning model", |
| | }, |
| | ] |
| |
|
| | load_dotenv() |
| |
|
| | |
| | MLFLOW_TRACKING_APP_URI = os.getenv("MLFLOW_TRACKING_APP_URI") |
| | MODEL_NAME = os.getenv("MODEL_NAME") |
| | STAGE = os.getenv("STAGE", "production") |
| |
|
| | |
| | os.environ["AWS_ACCESS_KEY_ID"] = os.getenv("AWS_ACCESS_KEY_ID") |
| | os.environ["AWS_SECRET_ACCESS_KEY"] = os.getenv("AWS_SECRET_ACCESS_KEY") |
| |
|
| | |
| | mlflow.set_tracking_uri(MLFLOW_TRACKING_APP_URI) |
| | model_uri = f"models:/{MODEL_NAME}@{STAGE}" |
| | model = None |
| | model_type = None |
| | classifier = None |
| |
|
| | |
| | try: |
| | |
| | model = mlflow.sklearn.load_model(model_uri) |
| | model_type = "sklearn" |
| | print("Modèle scikit-learn chargé avec succès.") |
| |
|
| | except mlflow.exceptions.MlflowException: |
| | try: |
| | model = mlflow.pytorch.load_model(model_uri) |
| | model_type = "pytorch" |
| | print("Modèle PyTorch chargé avec succès.") |
| | classifier = pipeline(task="text-classification", |
| | model=model, |
| | tokenizer="camembert-base") |
| | |
| | except mlflow.exceptions.MlflowException as e: |
| | raise HTTPException( |
| | status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| | detail=f"Erreur lors du chargement du modèle : {e}" |
| | ) |
| |
|
| | app = FastAPI( |
| | title="Climate Fake News Detector API", |
| | description="API pour détecter les fake news sur le climat", |
| | version="1.0", |
| | openapi_tags=tags_metadata, |
| | ) |
| |
|
| | @app.get("/") |
| | def index(): |
| | """ |
| | Renvoie un message de bienvenue sur l'API ainsi que le lien vers la documentation. |
| | """ |
| | return "Hello world! Go to /docs to try the API." |
| |
|
| | class TextInput(BaseModel): |
| | text: str |
| |
|
| | @app.post("/predict", tags=["Predictions"]) |
| | def predict(features: TextInput): |
| | """ |
| | Fait une prédiction sur un texte donné en utilisant le modèle chargé. |
| | """ |
| | try: |
| | if model_type == "sklearn": |
| | |
| | df = pd.DataFrame({"Text": [features.text]}) |
| | prediction = model.predict(df)[0] |
| | return {"prediction": int(prediction)} |
| |
|
| | elif model_type == "pytorch": |
| | |
| | result = classifier(features.text) |
| | return { |
| | "prediction": result[0]["label"], |
| | "score": result[0]["score"] |
| | } |
| |
|
| | else: |
| | raise HTTPException( |
| | status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| | detail="Aucun modèle chargé ou type de modèle non reconnu." |
| | ) |
| |
|
| | except Exception as e: |
| | raise HTTPException( |
| | status_code=status.HTTP_400_BAD_REQUEST, |
| | detail=f"Erreur lors de la prédiction : {e}" |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | uvicorn.run(app, host="localhost", port=8000) |