rakuten / src /predict.py
Demosthene-OR's picture
Configure LFS for images and update code
eb5ec73
from features.build_features import TextPreprocessor
from features.build_features import ImagePreprocessor
import tensorflow as tf
from tensorflow.keras.applications.vgg16 import preprocess_input
from tensorflow.keras.preprocessing.image import img_to_array, load_img
from tensorflow.keras.preprocessing.sequence import pad_sequences
import numpy as np
import json
from tensorflow import keras
import pandas as pd
import argparse
from keras import backend as K
from tools import f1_m, load_model, check_and_download
import time
# ... (omitted lines)
def main():
parser = argparse.ArgumentParser(description= "Input data")
parser.add_argument("--dataset_path", default = "data/predict/X_test_update.csv", type=str,help="File path for the input CSV file.")
parser.add_argument("--images_path", default = "data/predict/image_test", type=str, help="Base path for the images.")
parser.add_argument("--prediction_path", default = "data/predict/predictions.csv", type=str, help="Path for the prediction results.")
args = parser.parse_args()
# Charger les configurations et modèles
check_and_download("models", "tokenizer_config.json")
with open("models/tokenizer_config.json", "r", encoding="utf-8") as json_file:
tokenizer_config = json_file.read()
tokenizer = keras.preprocessing.text.tokenizer_from_json(tokenizer_config)
rnn = load_model("models","best_rnn_model.h5")
vgg16 = load_model("models","best_vgg16_model.h5")
check_and_download("models", "best_weights.json")
with open("models/best_weights.json", "r") as json_file:
best_weights = json.load(json_file)
check_and_download("models", "mapper.json")
with open("models/mapper.json", "r") as json_file:
mapper = json.load(json_file)
predictor = Predict(
tokenizer=tokenizer,
rnn=rnn,
vgg16=vgg16,
best_weights=best_weights,
mapper=mapper,
filepath= args.dataset_path,
imagepath = args.images_path,
)
# Création de l'instance Predict et exécution de la prédiction
t_debut = time.time()
predictions = predictor.predict()
# Sauvegarde des prédictions
# with open("data/preprocessed/predictions.json", "w", encoding="utf-8") as json_file:
# json.dump(predictions, json_file, indent=2)
predictions.to_csv(args.prediction_path, index=False)
t_fin = time.time()
print("Durée de la prédiction : {:.2f}".format(t_fin - t_debut))
if __name__ == "__main__":
main()