from fastapi import FastAPI, Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer from pydantic import BaseModel from typing import List, Optional import pandas as pd import requests from src.features.build_features import TextPreprocessor from src.features.build_features import ImagePreprocessor from tensorflow.keras.applications.vgg16 import preprocess_input from tensorflow.keras.preprocessing.image import img_to_array, load_img from tensorflow.keras.preprocessing.sequence import pad_sequences import tensorflow as tf import numpy as np import json from tensorflow import keras from keras import backend as K from src.tools import f1_m, load_model, check_and_download # ... (omitted lines) def initialisation(): global predictor, tokenizer, rnn, vgg16, best_weights, mapper # Charger les configurations et modèles check_and_download("models", "tokenizer_config.json") with open("models/tokenizer_config.json", "r", encoding="utf-8") as json_file: tokenizer_config = json_file.read() tokenizer = keras.preprocessing.text.tokenizer_from_json(tokenizer_config) rnn = load_model( "models" , "best_rnn_model.h5" ) vgg16 = load_model("models" , "best_vgg16_model.h5") check_and_download("models", "best_weights.json") with open("models/best_weights.json", "r") as json_file: best_weights = json.load(json_file) check_and_download("models", "mapper.json") with open("models/mapper.json", "r") as json_file: mapper = json.load(json_file) return {"message": "Initialisation de predict effectuée avec succès"} initialisation() # Endpoint pour l'initialisation lorsque un nouvelle doit être mis en production @app.get("/initialisation") def hot_init(): return initialisation() # Endpoint pour la prédiction @app.post("/prediction") def prediction(input_data: PredictionInput, token: Optional[str] = Depends(oauth2_scheme)): global predictor, tokenizer, rnn, vgg16, best_weights, mapper # Si api_secured est True, vérifiez les crédentiels if input_data.api_secured: auth_response = requests.get("http://api_oauth:8001/secured", headers={"Authorization": f"Bearer {token}"}) if auth_response.status_code != 200: raise HTTPException(status_code=auth_response.status_code, detail="Non autorisé à accéder à la prédiction") else: user_data = auth_response.json() user_info = user_data['FirstName']+" "+user_data['LastName'] if user_data['Authorization'] < 1: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"{user_info} n'est pas autorisé à effectuer une prédiction") # message_response = {"message": f"{user_info} n'est pas autorisé a effectuer une prediction"} # return message_response else: user_info = "un utilisateur inconnu" # Exécutez la prédiction t_debut = time.time() predictor = Predict( tokenizer=tokenizer, rnn=rnn, vgg16=vgg16, best_weights=best_weights, mapper=mapper, filepath=input_data.dataset_path, imagepath=input_data.images_path ) predictions = predictor.predict() t_fin = time.time() # Sauvegarde des prédictions predictions.to_csv(input_data.prediction_path+"/predictions.csv", index=False) predictions = predictions.rename(columns={'cat_pred': 'cat_real'}) predictions['cat_pred'] = predictions.iloc[:, 0] predictions[['cat_real','cat_pred']].to_csv(input_data.prediction_path+"/new_classes.csv", index=False) print("Durée de la prédiction : {:.2f}".format(t_fin - t_debut)) prediction_response = {"message": f"Prédiction effectuée avec succès, demandée par {user_info}","duration": t_fin - t_debut} return prediction_response