Spaces:
Build error
Build error
| import pickle | |
| import os | |
| import numpy as np | |
| from datetime import datetime | |
| import pandas as pd | |
| from src.utils.mail import Mail | |
| from config import Config | |
| config = vars(Config) | |
| def save_models(models, model_type, directory): | |
| print('Saving models...') | |
| for i, model in enumerate(models): | |
| with open(f'{directory}/{model_type}_FOLD_{i+1}.pkl', 'wb') as file: | |
| pickle.dump(model, file) | |
| def load_models(directory): | |
| print('Loading models...') | |
| models = [] | |
| # List all files in the directory | |
| files = os.listdir(directory) | |
| pkl_files = [file for file in files if file.endswith('.pkl')] | |
| for file in pkl_files: | |
| with open(os.path.join(directory, file), 'rb') as file: | |
| model = pickle.load(file) | |
| models.append(model) | |
| return models | |
| def get_predictions(models, X_test): | |
| print('Forecasting test data...') | |
| preds = [] | |
| for model in models: | |
| preds.append(post_process(model.predict(X_test))) | |
| return np.mean(preds, axis=0) | |
| def post_process(predictions): | |
| predictions = predictions.clip(0) | |
| return predictions | |
| def save_results(dataframe, file_name): | |
| print('Saving results...') | |
| today_date = datetime.now().strftime("%Y-%m-%d") | |
| dataframe.to_excel(f'demand_predictions/{file_name}_{today_date}.xlsx', index=False) | |
| def save_parquet(dataframe, path): | |
| dataframe.to_parquet(path, index=False) | |
| def load_parquet(path): | |
| return pd.read_parquet(path) | |
| def send_mail(**kwargs): | |
| mail = Mail(config['MJ_APIKEY_PUBLIC'], config['MJ_APIKEY_PRIVATE']) | |
| status = mail.send_mail( | |
| fromName=kwargs.get('name'), | |
| fromMail=kwargs.get('email'), | |
| fromSubject=kwargs.get('subject'), | |
| fromMsg=kwargs.get('message'), | |
| dataframe=kwargs.get('dataframe'), | |
| forecast_dates=(kwargs.get('forecast_start_date'),kwargs.get('forecast_end_date')), | |
| toMail=kwargs.get('toMail') | |
| ) | |
| if status == 200: | |
| return True | |
| return False |