File size: 3,952 Bytes
91f0a33
f9d87b8
 
55a19e2
fe10113
a5498e9
67af143
f9d87b8
91f0a33
 
f9d87b8
 
af97d0a
f9d87b8
 
 
 
fe10113
f9d87b8
 
 
b272ac7
f9d87b8
 
fe10113
f9d87b8
 
 
 
 
 
fe10113
f9d87b8
fe10113
f9d87b8
 
 
 
fe10113
f9d87b8
 
 
fe10113
f9d87b8
 
 
67af143
 
 
fe10113
f9d87b8
 
 
 
67af143
f9d87b8
67af143
f9d87b8
c3da7b0
426a949
67af143
426a949
67af143
 
 
 
426a949
c3da7b0
 
f9d87b8
426a949
c3da7b0
 
 
f9d87b8
 
 
c3da7b0
 
f9d87b8
 
 
 
 
 
 
c3da7b0
 
91f0a33
f9d87b8
 
 
 
 
fe10113
67af143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f9d87b8
fe10113
 
 
 
 
91f0a33
 
fe10113
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import uvicorn
import pandas as pd
import numpy as np
from pydantic import BaseModel
from fastapi import FastAPI, HTTPException, status, File, UploadFile
from transformers import pipeline
import mlflow
from dotenv import load_dotenv

description = """
# [Détection des fausses informations sur le réchauffement climatique]

## À propos
Les fausses informations et les contenus manipulateurs sur le climat se propagent rapidement,
nuisant à la lutte contre le réchauffement climatique.
Ce projet vise à automatiser la classification des articles en trois catégories : vrai, biaisé ou faux.

## Machine-Learning
Where you can:
* `/predict` : prediction for a single value

Check out documentation for more information on each endpoint.
"""

tags_metadata = [
    {
        "name": "Predictions",
        "description": "Endpoints that uses our Machine Learning model",
    },
]

load_dotenv()

# Variables MLflow : URI de tracking, nom du modèle et stage
MLFLOW_TRACKING_APP_URI = os.getenv("MLFLOW_TRACKING_APP_URI")
MODEL_NAME = os.getenv("MODEL_NAME")
STAGE = os.getenv("STAGE", "production")

# Variables AWS pour accéder au bucket S3 qui contient les artifacts de MLflow
os.environ["AWS_ACCESS_KEY_ID"] = os.getenv("AWS_ACCESS_KEY_ID")
os.environ["AWS_SECRET_ACCESS_KEY"] = os.getenv("AWS_SECRET_ACCESS_KEY")

# Variables globales pour stocker le modèle
mlflow.set_tracking_uri(MLFLOW_TRACKING_APP_URI)
model_uri = f"models:/{MODEL_NAME}@{STAGE}"
model = None
model_type = None  # "sklearn" ou "pytorch"
classifier = None  # Pour le pipeline Hugging Face

# Chargement conditionnel du modèle
try:
    # Essayer de charger un modèle scikit-learn
    model = mlflow.sklearn.load_model(model_uri)
    model_type = "sklearn"
    print("Modèle scikit-learn chargé avec succès.")

except mlflow.exceptions.MlflowException:
    try:
        model = mlflow.pytorch.load_model(model_uri)
        model_type = "pytorch"
        print("Modèle PyTorch chargé avec succès.")
        classifier = pipeline(task="text-classification", 
                              model=model, 
                              tokenizer="camembert-base")
        
    except mlflow.exceptions.MlflowException as e:
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=f"Erreur lors du chargement du modèle : {e}"
    )

app = FastAPI(
    title="Climate Fake News Detector API",
    description="API pour détecter les fake news sur le climat",
    version="1.0",
    openapi_tags=tags_metadata,
)

@app.get("/")
def index():
    """
    Renvoie un message de bienvenue sur l'API ainsi que le lien vers la documentation.
    """
    return "Hello world! Go to /docs to try the API."

class TextInput(BaseModel):
    text: str

@app.post("/predict", tags=["Predictions"])
def predict(features: TextInput):
    """
    Fait une prédiction sur un texte donné en utilisant le modèle chargé.
    """
    try:
        if model_type == "sklearn":
            # Cas scikit-learn : prédiction directe
            df = pd.DataFrame({"Text": [features.text]})
            prediction = model.predict(df)[0]
            return {"prediction": int(prediction)}

        elif model_type == "pytorch":
            # Cas PyTorch (transformers) : utiliser le pipeline
            result = classifier(features.text)
            return {
                "prediction": result[0]["label"],
                "score": result[0]["score"]
            }

        else:
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
                detail="Aucun modèle chargé ou type de modèle non reconnu."
            )

    except Exception as e:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=f"Erreur lors de la prédiction : {e}"
        )

if __name__ == "__main__":
    uvicorn.run(app, host="localhost", port=8000)