| import pickle | |
| from glob import glob | |
| import torch | |
| import os | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| def load_model(folder): | |
| ''' | |
| Función que tiene por objetivo cargar un modelo de predicción. | |
| Utiliza un modelo .pt y un objeto .pkl | |
| folder: carpeta de la que cargar el modelo (str) | |
| ''' | |
| base_folder = 'production_models' | |
| folder = folder | |
| model_path = glob(os.path.join(base_folder, folder, '*.pt'))[0] | |
| clf_path = glob(os.path.join(base_folder, folder, '*.pkl'))[0] | |
| with open(clf_path, 'rb') as file: | |
| clf = pickle.load(file) | |
| clf.model = torch.load(model_path, map_location = device) | |
| return clf |