Forecast / app.py
msn-enginenova21's picture
Update app.py
af6986b
import matplotlib
matplotlib.use('Agg') # Set the backend to Agg
from flask import Flask, request
from flask import render_template
from components.model_prediction import Prediction
from Support_module_dir.support_function_plot import combine_plot_function
import io
import base64
# Function to load the model and make predictions
def load_and_predict_model(forecast_days):
try:
saved_model_path = "Saved_Model_dir/2023-12-23_00_42_22/SARIMAX_FORCAST_MODEL.joblib"
# Call the function to load and predict
predictions_instance, dataframe_instance = Prediction(saved_model_path).model_prediction(
forecast_days) # called class Prediction
# Reset the index and move it into a new column
dataframe_instance = dataframe_instance[-7:].reset_index()
# Assuming predictions_instance is your DataFrame
predictions_instance['Mean_Price'] = (predictions_instance['Lower_Bound'] + predictions_instance[
'Upper_Bound']) / 2
# Create an in-memory buffer
buffer = io.BytesIO()
# calling plot function
combine_plot_function(dataframe_instance, predictions_instance, buffer)
# getting image decode string
plot_img_str = base64.b64encode(buffer.getvalue()).decode()
return predictions_instance, dataframe_instance, plot_img_str
except FileNotFoundError as file_error:
raise FileNotFoundError(f"Error loading model: {file_error}")
except Exception as e:
raise e
# Initialize Flask app
app = Flask(__name__)
# Route for the home page
@app.route('/', methods=['GET'])
def home():
return render_template("about.html")
# Route for the prediction page
@app.route('/predict', methods=['POST'])
def predict():
input_data = int(request.form['forecast'])
predictions_instance, dataframe_instance, plot_img_str = load_and_predict_model(forecast_days=input_data)
try:
return render_template('index.html', predictions=predictions_instance.to_dict(orient="records"),
dataframe=dataframe_instance.to_dict(orient="records"),
plot_img_str=plot_img_str)
except Exception as e:
return render_template('error.html', error_message=str(e))
# Run the app
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)