energy_forecasting / models /tsai_models.py
kawaiipeace's picture
update model
570d1fd
import os
import tempfile
import pandas as pd
import matplotlib.pyplot as plt
import torch
import numpy as np
from sklearn.metrics import mean_squared_error, r2_score
from tsai.all import *
MODEL_MAP = {
"TST": TST,
"InceptionTime": InceptionTime,
"XCM": XCM,
"TSTPlus": TSTPlus,
"ResNetPlus": ResNetPlus,
}
def run_tsai_forecast(
df,
model_name,
is_multivariate,
horizon,
lag,
epochs,
batch_size,
future_horizon=0,
device="cpu"
):
df.columns = df.columns.str.strip()
if isinstance(df.index, pd.DatetimeIndex):
df = df.sort_index()
else:
raise ValueError("Data must have datetime index")
df = df.astype("float32")
values = df.values
X, y = [], []
for i in range(len(values) - lag - horizon + 1):
X.append(values[i : i + lag])
y.append(values[i + lag : i + lag + horizon, 0]) # only first column as target
X = np.array(X)
y = np.array(y)
train_size = int(0.8 * len(X))
X_train, y_train = X[:train_size], y[:train_size]
X_valid, y_valid = X[train_size:], y[train_size:]
tfms = [TSForecasting(batch_tfms=TSStandardize())]
dls = get_ts_dls(X_train, y_train, valid_X=X_valid, valid_y=y_valid, batch_size=batch_size, tfms=tfms)
arch = MODEL_MAP[model_name]
learn = ts_learner(dls, arch, metrics=mae, cbs=ShowGraphCallback())
learn.fit_one_cycle(epochs)
# Prediction
y_pred = learn.get_X_preds(X_valid)[0].squeeze()
y_true = y_valid.squeeze()
if y_pred.ndim == 1:
y_pred = y_pred[:, None]
y_true = y_true[:, None]
rmse_val = mean_squared_error(y_true, y_pred, squared=False)
r2_val = r2_score(y_true, y_pred)
metrics = f"Test RMSE: {rmse_val:.3f}, Test R2: {r2_val:.3f}"
export_df = pd.DataFrame({
"Test_Actual": y_true.flatten(),
"Test_Predicted": y_pred.flatten(),
})
export_path = os.path.join(tempfile.gettempdir(), "tsai_forecast_result.csv")
export_df.to_csv(export_path, index=False)
fig = plt.figure(figsize=(12, 6))
plt.plot(y_true.flatten(), label="Test Actual")
plt.plot(y_pred.flatten(), label="Test Predicted", linestyle="--")
plt.title(f"{model_name} Forecast Result")
plt.legend()
return y_pred.flatten(), [], fig, metrics, export_path