rakuten / src /predict_API.py
Demosthene-OR's picture
Configure LFS for images and update code
eb5ec73
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel
from typing import List, Optional
import pandas as pd
import requests
from src.features.build_features import TextPreprocessor
from src.features.build_features import ImagePreprocessor
from tensorflow.keras.applications.vgg16 import preprocess_input
from tensorflow.keras.preprocessing.image import img_to_array, load_img
from tensorflow.keras.preprocessing.sequence import pad_sequences
import tensorflow as tf
import numpy as np
import json
from tensorflow import keras
from keras import backend as K
from src.tools import f1_m, load_model, check_and_download
# ... (omitted lines)
def initialisation():
global predictor, tokenizer, rnn, vgg16, best_weights, mapper
# Charger les configurations et modèles
check_and_download("models", "tokenizer_config.json")
with open("models/tokenizer_config.json", "r", encoding="utf-8") as json_file:
tokenizer_config = json_file.read()
tokenizer = keras.preprocessing.text.tokenizer_from_json(tokenizer_config)
rnn = load_model( "models" , "best_rnn_model.h5" )
vgg16 = load_model("models" , "best_vgg16_model.h5")
check_and_download("models", "best_weights.json")
with open("models/best_weights.json", "r") as json_file:
best_weights = json.load(json_file)
check_and_download("models", "mapper.json")
with open("models/mapper.json", "r") as json_file:
mapper = json.load(json_file)
return {"message": "Initialisation de predict effectuée avec succès"}
initialisation()
# Endpoint pour l'initialisation lorsque un nouvelle doit être mis en production
@app.get("/initialisation")
def hot_init():
return initialisation()
# Endpoint pour la prédiction
@app.post("/prediction")
def prediction(input_data: PredictionInput, token: Optional[str] = Depends(oauth2_scheme)):
global predictor, tokenizer, rnn, vgg16, best_weights, mapper
# Si api_secured est True, vérifiez les crédentiels
if input_data.api_secured:
auth_response = requests.get("http://api_oauth:8001/secured", headers={"Authorization": f"Bearer {token}"})
if auth_response.status_code != 200:
raise HTTPException(status_code=auth_response.status_code, detail="Non autorisé à accéder à la prédiction")
else:
user_data = auth_response.json()
user_info = user_data['FirstName']+" "+user_data['LastName']
if user_data['Authorization'] < 1:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"{user_info} n'est pas autorisé à effectuer une prédiction")
# message_response = {"message": f"{user_info} n'est pas autorisé a effectuer une prediction"}
# return message_response
else:
user_info = "un utilisateur inconnu"
# Exécutez la prédiction
t_debut = time.time()
predictor = Predict(
tokenizer=tokenizer,
rnn=rnn,
vgg16=vgg16,
best_weights=best_weights,
mapper=mapper,
filepath=input_data.dataset_path,
imagepath=input_data.images_path
)
predictions = predictor.predict()
t_fin = time.time()
# Sauvegarde des prédictions
predictions.to_csv(input_data.prediction_path+"/predictions.csv", index=False)
predictions = predictions.rename(columns={'cat_pred': 'cat_real'})
predictions['cat_pred'] = predictions.iloc[:, 0]
predictions[['cat_real','cat_pred']].to_csv(input_data.prediction_path+"/new_classes.csv", index=False)
print("Durée de la prédiction : {:.2f}".format(t_fin - t_debut))
prediction_response = {"message": f"Prédiction effectuée avec succès, demandée par {user_info}","duration": t_fin - t_debut}
return prediction_response