{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "5IM6CZzW_CH0" }, "source": [ "# Informer Demo" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "b5GFng7v7Eq0" }, "outputs": [], "source": [ "import sys\n", "\n", "# if not 'Informer2020' in sys.path:\n", "# sys.path += ['Informer2020']" ] }, { "cell_type": "markdown", "metadata": { "id": "rIjZdN5e_SWe" }, "source": [ "## Experiments: Train and Test" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RPdt-Kwc_RRZ" }, "outputs": [], "source": [ "from utils.tools import dotdict\n", "from exp.exp_informer import Exp_Informer\n", "import torch\n", "import numpy as np\n", "import pandas as pd\n", "import os\n", "from utils.ipynb_helpers import args_from_setting, setting_from_args, handle_gpu" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6mx2dnwY9dWi" }, "outputs": [], "source": [ "args = dotdict()\n", "args.des = \"full_1h\"\n", "\n", "args.model = \"informer\" # model of experiment, options: [informer, informerstack, informerlight(TBD)]\n", "\n", "args.data = \"custom\" # data\n", "args.root_path = \"./data/stock/\" # root path of data file\n", "\n", "\n", "args.data_path = \"full_1h.csv\" # data file\n", "args.features = \"MS\" # forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate\n", "args.target = \"XOM_pctchange\" # target feature in S or MS task\n", "args.freq = \"h\" # freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h\n", "args.checkpoints = \"./checkpoints\" # location of model checkpoints\n", "\n", "args.seq_len = 16 # input sequence length of Informer encoder\n", "args.label_len = 4 # start token length of Informer decoder\n", "args.pred_len = 1 # prediction sequence length\n", "# Informer decoder input: concat[start token series(label_len), zero padding series(pred_len)]\n", "\n", "args.cols = [\n", " \"XOM_open\",\n", " \"XOM_high\",\n", " \"XOM_low\",\n", " \"XOM_close\",\n", " \"XOM_volume\",\n", " \"XOM_pctchange\",\n", " \"XOM_shortsma\",\n", "] # [\"XOM_close\", \"BP_close\", \"CVX_close\", \"WTI_close\"]\n", "args.enc_in = 7 # 13 # encoder input size\n", "args.dec_in = 7 # 13 # decoder input size\n", "args.c_out = 1 # output size\n", "args.factor = 5 # probsparse attn factor\n", "args.d_model = 64 # 512 # dimension of model\n", "args.n_heads = 8 # num of heads\n", "args.e_layers = 4 # 2 # num of encoder layers\n", "args.d_layers = 2 # 1 # num of decoder layers\n", "args.d_ff = 2048 # dimension of fcn in model\n", "args.dropout = 0.05 # dropout\n", "args.attn = \"prob\" # attention used in encoder, options:[prob, full]\n", "args.t_embed = \"timeF\" # time features encoding, options:[timeF, fixed, learned]\n", "args.activation = \"gelu\" # activation\n", "args.distil = True # whether to use distilling in encoder\n", "args.output_attention = False # whether to output attention in encoder\n", "args.mix = True\n", "args.padding = 0\n", "\n", "args.batch_size = 64\n", "args.learning_rate = 0.00001\n", "args.loss = \"mse\"\n", "args.lradj = \"type1\"\n", "args.use_amp = False # whether to use automatic mixed precision training\n", "\n", "args.num_workers = 0\n", "args.itr = 1 # number of runs\n", "args.max_epochs = 15\n", "args.patience = 3\n", "\n", "\n", "args.scale = True # True # True\n", "args.inverse = True # True # Defaultly False but @Zac thinks it should be True\n", "\n", "\n", "args.date_start = None # \"2021-01-01\"\n", "args.date_end = None\n", "args.date_test = \"2022-04-01\" # None\n", "\n", "handle_gpu(args, None)\n", "\n", "# idk what this is for\n", "args.detail_freq = args.freq\n", "args.freq = args.freq[-1:]\n", "\n", "print(\"Args in experiment:\")\n", "print(args)\n", "Exp = Exp_Informer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Train & Test *args.itr* models" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "928tzaA2AA2g", "outputId": "c19f673a-02d1-4f4d-91c3-d0f25e600443" }, "outputs": [], "source": [ "exp = None\n", "setting = None\n", "for ii in range(args.itr):\n", " # setting record of experiments\n", " setting = setting_from_args(args, ii)\n", "\n", " # set experiments\n", " exp = Exp(args)\n", "\n", " # train\n", " print(f\">>>>>>>start training : {setting}>>>>>>>>>>>>>>>>>>>>>>>>>>\")\n", " exp.train(setting)\n", "\n", " # test\n", " print(f\">>>>>>>testing : {setting}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<\")\n", " exp.test(setting)\n", "\n", " torch.cuda.empty_cache()" ] }, { "cell_type": "markdown", "metadata": { "id": "CDHF-HerAE3u" }, "source": [ "## Prediction" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "nTkluNNcyMJt", "outputId": "780767fe-6321-4081-e827-6701daeb375b" }, "outputs": [], "source": [ "# If you already have a trained model, you can set the arguments and model path, then initialize a Experiment and use it to predict\n", "# Prediction is a sequence which is adjacent to the last date of the data, and does not exist in the data\n", "# If you want to get more information about prediction, you can refer to code `exp/exp_informer.py function predict()` and `data/data_loader.py class Dataset_Pred`\n", "\n", "# args = dotdict(model='informer', data='WTH', root_path='./data/ETT/', data_path='WTH.csv', features='M', target='WetBulbCelsius', freq='h', checkpoints='./checkpoints/', seq_len=96, label_len=48, pred_len=24, enc_in=12, dec_in=12, c_out=12, d_model=512, n_heads=8, e_layers=2, d_layers=1, s_layers=[3, 2, 1], d_ff=2048, factor=5, padding=0, distil=True, dropout=0.05, attn='prob', t_embed='timeF', activation='gelu', output_attention=False, do_predict=False, mix=True, cols=None, num_workers=0, itr=2, max_epochs=6, batch_size=32, patience=3, learning_rate=0.0001, des='test', loss='mse', lradj='type1', use_amp=False, inverse=False, use_gpu=True, gpu=0, use_multi_gpu=False, devices='0,1,2,3', detail_freq='h')\n", "\n", "manual = False\n", "\n", "if manual:\n", " setting = \"informer_custom_ftMS_sl256_ll64_pl16_ei1_di1_co1_iFalse_dm512_nh8_el2_dl1_df2048_atprob_fc5_ebtimeF_dtTrue_mxTrue_exp_0\"\n", " args = args_from_setting(setting, args)\n", "\n", " exp = Exp(args)\n", "\n", "path = os.path.join(args.checkpoints, setting, \"checkpoint.pth\")\n", "\n", "exp.predict(setting, True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "KBCPbjGuzAZb", "outputId": "945dc447-88e8-4b08-b7e5-f0a0b486d138" }, "outputs": [], "source": [ "# the prediction will be saved in ./results/{setting}/real_prediction.npy\n", "\n", "prediction = np.load(f\"./results/{setting}/real_prediction.npy\")\n", "\n", "prediction.shape" ] }, { "cell_type": "markdown", "metadata": { "id": "5yFuVkTV30_j" }, "source": [ "### More details about Prediction - prediction function" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Sv9AR_Aw030r" }, "outputs": [], "source": [ "# here is the detailed code of function predict\n", "\n", "\n", "def predict(exp, setting, load=False):\n", " pred_data, pred_loader = exp._get_data(flag=\"pred\")\n", "\n", " if load:\n", " path = os.path.join(exp.args.checkpoints, setting)\n", " best_model_path = path + \"/\" + \"checkpoint.pth\"\n", " exp.model.load_state_dict(torch.load(best_model_path))\n", "\n", " exp.model.eval()\n", "\n", " preds = []\n", "\n", " for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(pred_loader):\n", " batch_x = batch_x.float().to(exp.device)\n", " batch_y = batch_y.float()\n", " batch_x_mark = batch_x_mark.float().to(exp.device)\n", " batch_y_mark = batch_y_mark.float().to(exp.device)\n", "\n", " # decoder input\n", " if exp.args.padding == 0:\n", " dec_inp = torch.zeros(\n", " [batch_y.shape[0], exp.args.pred_len, batch_y.shape[-1]]\n", " ).float()\n", " elif exp.args.padding == 1:\n", " dec_inp = torch.ones(\n", " [batch_y.shape[0], exp.args.pred_len, batch_y.shape[-1]]\n", " ).float()\n", " else:\n", " dec_inp = torch.zeros(\n", " [batch_y.shape[0], exp.args.pred_len, batch_y.shape[-1]]\n", " ).float()\n", " dec_inp = (\n", " torch.cat([batch_y[:, : exp.args.label_len, :], dec_inp], dim=1)\n", " .float()\n", " .to(exp.device)\n", " )\n", " # encoder - decoder\n", " if exp.args.use_amp:\n", " with torch.cuda.amp.autocast():\n", " if exp.args.output_attention:\n", " outputs = exp.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]\n", " else:\n", " outputs = exp.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)\n", " else:\n", " if exp.args.output_attention:\n", " outputs = exp.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]\n", " else:\n", " outputs = exp.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)\n", " f_dim = -1 if exp.args.features == \"MS\" else 0\n", " batch_y = batch_y[:, -exp.args.pred_len :, f_dim:].to(exp.device)\n", "\n", " pred = outputs.detach().cpu().numpy() # .squeeze()\n", "\n", " preds.append(pred)\n", "\n", " preds = np.array(preds)\n", " preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1])\n", "\n", " # result save\n", " folder_path = \"./results/\" + setting + \"/\"\n", " if not os.path.exists(folder_path):\n", " os.makedirs(folder_path)\n", "\n", " np.save(folder_path + \"real_prediction.npy\", preds)\n", "\n", " return preds" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "tVLWZL2a1pwB", "outputId": "421e9ae1-f024-42b6-c8cb-ed1d38c864cd" }, "outputs": [], "source": [ "# you can also use this prediction function to get result\n", "prediction = predict(exp, setting, True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 269 }, "id": "NwtZmQC71uc8", "outputId": "eec9d116-f122-42d9-8e02-c893ff764db0" }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "plt.figure()\n", "plt.plot(prediction[0, :, -1])\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "EnePVyrW4I14" }, "source": [ "### More details about Prediction - prediction dataset\n", "\n", "You can give a `root_path` and `data_path` of the data you want to forecast, and set `seq_len`, `label_len`, `pred_len` and other arguments as other Dataset. The difference is that you can set a more detailed freq such as `15min` or `3h` to generate the timestamp of prediction series.\n", "\n", "`Dataset_Pred` only has one sample (including `encoder_input: [1, seq_len, dim]`, `decoder_token: [1, label_len, dim]`, `encoder_input_timestamp: [1, seq_len, date_dim]`, `decoder_input_timstamp: [1, label_len+pred_len, date_dim]`). It will intercept the last sequence of the given data (seq_len data) to forecast the unseen future sequence (pred_len data)." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "ZpXhNGp34Hf4" }, "outputs": [], "source": [ "from data_provider.data_loader import Dataset_Pred\n", "from torch.utils.data import DataLoader" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "j4Rpd1q74T8N" }, "outputs": [], "source": [ "Data = Dataset_Pred\n", "timeenc = 0 if args.t_embed != \"timeF\" else 1\n", "flag = \"pred\"\n", "shuffle_flag = False\n", "drop_last = False\n", "batch_size = 1\n", "\n", "freq = args.detail_freq\n", "\n", "data_set = Data(args, flag=flag, freq=freq, timeenc=timeenc)\n", "\n", "data_loader = DataLoader(\n", " data_set,\n", " batch_size=batch_size,\n", " shuffle=shuffle_flag,\n", " num_workers=args.num_workers,\n", " drop_last=drop_last,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "42C84BfY6UPV", "outputId": "f5ccc428-db92-4708-e104-f5d29aa5adf9" }, "outputs": [], "source": [ "len(data_set), len(data_loader)" ] }, { "cell_type": "markdown", "metadata": { "id": "cNhEP_7sAgqC" }, "source": [ "## Visualization" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "vMRk8VkQ2Iko", "outputId": "bbf3cd10-7294-472d-e330-21e00f20963a" }, "outputs": [], "source": [ "# When we finished exp.train(setting) and exp.test(setting), we will get a trained model and the results of test experiment\n", "# The results of test experiment will be saved in ./results/{setting}/pred.npy (prediction of test dataset) and ./results/{setting}/true.npy (groundtruth of test dataset)\n", "\n", "preds = np.load(f\"./results/{setting}/pred.npy\")\n", "trues = np.load(f\"./results/{setting}/true.npy\")\n", "\n", "# [samples, pred_len, dimensions]\n", "preds.shape, trues.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZEGhDOmxAeAb" }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import seaborn as sns" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 265 }, "id": "kyPuOPGAAjl3", "outputId": "8554f6f8-c13a-43e1-b04b-5f27823445d0" }, "outputs": [], "source": [ "# draw OT prediction\n", "plt.figure()\n", "plt.plot(trues[0, :, -1], label=\"GroundTruth\")\n", "plt.plot(preds[0, :, -1], label=\"Prediction\")\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(trues.shape)\n", "print(preds.shape)\n", "MSE = np.square(np.subtract(trues, preds)).mean()\n", "RMSE = np.sqrt(MSE)\n", "\n", "print(\"against preds\", MSE, RMSE)\n", "\n", "\n", "MSE = np.square(np.subtract(preds, np.zeros(preds.shape))).mean()\n", "RMSE = np.sqrt(MSE)\n", "print(\"against 0s\", MSE, RMSE)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 265 }, "id": "43MIgWfpMYIB", "outputId": "327f64b7-363c-44f9-c7c8-1f654911068c" }, "outputs": [], "source": [ "# draw HUFL prediction\n", "plt.figure()\n", "plt.plot(trues[0, :, 0], label=\"GroundTruth\")\n", "plt.plot(preds[0, :, 0], label=\"Prediction\")\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hKmqhCfmt0xd" }, "outputs": [], "source": [ "from data_provider.data_loader import Dataset_Custom\n", "from torch.utils.data import DataLoader\n", "\n", "Data = Dataset_Custom\n", "timeenc = 0 if args.t_embed != \"timeF\" else 1\n", "flag = \"test\"\n", "shuffle_flag = False\n", "drop_last = True\n", "batch_size = 1\n", "data_set = Data(args, flag=flag, freq=freq, timeenc=timeenc)\n", "\n", "data_loader = DataLoader(\n", " data_set,\n", " batch_size=batch_size,\n", " shuffle=shuffle_flag,\n", " num_workers=args.num_workers,\n", " drop_last=drop_last,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "iflTTl0quCoK", "outputId": "3708fc91-517e-4c83-e133-059381bde271" }, "outputs": [], "source": [ "import os\n", "\n", "args.output_attention = True\n", "\n", "exp = Exp(args)\n", "\n", "model = exp.model\n", "\n", "path = os.path.join(args.checkpoints, setting, \"checkpoint.pth\")\n", "\n", "print(model.load_state_dict(torch.load(path)))\n", "\n", "df = pd.read_csv(os.path.join(args.root_path, args.data_path))\n", "df[args.cols].head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Attention Visualization" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "lDdzqm9HAk2C" }, "outputs": [], "source": [ "idx = 0\n", "for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(data_loader):\n", " if i != idx:\n", " continue\n", " batch_x = batch_x.float().to(exp.device)\n", " batch_y = batch_y.float()\n", "\n", " batch_x_mark = batch_x_mark.float().to(exp.device)\n", " batch_y_mark = batch_y_mark.float().to(exp.device)\n", "\n", " dec_inp = torch.zeros_like(batch_y[:, -args.pred_len :, :]).float()\n", " dec_inp = (\n", " torch.cat([batch_y[:, : args.label_len, :], dec_inp], dim=1)\n", " .float()\n", " .to(exp.device)\n", " )\n", "\n", " outputs, attn = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "hWef23vWAmUz", "outputId": "021eca83-e12f-402c-c87e-4fffa643d2f1" }, "outputs": [], "source": [ "attn[0].shape, attn[1].shape # , attn[2].shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "iZDH1fZgAnrl", "outputId": "991cae95-04a2-402d-f179-777e962f46fe" }, "outputs": [], "source": [ "layers = [0, 1]\n", "distil = \"Distil\" if args.distil else \"NoDistil\"\n", "for layer in layers:\n", " print(\"\\n\\n==========================\")\n", " print(\"Showing attention layer\", layer)\n", " print(\"==========================\\n\\n\")\n", " for h in range(0, args.n_heads):\n", " plt.figure(figsize=[10, 8])\n", " plt.title(f\"Informer, {distil}, attn:{args.attn} layer:{layer} head:{h}\")\n", " A = attn[layer][0, h].detach().cpu().numpy()\n", " ax = sns.heatmap(A, vmin=0, vmax=A.max() + 0.01)\n", " plt.show()" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "former", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6 (main, Oct 24 2022, 16:07:47) [GCC 11.2.0]" }, "vscode": { "interpreter": { "hash": "44e5710a47a66ec240c2a0834fd7c20e15c61536e70be6891d892a39679ad994" } } }, "nbformat": 4, "nbformat_minor": 0 }