Spaces:
Paused
Paused
| from fastapi import FastAPI, Depends, HTTPException, status | |
| from fastapi.security import OAuth2PasswordBearer | |
| from pydantic import BaseModel | |
| from typing import List, Optional | |
| import pandas as pd | |
| import requests | |
| from src.features.build_features import TextPreprocessor | |
| from src.features.build_features import ImagePreprocessor | |
| from tensorflow.keras.applications.vgg16 import preprocess_input | |
| from tensorflow.keras.preprocessing.image import img_to_array, load_img | |
| from tensorflow.keras.preprocessing.sequence import pad_sequences | |
| import tensorflow as tf | |
| import numpy as np | |
| import json | |
| from tensorflow import keras | |
| from keras import backend as K | |
| from src.tools import f1_m, load_model, check_and_download | |
| # ... (omitted lines) | |
| def initialisation(): | |
| global predictor, tokenizer, rnn, vgg16, best_weights, mapper | |
| # Charger les configurations et modèles | |
| check_and_download("models", "tokenizer_config.json") | |
| with open("models/tokenizer_config.json", "r", encoding="utf-8") as json_file: | |
| tokenizer_config = json_file.read() | |
| tokenizer = keras.preprocessing.text.tokenizer_from_json(tokenizer_config) | |
| rnn = load_model( "models" , "best_rnn_model.h5" ) | |
| vgg16 = load_model("models" , "best_vgg16_model.h5") | |
| check_and_download("models", "best_weights.json") | |
| with open("models/best_weights.json", "r") as json_file: | |
| best_weights = json.load(json_file) | |
| check_and_download("models", "mapper.json") | |
| with open("models/mapper.json", "r") as json_file: | |
| mapper = json.load(json_file) | |
| return {"message": "Initialisation de predict effectuée avec succès"} | |
| initialisation() | |
| # Endpoint pour l'initialisation lorsque un nouvelle doit être mis en production | |
| def hot_init(): | |
| return initialisation() | |
| # Endpoint pour la prédiction | |
| def prediction(input_data: PredictionInput, token: Optional[str] = Depends(oauth2_scheme)): | |
| global predictor, tokenizer, rnn, vgg16, best_weights, mapper | |
| # Si api_secured est True, vérifiez les crédentiels | |
| if input_data.api_secured: | |
| auth_response = requests.get("http://api_oauth:8001/secured", headers={"Authorization": f"Bearer {token}"}) | |
| if auth_response.status_code != 200: | |
| raise HTTPException(status_code=auth_response.status_code, detail="Non autorisé à accéder à la prédiction") | |
| else: | |
| user_data = auth_response.json() | |
| user_info = user_data['FirstName']+" "+user_data['LastName'] | |
| if user_data['Authorization'] < 1: | |
| raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"{user_info} n'est pas autorisé à effectuer une prédiction") | |
| # message_response = {"message": f"{user_info} n'est pas autorisé a effectuer une prediction"} | |
| # return message_response | |
| else: | |
| user_info = "un utilisateur inconnu" | |
| # Exécutez la prédiction | |
| t_debut = time.time() | |
| predictor = Predict( | |
| tokenizer=tokenizer, | |
| rnn=rnn, | |
| vgg16=vgg16, | |
| best_weights=best_weights, | |
| mapper=mapper, | |
| filepath=input_data.dataset_path, | |
| imagepath=input_data.images_path | |
| ) | |
| predictions = predictor.predict() | |
| t_fin = time.time() | |
| # Sauvegarde des prédictions | |
| predictions.to_csv(input_data.prediction_path+"/predictions.csv", index=False) | |
| predictions = predictions.rename(columns={'cat_pred': 'cat_real'}) | |
| predictions['cat_pred'] = predictions.iloc[:, 0] | |
| predictions[['cat_real','cat_pred']].to_csv(input_data.prediction_path+"/new_classes.csv", index=False) | |
| print("Durée de la prédiction : {:.2f}".format(t_fin - t_debut)) | |
| prediction_response = {"message": f"Prédiction effectuée avec succès, demandée par {user_info}","duration": t_fin - t_debut} | |
| return prediction_response | |