Spaces:
Sleeping
Sleeping
| import os | |
| import uvicorn | |
| import pandas as pd | |
| import numpy as np | |
| from pydantic import BaseModel | |
| from fastapi import FastAPI, HTTPException, status, File, UploadFile | |
| from transformers import pipeline | |
| import mlflow | |
| from dotenv import load_dotenv | |
| description = """ | |
| # [Détection des fausses informations sur le réchauffement climatique] | |
| ## À propos | |
| Les fausses informations et les contenus manipulateurs sur le climat se propagent rapidement, | |
| nuisant à la lutte contre le réchauffement climatique. | |
| Ce projet vise à automatiser la classification des articles en trois catégories : vrai, biaisé ou faux. | |
| ## Machine-Learning | |
| Where you can: | |
| * `/predict` : prediction for a single value | |
| Check out documentation for more information on each endpoint. | |
| """ | |
| tags_metadata = [ | |
| { | |
| "name": "Predictions", | |
| "description": "Endpoints that uses our Machine Learning model", | |
| }, | |
| ] | |
| load_dotenv() | |
| # Variables MLflow : URI de tracking, nom du modèle et stage | |
| MLFLOW_TRACKING_APP_URI = os.getenv("MLFLOW_TRACKING_APP_URI") | |
| MODEL_NAME = os.getenv("MODEL_NAME") | |
| STAGE = os.getenv("STAGE", "production") | |
| # Variables AWS pour accéder au bucket S3 qui contient les artifacts de MLflow | |
| os.environ["AWS_ACCESS_KEY_ID"] = os.getenv("AWS_ACCESS_KEY_ID") | |
| os.environ["AWS_SECRET_ACCESS_KEY"] = os.getenv("AWS_SECRET_ACCESS_KEY") | |
| # Variables globales pour stocker le modèle | |
| mlflow.set_tracking_uri(MLFLOW_TRACKING_APP_URI) | |
| model_uri = f"models:/{MODEL_NAME}@{STAGE}" | |
| model = None | |
| model_type = None # "sklearn" ou "pytorch" | |
| classifier = None # Pour le pipeline Hugging Face | |
| # Chargement conditionnel du modèle | |
| try: | |
| # Essayer de charger un modèle scikit-learn | |
| model = mlflow.sklearn.load_model(model_uri) | |
| model_type = "sklearn" | |
| print("Modèle scikit-learn chargé avec succès.") | |
| except mlflow.exceptions.MlflowException: | |
| try: | |
| model = mlflow.pytorch.load_model(model_uri) | |
| model_type = "pytorch" | |
| print("Modèle PyTorch chargé avec succès.") | |
| classifier = pipeline(task="text-classification", | |
| model=model, | |
| tokenizer="camembert-base") | |
| except mlflow.exceptions.MlflowException as e: | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Erreur lors du chargement du modèle : {e}" | |
| ) | |
| app = FastAPI( | |
| title="Climate Fake News Detector API", | |
| description="API pour détecter les fake news sur le climat", | |
| version="1.0", | |
| openapi_tags=tags_metadata, | |
| ) | |
| def index(): | |
| """ | |
| Renvoie un message de bienvenue sur l'API ainsi que le lien vers la documentation. | |
| """ | |
| return "Hello world! Go to /docs to try the API." | |
| class TextInput(BaseModel): | |
| text: str | |
| def predict(features: TextInput): | |
| """ | |
| Fait une prédiction sur un texte donné en utilisant le modèle chargé. | |
| """ | |
| try: | |
| if model_type == "sklearn": | |
| # Cas scikit-learn : prédiction directe | |
| df = pd.DataFrame({"Text": [features.text]}) | |
| prediction = model.predict(df)[0] | |
| return {"prediction": int(prediction)} | |
| elif model_type == "pytorch": | |
| # Cas PyTorch (transformers) : utiliser le pipeline | |
| result = classifier(features.text) | |
| return { | |
| "prediction": result[0]["label"], | |
| "score": result[0]["score"] | |
| } | |
| else: | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="Aucun modèle chargé ou type de modèle non reconnu." | |
| ) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=f"Erreur lors de la prédiction : {e}" | |
| ) | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="localhost", port=8000) |