kleisar's picture
Upload 6 files
97869d5 verified
from datetime import datetime
import pickle
import pandas as pd
from prophet import Prophet
class ProphetModel:
def __init__(self):
self.model = None
def train_model(self, csv_file):
"""
Train the Prophet model using the data from the given CSV file.
:param csv_file: Path to the CSV file containing the dataset.
"""
# Load the dataset
df = pd.read_csv(csv_file)
# Prepare the data
df.rename(columns={'time': 'ds', 'volume': 'y'}, inplace=True)
# Initialize and fit the Prophet model
self.model = Prophet()
self.model.fit(df)
def predict(self, datetime):
"""
Query the forecast for a given date.
:param date: The date for which to query the forecast (format: 'YYYY-MM-DD').
:return: The forecasted value for the given date.
"""
if self.model is None:
raise Exception("Model has not been trained. Call train_model() first.")
# Create a DataFrame with the given date
future = pd.DataFrame({'ds': [datetime]})
# Make predictions
forecast = self.model.predict(future)
# Return the forecasted value to json format
return forecast[['ds', 'yhat', 'yhat_lower', 'yhat_upper']].to_json(orient='records')
def save_model(self, model_file):
"""
Save the trained model to a file.
:param model_file: Path to the file where the model will be saved.
"""
#check if trained model exists
if self.model is None:
raise Exception("Model has not been trained. Call train_model() first.")
with open(model_file, 'wb') as fout:
pickle.dump(self.model, fout)
def load_model(self, model_file):
"""
Load a trained model from a file.
:param model_file: Path to the file containing the saved model.
"""
with open(model_file, 'rb') as fin:
self.model = pickle.load(fin)
return self
# prophet_model = ProphetModel()
# prophet_model.train_model('data/traffic_data.csv')
# print(type(prophet_model))
# predict_time = datetime.strptime('2023-10-01 09:00:00', "%Y-%m-%d %H:%M:%S")
# forecast = prophet_model.predict(predict_time)
# prophet_model.save_model('saved_models/forecast_model.pkl')
# print(forecast)
# prophet_model = prophet_model.load_model('saved-models/forecast_model.pkl')
# print(type(prophet_model))
# forecast = prophet_model.predict(predict_time)
# print(forecast)