Navya-Sree commited on
Commit
d2c53bb
·
verified ·
1 Parent(s): e03ff34

Create src/modeling/advanced_models.py

Browse files
Files changed (1) hide show
  1. src/modeling/advanced_models.py +122 -0
src/modeling/advanced_models.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simplified modeling for Hugging Face compatibility.
3
+ """
4
+ import numpy as np
5
+ import pandas as pd
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.utils.data import DataLoader, TensorDataset
9
+ from prophet import Prophet
10
+ from pmdarima import auto_arima
11
+ import logging
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ class LSTMForecaster(nn.Module):
16
+ """LSTM model for time series forecasting."""
17
+
18
+ def __init__(self, input_size: int, hidden_size: int, num_layers: int,
19
+ output_size: int, dropout: float = 0.2):
20
+ super(LSTMForecaster, self).__init__()
21
+ self.hidden_size = hidden_size
22
+ self.num_layers = num_layers
23
+
24
+ self.lstm = nn.LSTM(input_size, hidden_size, num_layers,
25
+ batch_first=True, dropout=dropout)
26
+ self.dropout = nn.Dropout(dropout)
27
+ self.linear = nn.Linear(hidden_size, output_size)
28
+
29
+ def forward(self, x):
30
+ lstm_out, _ = self.lstm(x)
31
+ lstm_out = self.dropout(lstm_out[:, -1, :]) # Take the last output
32
+ out = self.linear(lstm_out)
33
+ return out
34
+
35
+ class AdvancedModelTrainer:
36
+ """Trainer for advanced forecasting models."""
37
+
38
+ def __init__(self, config: dict):
39
+ self.config = config
40
+
41
+ def train_lstm(self, X_train: np.ndarray, y_train: np.ndarray,
42
+ X_val: np.ndarray = None,
43
+ y_val: np.ndarray = None) -> nn.Module:
44
+ """Train LSTM model."""
45
+ model_config = self.config['lstm']
46
+
47
+ # Convert to PyTorch tensors
48
+ train_dataset = TensorDataset(
49
+ torch.FloatTensor(X_train),
50
+ torch.FloatTensor(y_train)
51
+ )
52
+ train_loader = DataLoader(train_dataset, batch_size=model_config['batch_size'], shuffle=True)
53
+
54
+ # Initialize model
55
+ model = LSTMForecaster(
56
+ input_size=X_train.shape[2],
57
+ hidden_size=model_config['hidden_size'],
58
+ num_layers=model_config['num_layers'],
59
+ output_size=y_train.shape[1],
60
+ dropout=model_config['dropout']
61
+ )
62
+
63
+ # Training setup
64
+ criterion = nn.MSELoss()
65
+ optimizer = torch.optim.Adam(model.parameters(), lr=model_config['learning_rate'])
66
+
67
+ # Training loop
68
+ for epoch in range(model_config['epochs']):
69
+ model.train()
70
+ epoch_loss = 0
71
+
72
+ for batch_X, batch_y in train_loader:
73
+ optimizer.zero_grad()
74
+ predictions = model(batch_X)
75
+ loss = criterion(predictions, batch_y)
76
+ loss.backward()
77
+
78
+ # Gradient clipping
79
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
80
+
81
+ optimizer.step()
82
+ epoch_loss += loss.item()
83
+
84
+ if epoch % 10 == 0:
85
+ logger.info(f'Epoch {epoch}, Loss: {epoch_loss/len(train_loader):.4f}')
86
+
87
+ return model
88
+
89
+ def train_prophet(self, df: pd.DataFrame,
90
+ date_col: str,
91
+ value_col: str) -> Prophet:
92
+ """Train Facebook Prophet model."""
93
+ prophet_df = df[[date_col, value_col]].rename(
94
+ columns={date_col: 'ds', value_col: 'y'}
95
+ )
96
+
97
+ model = Prophet(
98
+ changepoint_prior_scale=self.config['prophet'].get('changepoint_prior_scale', 0.05),
99
+ seasonality_prior_scale=self.config['prophet'].get('seasonality_prior_scale', 10),
100
+ yearly_seasonality=self.config['prophet'].get('yearly_seasonality', True),
101
+ weekly_seasonality=self.config['prophet'].get('weekly_seasonality', True),
102
+ daily_seasonality=self.config['prophet'].get('daily_seasonality', False)
103
+ )
104
+
105
+ model.fit(prophet_df)
106
+ return model
107
+
108
+ def train_auto_arima(self, series: pd.Series) -> object:
109
+ """Train auto ARIMA model."""
110
+ model = auto_arima(
111
+ series,
112
+ start_p=1,
113
+ start_q=1,
114
+ max_p=3,
115
+ max_q=3,
116
+ seasonal=True,
117
+ m=7,
118
+ stepwise=True,
119
+ suppress_warnings=True,
120
+ error_action='ignore'
121
+ )
122
+ return model