Spaces:
Paused
Paused
File size: 3,886 Bytes
eb5ec73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel
from typing import List, Optional
import pandas as pd
import requests
from src.features.build_features import TextPreprocessor
from src.features.build_features import ImagePreprocessor
from tensorflow.keras.applications.vgg16 import preprocess_input
from tensorflow.keras.preprocessing.image import img_to_array, load_img
from tensorflow.keras.preprocessing.sequence import pad_sequences
import tensorflow as tf
import numpy as np
import json
from tensorflow import keras
from keras import backend as K
from src.tools import f1_m, load_model, check_and_download
# ... (omitted lines)
def initialisation():
global predictor, tokenizer, rnn, vgg16, best_weights, mapper
# Charger les configurations et modèles
check_and_download("models", "tokenizer_config.json")
with open("models/tokenizer_config.json", "r", encoding="utf-8") as json_file:
tokenizer_config = json_file.read()
tokenizer = keras.preprocessing.text.tokenizer_from_json(tokenizer_config)
rnn = load_model( "models" , "best_rnn_model.h5" )
vgg16 = load_model("models" , "best_vgg16_model.h5")
check_and_download("models", "best_weights.json")
with open("models/best_weights.json", "r") as json_file:
best_weights = json.load(json_file)
check_and_download("models", "mapper.json")
with open("models/mapper.json", "r") as json_file:
mapper = json.load(json_file)
return {"message": "Initialisation de predict effectuée avec succès"}
initialisation()
# Endpoint pour l'initialisation lorsque un nouvelle doit être mis en production
@app.get("/initialisation")
def hot_init():
return initialisation()
# Endpoint pour la prédiction
@app.post("/prediction")
def prediction(input_data: PredictionInput, token: Optional[str] = Depends(oauth2_scheme)):
global predictor, tokenizer, rnn, vgg16, best_weights, mapper
# Si api_secured est True, vérifiez les crédentiels
if input_data.api_secured:
auth_response = requests.get("http://api_oauth:8001/secured", headers={"Authorization": f"Bearer {token}"})
if auth_response.status_code != 200:
raise HTTPException(status_code=auth_response.status_code, detail="Non autorisé à accéder à la prédiction")
else:
user_data = auth_response.json()
user_info = user_data['FirstName']+" "+user_data['LastName']
if user_data['Authorization'] < 1:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"{user_info} n'est pas autorisé à effectuer une prédiction")
# message_response = {"message": f"{user_info} n'est pas autorisé a effectuer une prediction"}
# return message_response
else:
user_info = "un utilisateur inconnu"
# Exécutez la prédiction
t_debut = time.time()
predictor = Predict(
tokenizer=tokenizer,
rnn=rnn,
vgg16=vgg16,
best_weights=best_weights,
mapper=mapper,
filepath=input_data.dataset_path,
imagepath=input_data.images_path
)
predictions = predictor.predict()
t_fin = time.time()
# Sauvegarde des prédictions
predictions.to_csv(input_data.prediction_path+"/predictions.csv", index=False)
predictions = predictions.rename(columns={'cat_pred': 'cat_real'})
predictions['cat_pred'] = predictions.iloc[:, 0]
predictions[['cat_real','cat_pred']].to_csv(input_data.prediction_path+"/new_classes.csv", index=False)
print("Durée de la prédiction : {:.2f}".format(t_fin - t_debut))
prediction_response = {"message": f"Prédiction effectuée avec succès, demandée par {user_info}","duration": t_fin - t_debut}
return prediction_response
|