{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import time\n", "import numpy as np\n", "from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_absolute_percentage_error, r2_score\n", "from statsmodels.tools.eval_measures import mse as sm_mse, rmse as sm_rmse, meanabs as sm_mae\n", "from sklearn.model_selection import TimeSeriesSplit, GridSearchCV\n", "from neuralforecast import NeuralForecast\n", "from neuralforecast.models import NBEATS, NHITS, TFT, PatchTST, TCN, DLinear, RNN, LSTM, Autoformer, BiTCN, DeepAR, DeepNPTS, DilatedRNN, FEDformer, GRU, HINT, Informer, iTransformer, KAN, MLP, MLPMultivariate, NBEATSx, NLinear, RMoK, SOFTS, StemGNN, TiDE, TimeMixer, TimeLLM, TimesNet, TSMixer, TSMixerx, VanillaTransformer\n", "from neuralforecast.losses.pytorch import MQLoss\n", "from neuralforecast.losses.pytorch import MAE\n", "from neuralforecast.utils import AirPassengersDF\n", "import time\n", "import pandas as pd\n", "import numpy as np\n", "import torch\n", "torch.set_float32_matmul_precision('high') # or 'medium' for more precision\n", "\n", "import plotly.graph_objects as go\n", "from sklearn.utils import check_random_state\n", "import logging\n", "logging.basicConfig(level=logging.INFO)\n", "# Set seed for reproducibility\n", "# Set up logging\n", "logging.basicConfig(level=logging.INFO)\n", "\n", "# Set seed for reproducibility\n", "seed = 42\n", "np.random.seed(seed)\n", "torch.manual_seed(seed)\n", "random_state = check_random_state(seed)\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
| \n", " | year | \n", "naics | \n", "emp | \n", "pay | \n", "prode | \n", "prodh | \n", "prodw | \n", "vship | \n", "matcost | \n", "vadd | \n", "... | \n", "equip | \n", "plant | \n", "piship | \n", "pimat | \n", "piinv | \n", "pien | \n", "dtfp5 | \n", "tfp5 | \n", "dtfp4 | \n", "tfp4 | \n", "
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", "1958 | \n", "336111 | \n", "146.1 | \n", "868.4 | \n", "116.2 | \n", "229.7 | \n", "650.6 | \n", "5007.9 | \n", "3411.3 | \n", "1563.8 | \n", "... | \n", "3291.1 | \n", "11718.8 | \n", "0.314 | \n", "0.238 | \n", "0.180 | \n", "0.147 | \n", "NaN | \n", "0.553 | \n", "NaN | \n", "0.550 | \n", "
| 1 | \n", "1959 | \n", "336111 | \n", "160.6 | \n", "1072.6 | \n", "131.0 | \n", "283.2 | \n", "827.2 | \n", "6422.2 | \n", "4306.7 | \n", "2143.4 | \n", "... | \n", "3457.5 | \n", "11415.6 | \n", "0.322 | \n", "0.243 | \n", "0.184 | \n", "0.145 | \n", "0.052 | \n", "0.583 | \n", "0.053 | \n", "0.580 | \n", "
| 2 | \n", "1960 | \n", "336111 | \n", "176.1 | \n", "1183.7 | \n", "144.7 | \n", "302.8 | \n", "925.7 | \n", "7239.0 | \n", "4883.5 | \n", "2336.2 | \n", "... | \n", "3673.0 | \n", "11375.1 | \n", "0.318 | \n", "0.241 | \n", "0.190 | \n", "0.150 | \n", "0.030 | \n", "0.601 | \n", "0.030 | \n", "0.598 | \n", "
| 3 | \n", "1961 | \n", "336111 | \n", "152.4 | \n", "1035.7 | \n", "123.2 | \n", "252.9 | \n", "789.5 | \n", "6214.2 | \n", "4134.9 | \n", "2053.6 | \n", "... | \n", "3794.2 | \n", "11179.9 | \n", "0.317 | \n", "0.242 | \n", "0.190 | \n", "0.146 | \n", "-0.007 | \n", "0.596 | \n", "-0.009 | \n", "0.592 | \n", "
| 4 | \n", "1962 | \n", "336111 | \n", "168.2 | \n", "1223.6 | \n", "138.4 | \n", "297.1 | \n", "954.7 | \n", "7855.3 | \n", "5187.9 | \n", "2671.2 | \n", "... | \n", "3908.1 | \n", "11141.2 | \n", "0.315 | \n", "0.246 | \n", "0.193 | \n", "0.147 | \n", "0.077 | \n", "0.644 | \n", "0.078 | \n", "0.641 | \n", "
5 rows × 24 columns
\n", "| \n", " | naics | \n", "emp | \n", "pay | \n", "prode | \n", "prodh | \n", "prodw | \n", "y | \n", "matcost | \n", "vadd | \n", "invest | \n", "... | \n", "piship | \n", "pimat | \n", "piinv | \n", "pien | \n", "dtfp5 | \n", "tfp5 | \n", "dtfp4 | \n", "tfp4 | \n", "ds | \n", "unique_id | \n", "
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 1 | \n", "336111 | \n", "160.6 | \n", "1072.6 | \n", "131.0 | \n", "283.2 | \n", "827.2 | \n", "6422.2 | \n", "4306.7 | \n", "2143.4 | \n", "93.0 | \n", "... | \n", "0.322 | \n", "0.243 | \n", "0.184 | \n", "0.145 | \n", "0.052 | \n", "0.583 | \n", "0.053 | \n", "0.580 | \n", "1959-01-01 | \n", "336111 | \n", "
| 2 | \n", "336111 | \n", "176.1 | \n", "1183.7 | \n", "144.7 | \n", "302.8 | \n", "925.7 | \n", "7239.0 | \n", "4883.5 | \n", "2336.2 | \n", "111.2 | \n", "... | \n", "0.318 | \n", "0.241 | \n", "0.190 | \n", "0.150 | \n", "0.030 | \n", "0.601 | \n", "0.030 | \n", "0.598 | \n", "1960-01-01 | \n", "336111 | \n", "
| 3 | \n", "336111 | \n", "152.4 | \n", "1035.7 | \n", "123.2 | \n", "252.9 | \n", "789.5 | \n", "6214.2 | \n", "4134.9 | \n", "2053.6 | \n", "91.8 | \n", "... | \n", "0.317 | \n", "0.242 | \n", "0.190 | \n", "0.146 | \n", "-0.007 | \n", "0.596 | \n", "-0.009 | \n", "0.592 | \n", "1961-01-01 | \n", "336111 | \n", "
| 4 | \n", "336111 | \n", "168.2 | \n", "1223.6 | \n", "138.4 | \n", "297.1 | \n", "954.7 | \n", "7855.3 | \n", "5187.9 | \n", "2671.2 | \n", "118.3 | \n", "... | \n", "0.315 | \n", "0.246 | \n", "0.193 | \n", "0.147 | \n", "0.077 | \n", "0.644 | \n", "0.078 | \n", "0.641 | \n", "1962-01-01 | \n", "336111 | \n", "
| 5 | \n", "336111 | \n", "160.6 | \n", "1231.7 | \n", "132.2 | \n", "293.9 | \n", "972.0 | \n", "8254.4 | \n", "5518.5 | \n", "2748.9 | \n", "144.9 | \n", "... | \n", "0.313 | \n", "0.245 | \n", "0.195 | \n", "0.147 | \n", "0.014 | \n", "0.653 | \n", "0.014 | \n", "0.649 | \n", "1963-01-01 | \n", "336111 | \n", "
5 rows × 25 columns
\n", "