vValentine7 commited on
Commit
7f5bd75
·
verified ·
1 Parent(s): ecce316

Create BaseModel.py

Browse files
Files changed (1) hide show
  1. BaseModel.py +83 -0
BaseModel.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from prophet import Prophet
3
+ from prophet.diagnostics import cross_validation, performance_metrics
4
+
5
+ class GeoMagModel:
6
+ """
7
+ A class for geomagnetic data forecasting using Prophet.
8
+ """
9
+
10
+ def __init__(self, changepoint_prior_scale=0.1, weekly_seasonality=True):
11
+ """
12
+ Initialize the GeoMagModel with Prophet configuration.
13
+
14
+ Args:
15
+ changepoint_prior_scale (float): Flexibility of changepoint detection.
16
+ weekly_seasonality (bool): Whether to include weekly seasonality.
17
+ """
18
+ self.changepoint_prior_scale = changepoint_prior_scale
19
+ self.weekly_seasonality = weekly_seasonality
20
+ self.model = None
21
+
22
+ @staticmethod
23
+ def prepare_for_prophet(df):
24
+ """
25
+ Prepare the DataFrame for Prophet by renaming columns as required.
26
+
27
+ Args:
28
+ df (pd.DataFrame): The cleaned DataFrame with 'timestamp' and 'Dst' columns.
29
+
30
+ Returns:
31
+ pd.DataFrame: A DataFrame with 'ds' (timestamp) and 'y' (Dst) columns for Prophet.
32
+ """
33
+ return df.rename(columns={'timestamp': 'ds', 'Dst': 'y'})
34
+
35
+ def train(self, df):
36
+ """
37
+ Train the Prophet model on the given DataFrame.
38
+
39
+ Args:
40
+ df (pd.DataFrame): The DataFrame prepared for Prophet.
41
+ """
42
+ df = self.prepare_for_prophet(df)
43
+ self.model = Prophet(interval_width=0.70, changepoint_prior_scale=self.changepoint_prior_scale)
44
+
45
+ if self.weekly_seasonality:
46
+ self.model.add_seasonality(name='weekly', period=7, fourier_order=5, prior_scale=10)
47
+
48
+ self.model.fit(df)
49
+
50
+ def forecast(self, periods=6, freq='h'):
51
+ """
52
+ Make a future forecast with the trained model.
53
+
54
+ Args:
55
+ periods (int): The number of future periods to forecast. Defaults to 6.
56
+ freq (str): The frequency of the forecast ('h' for hours). Defaults to 'h'.
57
+
58
+ Returns:
59
+ pd.DataFrame: The forecast DataFrame with 'ds', 'yhat', 'yhat_lower', and 'yhat_upper' columns.
60
+ """
61
+ if self.model is None:
62
+ raise ValueError("Model has not been trained. Call train() before forecast().")
63
+
64
+ future_dates = self.model.make_future_dataframe(periods=periods, freq=freq)
65
+ forecast = self.model.predict(future_dates)
66
+ return forecast[['ds', 'yhat', 'yhat_lower', 'yhat_upper']]
67
+
68
+ def cross_validate(self, initial='384 hours', period='12 hours', horizon='6 hours'):
69
+ """
70
+ Perform cross-validation on the trained Prophet model.
71
+
72
+ Args:
73
+ initial (str): Initial training period for cross-validation.
74
+ period (str): Frequency of making predictions.
75
+ horizon (str): Forecast horizon for each prediction.
76
+
77
+ Returns:
78
+ pd.DataFrame: Cross-validation results including metrics.
79
+ """
80
+ if self.model is None:
81
+ raise ValueError("Model has not been trained. Call train() before cross_validate().")
82
+
83
+ return cross_validation(self.model, initial=initial, period=period, horizon=horizon)