Spaces:
Sleeping
Sleeping
| import os | |
| import tempfile | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import torch | |
| import numpy as np | |
| from sklearn.metrics import mean_squared_error, r2_score | |
| from tsai.all import * | |
| MODEL_MAP = { | |
| "TST": TST, | |
| "InceptionTime": InceptionTime, | |
| "XCM": XCM, | |
| "TSTPlus": TSTPlus, | |
| "ResNetPlus": ResNetPlus, | |
| } | |
| def run_tsai_forecast( | |
| df, | |
| model_name, | |
| is_multivariate, | |
| horizon, | |
| lag, | |
| epochs, | |
| batch_size, | |
| future_horizon=0, | |
| device="cpu" | |
| ): | |
| df.columns = df.columns.str.strip() | |
| if isinstance(df.index, pd.DatetimeIndex): | |
| df = df.sort_index() | |
| else: | |
| raise ValueError("Data must have datetime index") | |
| df = df.astype("float32") | |
| values = df.values | |
| X, y = [], [] | |
| for i in range(len(values) - lag - horizon + 1): | |
| X.append(values[i : i + lag]) | |
| y.append(values[i + lag : i + lag + horizon, 0]) # only first column as target | |
| X = np.array(X) | |
| y = np.array(y) | |
| train_size = int(0.8 * len(X)) | |
| X_train, y_train = X[:train_size], y[:train_size] | |
| X_valid, y_valid = X[train_size:], y[train_size:] | |
| tfms = [TSForecasting(batch_tfms=TSStandardize())] | |
| dls = get_ts_dls(X_train, y_train, valid_X=X_valid, valid_y=y_valid, batch_size=batch_size, tfms=tfms) | |
| arch = MODEL_MAP[model_name] | |
| learn = ts_learner(dls, arch, metrics=mae, cbs=ShowGraphCallback()) | |
| learn.fit_one_cycle(epochs) | |
| # Prediction | |
| y_pred = learn.get_X_preds(X_valid)[0].squeeze() | |
| y_true = y_valid.squeeze() | |
| if y_pred.ndim == 1: | |
| y_pred = y_pred[:, None] | |
| y_true = y_true[:, None] | |
| rmse_val = mean_squared_error(y_true, y_pred, squared=False) | |
| r2_val = r2_score(y_true, y_pred) | |
| metrics = f"Test RMSE: {rmse_val:.3f}, Test R2: {r2_val:.3f}" | |
| export_df = pd.DataFrame({ | |
| "Test_Actual": y_true.flatten(), | |
| "Test_Predicted": y_pred.flatten(), | |
| }) | |
| export_path = os.path.join(tempfile.gettempdir(), "tsai_forecast_result.csv") | |
| export_df.to_csv(export_path, index=False) | |
| fig = plt.figure(figsize=(12, 6)) | |
| plt.plot(y_true.flatten(), label="Test Actual") | |
| plt.plot(y_pred.flatten(), label="Test Predicted", linestyle="--") | |
| plt.title(f"{model_name} Forecast Result") | |
| plt.legend() | |
| return y_pred.flatten(), [], fig, metrics, export_path | |