File size: 3,886 Bytes
eb5ec73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel
from typing import List, Optional
import pandas as pd
import requests
from src.features.build_features import TextPreprocessor
from src.features.build_features import ImagePreprocessor
from tensorflow.keras.applications.vgg16 import preprocess_input
from tensorflow.keras.preprocessing.image import img_to_array, load_img
from tensorflow.keras.preprocessing.sequence import pad_sequences
import tensorflow as tf
import numpy as np
import json
from tensorflow import keras
from keras import backend as K
from src.tools import f1_m, load_model, check_and_download

# ... (omitted lines)

def initialisation():
    global predictor, tokenizer, rnn, vgg16, best_weights, mapper
    
    # Charger les configurations et modèles
    check_and_download("models", "tokenizer_config.json")
    with open("models/tokenizer_config.json", "r", encoding="utf-8") as json_file:
        tokenizer_config = json_file.read()
    tokenizer = keras.preprocessing.text.tokenizer_from_json(tokenizer_config)

    rnn = load_model( "models" , "best_rnn_model.h5" )
    vgg16 = load_model("models" , "best_vgg16_model.h5")

    check_and_download("models", "best_weights.json")
    with open("models/best_weights.json", "r") as json_file:
        best_weights = json.load(json_file)

    check_and_download("models", "mapper.json")
    with open("models/mapper.json", "r") as json_file:
        mapper = json.load(json_file)

    return {"message": "Initialisation de predict effectuée avec succès"}

initialisation()

# Endpoint pour l'initialisation lorsque un nouvelle doit être mis en production
@app.get("/initialisation")
def hot_init():
    return initialisation()

# Endpoint pour la prédiction
@app.post("/prediction")
def prediction(input_data: PredictionInput, token: Optional[str] = Depends(oauth2_scheme)):
    global predictor, tokenizer, rnn, vgg16, best_weights, mapper
    
    # Si api_secured est True, vérifiez les crédentiels
    if input_data.api_secured:
        auth_response = requests.get("http://api_oauth:8001/secured", headers={"Authorization": f"Bearer {token}"})
        if auth_response.status_code != 200:
            raise HTTPException(status_code=auth_response.status_code, detail="Non autorisé à accéder à la prédiction")
        else:
            user_data = auth_response.json()
            user_info = user_data['FirstName']+" "+user_data['LastName']
            if user_data['Authorization'] < 1:
                raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"{user_info} n'est pas autorisé à effectuer une prédiction")
                # message_response = {"message": f"{user_info} n'est pas autorisé a effectuer une prediction"}
                # return message_response
    else:
        user_info = "un utilisateur inconnu"

    # Exécutez la prédiction
    t_debut = time.time()
    predictor = Predict(
        tokenizer=tokenizer,
        rnn=rnn,
        vgg16=vgg16,
        best_weights=best_weights,
        mapper=mapper,
        filepath=input_data.dataset_path,
        imagepath=input_data.images_path
    )
    predictions = predictor.predict()
    t_fin = time.time()
    
    # Sauvegarde des prédictions
    predictions.to_csv(input_data.prediction_path+"/predictions.csv", index=False)
    predictions = predictions.rename(columns={'cat_pred': 'cat_real'})
    predictions['cat_pred'] = predictions.iloc[:, 0]
    predictions[['cat_real','cat_pred']].to_csv(input_data.prediction_path+"/new_classes.csv", index=False)
        
    print("Durée de la prédiction : {:.2f}".format(t_fin - t_debut))
    
    prediction_response = {"message": f"Prédiction effectuée avec succès, demandée par {user_info}","duration": t_fin - t_debut}
    return prediction_response