WissamH commited on
Commit
d78d73c
·
1 Parent(s): e5f297d

Update app.py

Browse files

fix transformers import bug

Files changed (1) hide show
  1. app.py +35 -16
app.py CHANGED
@@ -1,8 +1,10 @@
1
  import os
2
  import uvicorn
3
  import pandas as pd
 
4
  from pydantic import BaseModel
5
  from fastapi import FastAPI, HTTPException, status, File, UploadFile
 
6
  import mlflow
7
  from dotenv import load_dotenv
8
 
@@ -42,22 +44,31 @@ os.environ["AWS_SECRET_ACCESS_KEY"] = os.getenv("AWS_SECRET_ACCESS_KEY")
42
  # Variables globales pour stocker le modèle
43
  mlflow.set_tracking_uri(MLFLOW_TRACKING_APP_URI)
44
  model_uri = f"models:/{MODEL_NAME}@{STAGE}"
 
 
 
45
 
46
  # Chargement conditionnel du modèle
47
  try:
48
  # Essayer de charger un modèle scikit-learn
49
  model = mlflow.sklearn.load_model(model_uri)
 
50
  print("Modèle scikit-learn chargé avec succès.")
 
51
  except mlflow.exceptions.MlflowException:
52
  try:
53
- # Si échec, essayer de charger un modèle Transformers
54
- model = mlflow.transformers.load_model(model_uri)
55
- print("Modèle Transformers chargé avec succès.")
56
- except Exception as e:
 
 
 
 
57
  raise HTTPException(
58
  status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
59
  detail=f"Erreur lors du chargement du modèle : {e}"
60
- )
61
 
62
  app = FastAPI(
63
  title="Climate Fake News Detector API",
@@ -80,19 +91,27 @@ class TextInput(BaseModel):
80
  def predict(features: TextInput):
81
  """
82
  Fait une prédiction sur un texte donné en utilisant le modèle chargé.
83
-
84
- Args:
85
- features (TextInput): Objet contenant le texte à prédire.
86
-
87
- Returns:
88
- dict: Dictionnaire contenant la prédiction.
89
  """
90
  try:
91
- df = pd.DataFrame({"Text": [features.text]})
92
-
93
- prediction = model.predict(df)[0]
94
-
95
- return {"prediction": int(prediction)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  except Exception as e:
98
  raise HTTPException(
 
1
  import os
2
  import uvicorn
3
  import pandas as pd
4
+ import numpy as np
5
  from pydantic import BaseModel
6
  from fastapi import FastAPI, HTTPException, status, File, UploadFile
7
+ from transformers import pipeline
8
  import mlflow
9
  from dotenv import load_dotenv
10
 
 
44
  # Variables globales pour stocker le modèle
45
  mlflow.set_tracking_uri(MLFLOW_TRACKING_APP_URI)
46
  model_uri = f"models:/{MODEL_NAME}@{STAGE}"
47
+ model = None
48
+ model_type = None # "sklearn" ou "pytorch"
49
+ classifier = None # Pour le pipeline Hugging Face
50
 
51
  # Chargement conditionnel du modèle
52
  try:
53
  # Essayer de charger un modèle scikit-learn
54
  model = mlflow.sklearn.load_model(model_uri)
55
+ model_type = "sklearn"
56
  print("Modèle scikit-learn chargé avec succès.")
57
+
58
  except mlflow.exceptions.MlflowException:
59
  try:
60
+ model = mlflow.pytorch.load_model(model_uri)
61
+ model_type = "pytorch"
62
+ print("Modèle PyTorch chargé avec succès.")
63
+ classifier = pipeline(task="text-classification",
64
+ model=model,
65
+ tokenizer="camembert-base")
66
+
67
+ except mlflow.exceptions.MlflowException as e:
68
  raise HTTPException(
69
  status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
70
  detail=f"Erreur lors du chargement du modèle : {e}"
71
+ )
72
 
73
  app = FastAPI(
74
  title="Climate Fake News Detector API",
 
91
  def predict(features: TextInput):
92
  """
93
  Fait une prédiction sur un texte donné en utilisant le modèle chargé.
 
 
 
 
 
 
94
  """
95
  try:
96
+ if model_type == "sklearn":
97
+ # Cas scikit-learn : prédiction directe
98
+ df = pd.DataFrame({"Text": [features.text]})
99
+ prediction = model.predict(df)[0]
100
+ return {"prediction": int(prediction)}
101
+
102
+ elif model_type == "pytorch":
103
+ # Cas PyTorch (transformers) : utiliser le pipeline
104
+ result = classifier(features.text)
105
+ return {
106
+ "prediction": result[0]["label"],
107
+ "score": result[0]["score"]
108
+ }
109
+
110
+ else:
111
+ raise HTTPException(
112
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
113
+ detail="Aucun modèle chargé ou type de modèle non reconnu."
114
+ )
115
 
116
  except Exception as e:
117
  raise HTTPException(