energy_forecasting / models /pytorch_forecasting_models.py
kawaiipeace's picture
update
de78237
import pandas as pd
import numpy as np
import os
import tempfile
import matplotlib.pyplot as plt
from sklearn.metrics import r2_score, root_mean_squared_error
from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer, DeepAR, NHiTS, Baseline
from pytorch_forecasting.data import GroupNormalizer
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger
import torch
from torch.utils.data import DataLoader
def prepare_forecasting_data(df, max_prediction_length, max_encoder_length):
df = df.copy()
df.reset_index(inplace=True)
df.columns = ['time_idx', 'target'] if df.shape[1] == 2 else ['time_idx'] + df.columns.tolist()[1:]
df['group'] = "series_1"
df['time_idx'] = np.arange(len(df))
training_cutoff = df["time_idx"].max() - max_prediction_length
training = TimeSeriesDataSet(
df[lambda x: x.time_idx <= training_cutoff],
time_idx="time_idx",
target="target",
group_ids=["group"],
max_encoder_length=max_encoder_length,
max_prediction_length=max_prediction_length,
static_categoricals=[],
static_reals=[],
time_varying_known_categoricals=[],
time_varying_known_reals=["time_idx"],
time_varying_unknown_categoricals=[],
time_varying_unknown_reals=["target"],
target_normalizer=GroupNormalizer(groups=["group"]),
)
validation = TimeSeriesDataSet.from_dataset(training, df, predict=True, stop_randomization=True)
train_dataloader = training.to_dataloader(train=True, batch_size=64, num_workers=0)
val_dataloader = validation.to_dataloader(train=False, batch_size=64, num_workers=0)
return training, train_dataloader, val_dataloader, df
def run_pytorch_forecasting(df, model_type, horizon, lag, future_horizon=0, device='cpu'):
torch.set_float32_matmul_precision('high')
torch_device = torch.device('cuda' if device == 'GPU' and torch.cuda.is_available() else 'cpu')
training, train_dataloader, val_dataloader, full_df = prepare_forecasting_data(df, horizon, lag)
pl_trainer = Trainer(
max_epochs=30,
gpus=1 if torch_device.type == 'cuda' else 0,
logger=CSVLogger("lightning_logs"),
enable_checkpointing=True,
callbacks=[
EarlyStopping(monitor="val_loss", patience=3, mode="min"),
ModelCheckpoint(monitor="val_loss", mode="min", save_top_k=1)
],
gradient_clip_val=0.1
)
if model_type == "TFT":
model = TemporalFusionTransformer.from_dataset(training, learning_rate=0.03)
elif model_type == "DeepAR":
model = DeepAR.from_dataset(training, learning_rate=0.03)
elif model_type == "NHiTS":
model = NHiTS.from_dataset(training, learning_rate=0.03)
else:
raise ValueError(f"Unsupported model type: {model_type}")
pl_trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
best_model_path = pl_trainer.checkpoint_callback.best_model_path
best_model = model.load_from_checkpoint(best_model_path)
actuals = torch.cat([y[0] for x, y in iter(val_dataloader)])
predictions = best_model.predict(val_dataloader)
rmse = root_mean_squared_error(actuals, predictions, squared=False)
r2 = r2_score(actuals, predictions)
metrics = f"Test RMSE: {rmse:.3f}, Test R2: {r2:.3f}"
fig = plt.figure(figsize=(12, 6))
plt.plot(actuals.numpy(), label="Actual")
plt.plot(predictions.numpy(), label="Predicted")
plt.title(f"{model_type} Forecast Result")
plt.legend()
export_df = pd.DataFrame({"Actual": actuals.numpy(), "Predicted": predictions.numpy()})
export_path = os.path.join(tempfile.gettempdir(), f"{model_type}_forecast.csv")
export_df.to_csv(export_path, index=False)
return actuals.numpy(), predictions.numpy(), fig, metrics, export_path