Spaces:
Paused
Paused
| from features.build_features import TextPreprocessor | |
| from features.build_features import ImagePreprocessor | |
| import tensorflow as tf | |
| from tensorflow.keras.applications.vgg16 import preprocess_input | |
| from tensorflow.keras.preprocessing.image import img_to_array, load_img | |
| from tensorflow.keras.preprocessing.sequence import pad_sequences | |
| import numpy as np | |
| import json | |
| from tensorflow import keras | |
| import pandas as pd | |
| import argparse | |
| from keras import backend as K | |
| from tools import f1_m, load_model, check_and_download | |
| import time | |
| # ... (omitted lines) | |
| def main(): | |
| parser = argparse.ArgumentParser(description= "Input data") | |
| parser.add_argument("--dataset_path", default = "data/predict/X_test_update.csv", type=str,help="File path for the input CSV file.") | |
| parser.add_argument("--images_path", default = "data/predict/image_test", type=str, help="Base path for the images.") | |
| parser.add_argument("--prediction_path", default = "data/predict/predictions.csv", type=str, help="Path for the prediction results.") | |
| args = parser.parse_args() | |
| # Charger les configurations et modèles | |
| check_and_download("models", "tokenizer_config.json") | |
| with open("models/tokenizer_config.json", "r", encoding="utf-8") as json_file: | |
| tokenizer_config = json_file.read() | |
| tokenizer = keras.preprocessing.text.tokenizer_from_json(tokenizer_config) | |
| rnn = load_model("models","best_rnn_model.h5") | |
| vgg16 = load_model("models","best_vgg16_model.h5") | |
| check_and_download("models", "best_weights.json") | |
| with open("models/best_weights.json", "r") as json_file: | |
| best_weights = json.load(json_file) | |
| check_and_download("models", "mapper.json") | |
| with open("models/mapper.json", "r") as json_file: | |
| mapper = json.load(json_file) | |
| predictor = Predict( | |
| tokenizer=tokenizer, | |
| rnn=rnn, | |
| vgg16=vgg16, | |
| best_weights=best_weights, | |
| mapper=mapper, | |
| filepath= args.dataset_path, | |
| imagepath = args.images_path, | |
| ) | |
| # Création de l'instance Predict et exécution de la prédiction | |
| t_debut = time.time() | |
| predictions = predictor.predict() | |
| # Sauvegarde des prédictions | |
| # with open("data/preprocessed/predictions.json", "w", encoding="utf-8") as json_file: | |
| # json.dump(predictions, json_file, indent=2) | |
| predictions.to_csv(args.prediction_path, index=False) | |
| t_fin = time.time() | |
| print("Durée de la prédiction : {:.2f}".format(t_fin - t_debut)) | |
| if __name__ == "__main__": | |
| main() |