{ "cells": [ { "cell_type": "markdown", "id": "11b542d7", "metadata": {}, "source": [ "# Loading Model " ] }, { "cell_type": "code", "execution_count": 3, "id": "e4c98d3f", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2021-11-26 18:18:05.997853: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n", "2021-11-26 18:18:05.997990: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n" ] } ], "source": [ "import os\n", "from tensorflow.keras import backend as K\n", "from tensorflow.keras.optimizers import Adadelta\n", "from tensorflow.keras.models import Sequential\n", "from tensorflow.keras.layers import Dense, Dropout, Flatten, Reshape, Activation\n", "from tensorflow.keras.layers import Conv2D, MaxPooling2D, Conv1D, Lambda\n", "from tensorflow.keras.models import load_model\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "7e0fb6b8", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2021-11-26 18:18:19.913350: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory\n", "2021-11-26 18:18:19.913471: W tensorflow/stream_executor/cuda/cuda_driver.cc:269] failed call to cuInit: UNKNOWN ERROR (303)\n", "2021-11-26 18:18:19.914244: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (LAPTOP-GP5BU3LN): /proc/driver/nvidia/version does not exist\n", "2021-11-26 18:18:19.915420: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n", "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n" ] } ], "source": [ "con_win_size = 9\n", "num_strings = 6\n", "num_classes = 21\n", "input_shape = (192, con_win_size, 1)\n", "\n", "def softmax_by_string(t):\n", " sh = K.shape(t)\n", " string_sm = []\n", " for i in range(num_strings):\n", " string_sm.append(K.expand_dims(K.softmax(t[:, i, :]), axis=1))\n", " return K.concatenate(string_sm, axis=1)\n", "\n", "def build_model():\n", "\n", " model = Sequential()\n", " model.add(\n", " Conv2D(32,\n", " kernel_size=(3, 3),\n", " activation=\"relu\",\n", " input_shape=input_shape))\n", " model.add(Conv2D(64, (3, 3), activation=\"relu\"))\n", " model.add(Conv2D(64, (3, 3), activation=\"relu\"))\n", " model.add(MaxPooling2D(pool_size=(2, 2)))\n", " model.add(Dropout(0.25))\n", " model.add(Flatten())\n", " model.add(Dense(128, activation=\"relu\"))\n", " model.add(Dropout(0.5))\n", " model.add(Dense(num_classes * num_strings)) # no activation\n", " model.add(Reshape((num_strings, num_classes)))\n", " model.add(Activation(softmax_by_string))\n", "\n", "# model.compile(loss=self.catcross_by_string,\n", "# optimizer=Adadelta(),\n", "# metrics=[self.avg_acc])\n", "\n", " return model\n", "model = build_model()" ] }, { "cell_type": "code", "execution_count": 5, "id": "2cc33c7b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"sequential\"\n", "_________________________________________________________________\n", " Layer (type) Output Shape Param # \n", "=================================================================\n", " conv2d (Conv2D) (None, 190, 7, 32) 320 \n", " \n", " conv2d_1 (Conv2D) (None, 188, 5, 64) 18496 \n", " \n", " conv2d_2 (Conv2D) (None, 186, 3, 64) 36928 \n", " \n", " max_pooling2d (MaxPooling2D (None, 93, 1, 64) 0 \n", " ) \n", " \n", " dropout (Dropout) (None, 93, 1, 64) 0 \n", " \n", " flatten (Flatten) (None, 5952) 0 \n", " \n", " dense (Dense) (None, 128) 761984 \n", " \n", " dropout_1 (Dropout) (None, 128) 0 \n", " \n", " dense_1 (Dense) (None, 126) 16254 \n", " \n", " reshape (Reshape) (None, 6, 21) 0 \n", " \n", " activation (Activation) (None, 6, 21) 0 \n", " \n", "=================================================================\n", "Total params: 833,982\n", "Trainable params: 833,982\n", "Non-trainable params: 0\n", "_________________________________________________________________\n" ] } ], "source": [ "model.summary()" ] }, { "cell_type": "markdown", "id": "9baba37d", "metadata": {}, "source": [ "# Load model weights " ] }, { "cell_type": "code", "execution_count": 6, "id": "8949e7cd", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "full_val0_75acc_weights.h5\n" ] } ], "source": [ "os.chdir('../h5-model')\n", "!ls" ] }, { "cell_type": "code", "execution_count": 7, "id": "1a0b640c", "metadata": {}, "outputs": [], "source": [ "model.load_weights('full_val0_75acc_weights.h5')" ] }, { "cell_type": "code", "execution_count": 8, "id": "738d20ab", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model" ] }, { "cell_type": "code", "execution_count": 9, "id": "2631aefb", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[array([[[[-1.14527084e-01, -7.00837746e-02, 1.64513141e-02,\n", " -8.46667960e-02, 7.80746713e-02, -6.10514767e-02,\n", " 1.10172659e-01, -1.30868405e-01, 1.17944069e-01,\n", " 8.05187821e-02, 1.33489624e-01, -6.60827989e-03,\n", " -7.73175359e-02, 3.72690037e-02, 1.03387728e-01,\n", " 8.85068849e-02, -5.47240674e-02, -9.36797783e-02,\n", " 2.81246901e-02, 3.64853963e-02, 1.42038763e-01,\n", " -3.35541889e-02, -2.07621288e-02, 3.52249779e-02,\n", " -1.24825507e-01, -1.00334905e-01, 6.65799752e-02,\n", " -1.16388358e-01, -3.19974236e-02, -1.70573920e-01,\n", " -7.60619119e-02, 1.46001697e-01]],\n", "\n", " [[-4.05685678e-02, 1.66681096e-01, 1.29631609e-01,\n", " -2.65946500e-02, 5.08399941e-02, -1.38498351e-01,\n", " -6.97663724e-02, -1.22033164e-01, 1.14255841e-03,\n", " 1.18322901e-01, -1.16930827e-01, 1.51325539e-01,\n", " -5.27841933e-02, 6.96947202e-02, -1.35874555e-01,\n", " 1.45532504e-01, 1.40910745e-01, 4.45879586e-02,\n", " 1.05302148e-01, -1.27617968e-02, 1.90328583e-02,\n", " 1.54919162e-01, -8.32226053e-02, 7.67920837e-02,\n", " -1.18180208e-01, -9.67657194e-03, 1.21145966e-02,\n", " 1.09448351e-01, 7.66155794e-02, -2.73343250e-02,\n", " 1.10605858e-01, 1.43061072e-01]],\n", "\n", " [[ 2.03716401e-02, 1.17583640e-01, 3.14034522e-02,\n", " -2.98159532e-02, 1.47711828e-01, -1.24904536e-01,\n", " -1.05287507e-01, 3.88677511e-03, 4.23943391e-03,\n", " -1.54330581e-02, -1.23788953e-01, 1.38949722e-01,\n", " -1.39491573e-01, -1.94043014e-02, 1.83309652e-02,\n", " 2.22375616e-02, -9.14291590e-02, -1.77925061e-02,\n", " -5.52734993e-02, -2.83049466e-03, -3.07116900e-02,\n", " 1.12627268e-01, 5.87593913e-02, 3.89475748e-02,\n", " 2.31335610e-02, -8.66767243e-02, 1.85129251e-02,\n", " 1.07603828e-05, 9.58463252e-02, -1.48589209e-01,\n", " -1.14813887e-01, 5.83817475e-02]]],\n", "\n", "\n", " [[[-1.54552311e-02, 8.86145458e-02, 1.37174979e-01,\n", " -3.54579799e-02, -3.65477434e-04, 1.31149637e-03,\n", " -1.39428571e-01, 3.68981659e-02, 2.25509494e-03,\n", " 1.01817720e-01, -7.05295149e-03, 4.64775227e-02,\n", " 1.07821725e-01, -1.39156682e-02, -8.30163658e-02,\n", " 2.88936030e-02, -1.18742168e-01, 1.35299832e-01,\n", " 1.28781661e-01, 1.42011806e-01, -3.56708653e-02,\n", " 8.54031891e-02, 4.83890326e-04, 3.91381048e-02,\n", " 9.17226002e-02, 1.44389287e-01, 1.64631866e-02,\n", " 2.43344475e-02, 5.15554706e-03, -1.37477908e-02,\n", " 4.20040190e-02, -9.59556773e-02]],\n", "\n", " [[ 1.20128736e-01, 9.06534269e-02, -1.06412463e-01,\n", " -9.61880833e-02, -4.28094938e-02, -8.34638178e-02,\n", " 1.01216726e-01, 1.19200364e-01, -6.79962635e-02,\n", " -4.13414948e-02, 1.17787890e-01, 4.02927957e-02,\n", " -1.02881365e-01, 2.34187674e-02, -2.43473239e-02,\n", " 1.37818739e-01, 9.66306180e-02, 1.06404774e-01,\n", " 2.48169787e-02, -1.90081764e-02, -7.08291531e-02,\n", " 1.27087206e-01, 2.25576689e-03, 1.20728359e-01,\n", " 3.59125547e-02, 1.33096740e-01, -2.38511506e-02,\n", " 1.27964720e-01, -4.38287668e-02, -1.68812107e-02,\n", " 5.48647605e-02, -1.01585865e-01]],\n", "\n", " [[-4.38075438e-02, -1.08206250e-01, -5.08552603e-02,\n", " -2.76047569e-02, 5.21482974e-02, -4.43790257e-02,\n", " 6.54815659e-02, 1.41215220e-01, 1.56878546e-01,\n", " 1.57123864e-01, -1.15259722e-01, 3.05716433e-02,\n", " -1.57100260e-02, 5.58211654e-02, 6.03962578e-02,\n", " 1.40839685e-02, -1.24976538e-01, -1.03438301e-02,\n", " 1.49092138e-01, -1.10869348e-01, 4.13206294e-02,\n", " -3.75214815e-02, -6.54979656e-03, 6.61133183e-03,\n", " 1.42542422e-01, 2.22881120e-02, 7.60144815e-02,\n", " 5.29681854e-02, -5.93977571e-02, -2.87397783e-02,\n", " 1.32607371e-01, -6.86246827e-02]]],\n", "\n", "\n", " [[[-2.41188420e-04, 9.70838442e-02, 1.21379890e-01,\n", " -6.55761063e-02, 4.85191382e-02, -1.25394957e-02,\n", " 1.90682467e-02, -1.14000607e-02, 1.12373963e-01,\n", " -8.51348639e-02, 1.17202073e-01, -3.49313840e-02,\n", " 1.42564625e-01, 5.70051037e-02, -7.64764613e-03,\n", " -7.55342767e-02, -1.38341233e-01, 1.18126474e-01,\n", " -8.80356058e-02, 9.80901644e-02, 7.88373724e-02,\n", " 1.37699112e-01, 1.26460463e-01, 5.68267666e-02,\n", " 2.54523661e-03, -4.80788685e-02, 1.13006271e-01,\n", " 6.46024644e-02, 8.26503187e-02, -2.12154686e-01,\n", " 8.43161196e-02, 1.27397224e-01]],\n", "\n", " [[ 1.59892917e-01, -7.33290017e-02, 1.30629882e-01,\n", " -9.62211117e-02, -7.31635466e-02, 1.43488258e-01,\n", " 2.40880195e-02, 4.98163886e-02, 9.17535089e-03,\n", " 1.26280203e-01, 8.00801590e-02, 1.25236632e-02,\n", " 4.40345593e-02, 1.45838484e-01, -8.24865997e-02,\n", " 1.72862317e-02, 1.15627974e-01, 1.23322628e-01,\n", " -1.19584709e-01, -4.25971635e-02, -1.03793114e-01,\n", " -2.79876031e-02, 6.98375851e-02, -1.15725398e-01,\n", " 3.00386306e-02, 1.08486868e-01, 4.40593846e-02,\n", " -1.35479784e-02, -1.18479453e-01, -5.82292788e-02,\n", " -6.47029877e-02, -7.74575174e-02]],\n", "\n", " [[ 9.57414731e-02, -1.29820511e-01, -9.98642072e-02,\n", " -1.02473915e-01, -9.89412069e-02, -5.47248535e-02,\n", " -9.99636874e-02, 1.56961426e-01, 7.68359527e-02,\n", " -6.36292174e-02, 1.54039621e-01, 9.62654725e-02,\n", " 1.07983612e-01, -8.79328921e-02, 3.56053412e-02,\n", " 8.89461860e-02, 6.79489970e-02, -9.36115608e-02,\n", " -1.03025272e-01, 1.29710495e-01, -3.59881110e-02,\n", " 1.07964292e-01, -4.12224373e-03, -9.50005278e-02,\n", " 8.87222309e-03, -3.48649500e-03, 1.18317653e-03,\n", " 2.71664131e-02, -1.17610984e-01, -8.80862847e-02,\n", " -1.28909349e-01, -6.64263666e-02]]]], dtype=float32), array([ 0.01282204, 0.00842529, 0.01232481, 0.15097225, 0.0193409 ,\n", " 0.00225031, -0.01009997, 0.01819846, 0.01107948, 0.01677031,\n", " 0.00665016, 0.00886661, 0.00297947, 0.01281719, 0.00087603,\n", " 0.01799514, 0.00463679, 0.00183174, 0.00038497, 0.01992698,\n", " 0.020066 , 0.014112 , 0.00364966, 0.0164441 , 0.00834432,\n", " 0.01856335, 0.00073165, 0.00312082, 0.02126137, 0.19401325,\n", " 0.0191715 , 0.00493639], dtype=float32)], [array([[[[-0.024253 , -0.0422027 , 0.09478629, ..., 0.04256212,\n", " -0.0148557 , 0.02615436],\n", " [-0.07237651, -0.0235385 , -0.05396594, ..., -0.02231066,\n", " -0.0129886 , 0.09288608],\n", " [-0.00604533, -0.03374109, 0.02361795, ..., 0.04941289,\n", " -0.02656637, 0.04705735],\n", " ...,\n", " [ 0.09154902, -0.03423662, -0.0536489 , ..., -0.02788846,\n", " 0.01413645, -0.00983084],\n", " [-0.00106758, 0.03129057, 0.03199857, ..., 0.04893037,\n", " 0.02027882, -0.02536237],\n", " [-0.0399229 , 0.02251323, -0.06572151, ..., 0.05626247,\n", " -0.02962151, 0.03319808]],\n", "\n", " [[-0.04706322, 0.01705749, -0.0130353 , ..., 0.04436897,\n", " 0.05803913, 0.03344763],\n", " [ 0.05700628, 0.00100709, 0.06773836, ..., 0.01194232,\n", " -0.07893106, 0.00214234],\n", " [-0.03080998, -0.03037829, 0.04785987, ..., -0.04617473,\n", " -0.07698756, 0.08599512],\n", " ...,\n", " [ 0.01612184, 0.00439207, 0.05504813, ..., -0.04482628,\n", " -0.05439194, 0.00540957],\n", " [ 0.0030311 , -0.07972671, -0.0170755 , ..., -0.07501575,\n", " -0.03553087, 0.03290023],\n", " [-0.02717343, -0.04108334, -0.09020455, ..., 0.03700458,\n", " 0.04779815, 0.00623057]],\n", "\n", " [[ 0.04955999, 0.06535412, 0.00848441, ..., -0.01225876,\n", " 0.06316476, -0.0424533 ],\n", " [-0.06809513, -0.04177037, 0.01570524, ..., -0.0829807 ,\n", " -0.06015162, -0.02058587],\n", " [ 0.01344685, -0.05208585, 0.05211343, ..., 0.02825315,\n", " -0.05496167, -0.01574153],\n", " ...,\n", " [-0.02004736, 0.05640408, -0.06429324, ..., 0.02216084,\n", " -0.07524338, 0.05568274],\n", " [-0.0464257 , 0.02633075, -0.04085077, ..., -0.00655695,\n", " -0.05695147, 0.00737705],\n", " [ 0.01410426, -0.06232076, -0.05227138, ..., -0.02583781,\n", " 0.06497865, 0.05954316]]],\n", "\n", "\n", " [[[-0.03931937, -0.05798993, 0.04763195, ..., 0.01779707,\n", " -0.01041569, 0.07967227],\n", " [ 0.07501813, 0.05678853, 0.08064738, ..., -0.05928201,\n", " -0.04068109, -0.038105 ],\n", " [-0.00263615, 0.07535851, 0.06856363, ..., 0.07070952,\n", " -0.05825863, -0.05238011],\n", " ...,\n", " [-0.0432196 , 0.04922615, -0.01277266, ..., 0.09061682,\n", " -0.00326792, -0.05467705],\n", " [ 0.0731143 , -0.0286781 , 0.0418772 , ..., -0.03771072,\n", " -0.00857606, -0.00013882],\n", " [-0.02330487, -0.06271102, 0.05608054, ..., -0.05058254,\n", " 0.01431858, -0.05069762]],\n", "\n", " [[-0.03734201, -0.05405557, 0.0199756 , ..., 0.04369628,\n", " 0.02550276, -0.04005559],\n", " [-0.08419245, 0.07661432, 0.04595702, ..., -0.04621217,\n", " 0.05589479, 0.09150885],\n", " [ 0.01033168, 0.06317565, -0.03273775, ..., 0.04538567,\n", " -0.00335575, 0.04740722],\n", " ...,\n", " [ 0.00403877, 0.02548648, 0.02838034, ..., 0.05241664,\n", " -0.05179545, 0.07681815],\n", " [ 0.07288977, 0.05695255, 0.08590015, ..., -0.02061978,\n", " -0.04532065, 0.01756267],\n", " [ 0.02887062, -0.01312189, 0.07261 , ..., -0.07375655,\n", " -0.00277509, 0.08386153]],\n", "\n", " [[-0.06550452, 0.00217563, -0.05772574, ..., 0.01401513,\n", " 0.06173348, -0.06505993],\n", " [ 0.02040756, 0.07027661, 0.08179134, ..., 0.00169804,\n", " 0.06531129, -0.06225653],\n", " [-0.05966496, -0.03525139, 0.04835444, ..., 0.07861889,\n", " -0.08056166, 0.0102482 ],\n", " ...,\n", " [-0.04471021, -0.06418537, 0.05379616, ..., 0.09416924,\n", " 0.05024889, -0.00740042],\n", " [ 0.07122994, -0.04815977, 0.00293711, ..., 0.0247791 ,\n", " 0.01953502, -0.04322057],\n", " [-0.02085856, 0.01199629, 0.05656463, ..., 0.0611779 ,\n", " 0.03310826, -0.0012296 ]]],\n", "\n", "\n", " [[[ 0.04200918, -0.03176309, -0.00215876, ..., -0.03419558,\n", " 0.00789071, -0.00887679],\n", " [-0.07210154, -0.07641898, 0.00762249, ..., -0.0815539 ,\n", " 0.00694853, 0.02748295],\n", " [-0.02342687, 0.00460077, 0.06684699, ..., -0.011603 ,\n", " 0.01094546, -0.02282121],\n", " ...,\n", " [-0.00595846, 0.01482918, -0.00146219, ..., -0.02290521,\n", " 0.06868473, -0.01946544],\n", " [-0.0675327 , -0.01317107, -0.0060557 , ..., -0.0756464 ,\n", " 0.03366766, 0.02651644],\n", " [-0.0044026 , 0.02264087, 0.01288516, ..., -0.02590666,\n", " 0.07361539, 0.00624638]],\n", "\n", " [[ 0.00765323, -0.08195467, 0.0399415 , ..., 0.05840761,\n", " -0.06508236, 0.04785883],\n", " [ 0.02622476, 0.06906734, 0.00525846, ..., -0.0820672 ,\n", " 0.03925994, -0.06250729],\n", " [-0.05014937, 0.07326727, 0.05716334, ..., -0.05109742,\n", " 0.04840945, -0.04639239],\n", " ...,\n", " [ 0.08495307, 0.00784765, -0.04938578, ..., -0.06867561,\n", " -0.04457028, -0.05976995],\n", " [-0.00315331, -0.07650689, -0.02901176, ..., 0.02287063,\n", " -0.05147613, -0.07230018],\n", " [ 0.04023301, 0.01832285, 0.00835371, ..., -0.08145549,\n", " -0.03039736, 0.08281726]],\n", "\n", " [[ 0.04526403, -0.06862365, 0.01910903, ..., 0.01613895,\n", " 0.01218551, -0.05608582],\n", " [-0.0182588 , -0.00290652, 0.08782344, ..., 0.03537245,\n", " 0.03173613, 0.08155032],\n", " [-0.006832 , -0.06779082, 0.06450103, ..., 0.02297189,\n", " -0.00151633, -0.05260633],\n", " ...,\n", " [-0.06493159, 0.05477048, -0.06155345, ..., -0.01009615,\n", " -0.00583238, -0.03391334],\n", " [-0.07407849, -0.0150975 , -0.02837432, ..., -0.00476507,\n", " -0.0294109 , 0.02647972],\n", " [ 0.06788563, -0.01472817, -0.03993933, ..., -0.05126514,\n", " 0.02446889, -0.02058403]]]], dtype=float32), array([ 0.01062784, 0.00730821, 0.00110831, 0.00529176, 0.0074725 ,\n", " 0.01280306, -0.00082352, -0.00092877, 0.00996368, 0.00504148,\n", " 0.01015213, 0.00161414, 0.00683421, 0.00506695, 0.0152585 ,\n", " 0.00935267, -0.00847481, 0.01573595, 0.00773036, 0.00435062,\n", " 0.00710985, -0.00040831, 0.01054417, 0.00169495, 0.01117416,\n", " 0.00450248, 0.00614199, 0.00296131, 0.00734073, 0.0061549 ,\n", " 0.00523263, 0.00573932, 0.00048997, 0.01583373, -0.01164677,\n", " 0.06779497, 0.01095906, 0.00366794, 0.00746105, 0.0081636 ,\n", " 0.011236 , 0.00170474, 0.00809041, -0.00094938, 0.00053891,\n", " 0.00911564, 0.00970696, 0.00470014, 0.00345552, 0.01346137,\n", " 0.00264542, 0.02158701, 0.01649638, 0.00772496, -0.00046383,\n", " 0.00100259, -0.00143617, 0.01159115, 0.0042935 , 0.01302496,\n", " 0.00762347, 0.0064116 , 0.00162997, 0.00486967], dtype=float32)], [array([[[[-0.00342654, -0.04331037, 0.06396149, ..., 0.03001836,\n", " -0.05896892, -0.05154102],\n", " [-0.06167585, -0.00195316, 0.00229552, ..., 0.05734748,\n", " 0.02626396, -0.03857701],\n", " [-0.04266005, -0.01869453, -0.0315816 , ..., 0.03122951,\n", " -0.04484127, 0.04592305],\n", " ...,\n", " [-0.02760076, -0.01504721, 0.06166425, ..., -0.02279343,\n", " 0.0510645 , 0.05356678],\n", " [ 0.00943752, 0.0712416 , 0.01424259, ..., -0.06107641,\n", " -0.06270019, 0.02656677],\n", " [ 0.05959832, -0.06812761, 0.00832332, ..., 0.03421777,\n", " 0.04023626, -0.02532436]],\n", "\n", " [[-0.0045432 , 0.04211973, -0.05109784, ..., 0.01799299,\n", " 0.03073769, 0.05999514],\n", " [-0.06066577, -0.06746326, 0.04165009, ..., 0.05161154,\n", " 0.05618427, 0.05961972],\n", " [ 0.01115724, -0.02185192, -0.0344601 , ..., -0.0575684 ,\n", " 0.02744669, 0.05327303],\n", " ...,\n", " [ 0.05539149, -0.03274316, -0.0128122 , ..., -0.02577713,\n", " -0.06667926, -0.00503341],\n", " [ 0.00290914, 0.01094299, 0.0690598 , ..., -0.0639387 ,\n", " -0.0304285 , -0.05609442],\n", " [ 0.01713718, -0.02896782, -0.0274498 , ..., 0.06216149,\n", " 0.05976059, -0.06856986]],\n", "\n", " [[-0.00985544, 0.07635961, 0.0664077 , ..., -0.05182622,\n", " -0.02440396, -0.05477538],\n", " [ 0.06196726, 0.05144979, -0.06026461, ..., 0.06545223,\n", " -0.00928823, -0.05784639],\n", " [ 0.03207679, 0.05510361, 0.0682982 , ..., 0.02693168,\n", " 0.02768218, 0.06213966],\n", " ...,\n", " [-0.02250849, 0.01917696, 0.05094068, ..., -0.07000176,\n", " -0.0098856 , 0.05024325],\n", " [ 0.06882056, -0.0149803 , 0.00501051, ..., -0.00118458,\n", " 0.06491743, 0.06050596],\n", " [-0.00518594, 0.05167596, -0.01857725, ..., -0.04059281,\n", " 0.03668557, -0.03235075]]],\n", "\n", "\n", " [[[-0.03364386, -0.07194262, -0.06738777, ..., -0.00071171,\n", " -0.03846123, -0.02427467],\n", " [-0.07132583, 0.05083149, 0.02320879, ..., -0.0320206 ,\n", " -0.07138474, -0.02620506],\n", " [ 0.0766857 , 0.05631109, -0.00654307, ..., -0.01280011,\n", " -0.01622096, 0.01766451],\n", " ...,\n", " [ 0.0657953 , -0.05668041, 0.05629045, ..., 0.04353549,\n", " 0.04993413, 0.05720853],\n", " [-0.04759688, -0.00882784, 0.06422342, ..., -0.0569385 ,\n", " -0.00691132, 0.0473676 ],\n", " [-0.00946682, -0.02307444, 0.05863743, ..., 0.04215981,\n", " -0.06196372, -0.07203145]],\n", "\n", " [[-0.02533285, -0.06276552, 0.0557471 , ..., -0.00596738,\n", " -0.0419081 , 0.02013043],\n", " [ 0.01459188, -0.01027497, -0.02529136, ..., 0.00577768,\n", " 0.04238577, 0.07653186],\n", " [-0.02096691, -0.01015441, 0.07159476, ..., 0.0100741 ,\n", " 0.05877277, 0.04954444],\n", " ...,\n", " [ 0.04219133, -0.01578625, 0.0034997 , ..., -0.05270585,\n", " 0.01687538, 0.04098083],\n", " [ 0.04595196, -0.05584689, -0.01555275, ..., 0.05472008,\n", " -0.00302994, -0.01570323],\n", " [ 0.07284594, -0.06325058, -0.03903197, ..., -0.00839465,\n", " 0.06461746, -0.02404273]],\n", "\n", " [[ 0.0096987 , -0.04807894, 0.05738699, ..., -0.00678441,\n", " -0.00188825, -0.06907655],\n", " [-0.04330473, -0.00826486, -0.02970097, ..., 0.02027568,\n", " 0.06968693, 0.01775828],\n", " [ 0.03815399, 0.02993638, -0.00304228, ..., -0.00911514,\n", " -0.04824452, 0.04842124],\n", " ...,\n", " [-0.01245415, -0.00307789, 0.01629793, ..., -0.01054224,\n", " -0.06006046, 0.02579017],\n", " [-0.0468111 , 0.03817052, -0.02060169, ..., 0.05800638,\n", " 0.01519652, -0.05498468],\n", " [-0.03234467, 0.05112289, -0.0558816 , ..., 0.0660087 ,\n", " 0.01078841, -0.04599995]]],\n", "\n", "\n", " [[[ 0.06579272, 0.06292238, -0.06205738, ..., -0.01055096,\n", " -0.02026774, 0.03126688],\n", " [-0.02270779, 0.01040752, -0.05449043, ..., 0.05368847,\n", " -0.02615539, -0.01639705],\n", " [-0.06283817, 0.04479095, -0.00947504, ..., 0.00649719,\n", " 0.04338491, 0.03142663],\n", " ...,\n", " [-0.02240729, 0.02466509, -0.0456233 , ..., -0.02917192,\n", " 0.00957553, -0.02503187],\n", " [ 0.03560261, 0.0453327 , 0.04110314, ..., -0.06202294,\n", " -0.05693126, -0.01402239],\n", " [ 0.03543817, 0.03156625, -0.0421279 , ..., -0.05898493,\n", " -0.00786905, -0.06690583]],\n", "\n", " [[ 0.07184749, 0.01087924, -0.06866235, ..., -0.03493822,\n", " 0.03458441, -0.06760138],\n", " [ 0.06523995, 0.01153559, 0.02874224, ..., -0.05268464,\n", " 0.05645235, 0.05533127],\n", " [ 0.0291126 , 0.03678867, -0.07331565, ..., -0.02871324,\n", " -0.00554177, 0.02865294],\n", " ...,\n", " [ 0.02899255, 0.02320939, -0.04319515, ..., 0.05656745,\n", " 0.0172996 , -0.03161235],\n", " [ 0.05468991, -0.03315257, -0.03777886, ..., -0.01933525,\n", " 0.061138 , 0.03111246],\n", " [-0.03651225, -0.04889263, 0.03472312, ..., -0.00183615,\n", " -0.05811863, -0.01943744]],\n", "\n", " [[-0.02031692, -0.03774923, 0.05279135, ..., -0.05021623,\n", " -0.01255627, -0.02848115],\n", " [ 0.02470684, -0.06591625, -0.06063269, ..., -0.06750314,\n", " 0.05265184, 0.05876762],\n", " [ 0.03412772, 0.04901206, 0.02559118, ..., -0.01527247,\n", " -0.06042558, -0.06003268],\n", " ...,\n", " [ 0.0691038 , -0.0599289 , -0.0148479 , ..., 0.06739101,\n", " 0.06591565, 0.05373253],\n", " [-0.0078099 , 0.02295279, 0.01884985, ..., -0.07014979,\n", " 0.03443056, 0.02859895],\n", " [ 0.07652654, 0.05678432, 0.0142913 , ..., 0.01117795,\n", " -0.00777341, -0.04807463]]]], dtype=float32), array([ 0.00863983, 0.00979703, 0.00029056, 0.00650495, -0.00146489,\n", " 0.00782927, 0.00184375, -0.00037028, 0.00130819, 0.00799961,\n", " 0.00022063, -0.0040557 , 0.00403143, 0.01416676, -0.00095555,\n", " 0.00474288, -0.00177316, 0.01135979, 0.00562979, 0.00588493,\n", " 0.00800752, 0.00385167, -0.00144799, 0.00228242, -0.00025023,\n", " 0.00926757, 0.00601811, 0.00086004, 0.00097395, -0.0001493 ,\n", " 0.01077604, 0.00182313, 0.00465522, 0.00521685, 0.00046088,\n", " 0.00873548, 0.00155117, 0.00512709, 0.00316863, 0.00322584,\n", " 0.01374802, 0.00205585, 0.00394222, 0.00794277, 0.00025733,\n", " 0.00068965, 0.00385011, -0.00408257, 0.00477339, 0.00350911,\n", " 0.00169597, 0.0067974 , 0.00139975, 0.00931548, 0.00708965,\n", " 0.00655331, 0.00590179, 0.0085585 , 0.00126311, 0.00255866,\n", " 0.00632384, 0.00209401, 0.00167473, 0.02259504], dtype=float32)], [], [], [], [array([[-0.00319826, -0.0313651 , -0.00458261, ..., -0.00520918,\n", " 0.01125336, -0.02698919],\n", " [ 0.02926771, -0.02163595, -0.02039317, ..., -0.00414997,\n", " 0.01804318, 0.01336489],\n", " [-0.01380948, 0.00864907, 0.02589456, ..., -0.00306701,\n", " -0.03058925, 0.00575677],\n", " ...,\n", " [ 0.00666343, -0.00806017, -0.0238623 , ..., -0.01236153,\n", " 0.01327147, -0.01528543],\n", " [ 0.01794382, 0.01397446, 0.00450658, ..., -0.00381193,\n", " -0.0185669 , 0.00763699],\n", " [ 0.01470017, -0.00038448, 0.01900323, ..., -0.02822573,\n", " 0.00447253, -0.03020758]], dtype=float32), array([ 1.21881231e-03, -2.54789903e-03, 1.19442202e-03, -1.34425715e-03,\n", " 1.28064165e-03, 2.38297111e-03, -1.50695952e-04, 3.34231346e-03,\n", " -1.16801728e-03, -3.90314963e-03, 1.14376098e-03, -2.61399924e-04,\n", " -6.27846690e-04, -2.50759674e-03, -2.83996086e-03, 1.94849214e-03,\n", " 1.53332343e-03, -3.28186783e-03, -2.84067611e-03, 1.66528090e-03,\n", " 3.39212851e-03, 2.00245599e-03, -3.89982015e-05, 1.91379630e-03,\n", " 2.89280084e-03, 1.92826532e-03, -2.13255826e-03, -6.53538853e-04,\n", " 3.05190729e-03, 1.80690593e-04, 3.58022517e-04, 9.04452987e-04,\n", " -1.38803397e-03, 8.64794827e-04, -2.70661060e-03, 8.71402968e-04,\n", " -3.01148323e-03, -3.05797643e-04, -1.53289316e-03, -3.18018813e-03,\n", " 9.16826422e-04, -1.54540510e-04, -2.67823669e-03, 5.19396162e-05,\n", " -1.56265486e-03, 9.03457345e-04, 5.33110870e-04, 8.73393321e-04,\n", " -2.19670730e-03, 1.97667652e-03, -1.45599048e-03, 2.40320037e-03,\n", " 3.57997662e-04, -1.85142527e-03, 3.45550198e-03, 3.04877199e-03,\n", " -2.18497775e-03, 1.84163474e-03, -3.10299709e-03, 1.64441078e-03,\n", " 3.76233911e-05, 2.53297994e-03, -2.79979635e-04, 1.99681628e-04,\n", " 1.31242373e-03, -4.32697171e-03, -3.85067379e-03, 6.92477857e-04,\n", " 9.81733319e-04, 1.39035669e-03, 4.38014482e-04, 7.65532313e-04,\n", " 5.72401215e-04, 1.50362845e-04, 1.50697131e-03, 1.80909561e-03,\n", " 1.70004729e-03, 2.42656888e-03, 4.28195024e-04, -8.01270944e-04,\n", " -1.67141808e-03, 1.52487063e-03, -2.88763316e-03, -1.46857360e-06,\n", " 1.59786874e-03, -3.78585886e-04, 1.79942732e-03, 6.71663147e-04,\n", " -1.18505290e-04, 1.09255116e-03, 2.97151529e-03, 1.46413071e-03,\n", " -2.76076468e-03, 1.92334398e-03, 5.22327842e-04, 1.67542917e-03,\n", " -2.36922293e-04, -5.91301941e-04, 2.12733354e-03, -3.67739424e-03,\n", " -5.31267084e-04, -2.47187470e-03, -2.47332908e-04, -3.97768104e-03,\n", " -1.33028498e-03, 7.63126591e-04, -4.55573580e-04, 4.28782689e-04,\n", " 3.35209392e-04, 2.21631653e-03, 2.85795570e-04, 1.57292362e-03,\n", " -1.35600357e-03, 1.62575534e-03, 1.48851553e-03, -9.38432757e-04,\n", " 1.09718836e-04, 4.95090440e-04, 1.23268273e-03, -3.80894344e-04,\n", " -1.40213815e-03, 2.17806362e-03, 1.34098425e-03, -2.33028550e-03,\n", " 5.85740898e-04, 1.87760510e-03, -4.04093880e-03, -1.93038362e-03],\n", " dtype=float32)], [], [array([[-0.01520345, -0.05956176, -0.10306229, ..., -0.16983195,\n", " 0.0258285 , 0.05963764],\n", " [ 0.02494746, -0.15091033, -0.02940223, ..., 0.1418434 ,\n", " 0.09850448, 0.02380073],\n", " [ 0.08474976, -0.15346092, 0.07144733, ..., -0.17547144,\n", " -0.05838747, 0.0393731 ],\n", " ...,\n", " [ 0.09131451, -0.13077933, -0.02829319, ..., -0.00425629,\n", " -0.09398489, 0.07388706],\n", " [ 0.06928039, 0.03727286, -0.02080682, ..., 0.07141316,\n", " 0.14023592, -0.02679718],\n", " [-0.06331176, -0.12194489, -0.09191213, ..., -0.1400669 ,\n", " 0.08208108, -0.12799302]], dtype=float32), array([ 0.03490328, -0.01235474, 0.01282894, -0.00811613, 0.00024564,\n", " -0.00229365, -0.02019599, 0.000125 , -0.01477893, -0.02732999,\n", " -0.05141335, -0.05286398, -0.06364606, -0.05755578, -0.06141432,\n", " -0.05327236, -0.06219975, -0.05652038, -0.04803422, -0.05056494,\n", " -0.05464143, 0.02384927, -0.00999154, -0.00930119, 0.00256772,\n", " 0.00230656, 0.00612281, -0.00475113, 0.00501595, -0.00940385,\n", " -0.00682612, -0.01799249, -0.03766802, -0.04428675, -0.05001188,\n", " -0.062364 , -0.05280105, -0.05576333, -0.05021756, -0.04794609,\n", " -0.04000037, -0.04536385, 0.02194885, -0.0135642 , -0.02653779,\n", " 0.00452685, 0.00312414, 0.00827176, 0.00347508, 0.00281434,\n", " -0.00448135, -0.00441511, 0.00217725, -0.00283862, -0.02372732,\n", " -0.04351533, -0.04282343, -0.06187819, -0.05622285, -0.05070539,\n", " -0.05412994, -0.05883593, -0.05785481, 0.0209148 , -0.01931002,\n", " -0.01383029, -0.00782868, 0.01649438, -0.00548673, 0.0085587 ,\n", " 0.00884054, -0.00823902, 0.00428784, -0.00035183, -0.01080555,\n", " -0.01893978, -0.02529691, -0.03771995, -0.05600805, -0.05347722,\n", " -0.05915058, -0.05887462, -0.05110266, -0.05831904, 0.02405787,\n", " 0.00950599, -0.00500965, -0.0105366 , 0.00339871, -0.01372263,\n", " 0.00801604, -0.00359699, 0.00951477, -0.00672276, -0.02112747,\n", " -0.01138078, -0.00272069, -0.03242434, -0.01675265, -0.05859887,\n", " -0.05299615, -0.04749462, -0.04528466, -0.04247516, -0.05508986,\n", " 0.02977874, -0.00369393, -0.00321471, -0.01700962, 0.00042408,\n", " 0.00259167, -0.00090576, -0.00949067, -0.0183843 , -0.0178394 ,\n", " -0.0050512 , -0.00871942, -0.02500123, -0.01665437, -0.03682591,\n", " -0.03818711, -0.02281728, -0.04489514, -0.04617931, -0.03573886,\n", " -0.06009895], dtype=float32)], [], []]\n", "2\n" ] } ], "source": [ "weights = []\n", "for layer in model.layers:\n", " weights.append(layer.get_weights())\n", "print(weights)\n", "print(len(weights[0]))" ] }, { "cell_type": "markdown", "id": "5f558f99", "metadata": {}, "source": [ "# preprocess experimentmono.wav \n" ] }, { "cell_type": "code", "execution_count": 10, "id": "33d74cf2", "metadata": {}, "outputs": [], "source": [ "import os\n", "import numpy as np\n", "import jams\n", "from scipy.io import wavfile\n", "import sys\n", "import librosa\n", "from tensorflow.keras.utils import to_categorical\n", "\n", "\n", "class TabDataReprGen:\n", " def __init__(self, mode=\"c\"):\n", " # file path to the GuitarSet dataset\n", " path = \"data/GuitarSet/\"\n", " self.path_audio = path + \"audio/audio_mic/\"\n", " self.path_anno = path + \"annotation/\"\n", "\n", " # labeling parameters\n", " self.string_midi_pitches = [40, 45, 50, 55, 59, 64]\n", " self.highest_fret = 19\n", " self.num_classes = self.highest_fret + 2 # for open/closed\n", "\n", " # prepresentation and its labels storage\n", " self.output = {}\n", "\n", " # preprocessing modes\n", " #\n", " # c = cqt\n", " # m = melspec\n", " # cm = cqt + melspec\n", " # s = stft\n", " #\n", " self.preproc_mode = mode # Preprocessing mode for the wav file data\n", " self.downsample = True # Select to lower sample rate of data\n", " self.normalize = True # Select to normalize data\n", " self.sr_downs = 22050 # Lowered sample rate\n", "\n", " # CQT parameters\n", " self.cqt_n_bins = 192 # Number of bins for the constant-Q transform \"c\"\n", " self.cqt_bins_per_octave = 24 # Number of bins per octave\n", "\n", " # STFT parameters\n", " self.n_fft = 2048 # Length of the FFT window\n", " self.hop_length = 512 # Number of samples between successive frames\n", "\n", " # save file path\n", " self.save_path = \"data/spec_repr/\" + self.preproc_mode + \"/\"\n", "\n", " def load_rep_and_labels_from_raw_file(self, filename):\n", " \"\"\"\n", " Loads wav and jams files, reads wav file and creates sample rate [int]\n", " and data [np.array].\n", " Constructs, cleans, and categorizes labels and stores them in output dict\n", " Returns the number of frames\n", " \"\"\"\n", " file_audio = filename\n", " #file_audio = self.path_audio + filename + \"_mic.wav\" # wav file\n", " #file_anno = self.path_anno + filename + \".jams\" # jams file\n", " #jam = jams.load(file_anno) # loads jams file\n", " self.sr_original, data = wavfile.read(file_audio) # creates sample rate [int] and data from wav file\n", " self.sr_curr = self.sr_original\n", "\n", " # preprocess audio, store in output dict\n", " self.output[\"repr\"] = np.swapaxes(self.preprocess_audio(data), 0, 1)\n", "\n", " # construct labels\n", " frame_indices = range(len(self.output[\"repr\"])) # Counts the frames\n", " times = librosa.frames_to_time( # Converts frame counts to time (seconds)\n", " frame_indices,\n", " sr=self.sr_curr, # Sample rate\n", " hop_length=self.hop_length # Number of samples between successive frames\n", " )\n", " return data\n", "\n", "\n", " \n", " \n", " def correct_numbering(self, n):\n", " \"\"\"\n", " Adds +1 to correct the string number\n", " \"\"\"\n", " n += 1\n", " if n < 0 or n > self.highest_fret:\n", " n = 0\n", " return n\n", "\n", " def categorical(self, label):\n", " \"\"\"\n", " Categorizes the label in the number of classes defined\n", " (highest_fret (19) + 2 # for open/closed)\n", " \"\"\"\n", " return to_categorical(label, self.num_classes)\n", "\n", " def clean_label(self, label):\n", " \"\"\"\n", " Takes the label, corrects the string numbering and categorizes the label\n", " using to_categorical.\n", " Returns categorized and clean label\n", " \"\"\"\n", " label = [self.correct_numbering(n) for n in label]\n", " return self.categorical(label)\n", "\n", " def clean_labels(self, labels):\n", " \"\"\"\n", " Returns an array of all the cleaned labels with the correct string numbering\n", " and categorized according to the number of classes defined\n", " \"\"\"\n", " return np.array([self.clean_label(label) for label in labels])\n", "\n", " def preprocess_audio(self, data):\n", " \"\"\"\n", " Preprocesses data depending on mode selected using librosa.\n", " It converts data to float, then it normalizes it and resamples it\n", " to a lower sample rate. Then, preprocesses it and returns the processed data\n", " Args:\n", " data ([np.array]): [data created by wavfile.read]\n", " Returns:\n", " [np.ndarrray[shape=(n_bins, t)]]: [preprocessed data array]\n", " \"\"\"\n", " data = data.astype(float)\n", " if self.normalize:\n", " data = librosa.util.normalize(data)\n", " if self.downsample:\n", " data = librosa.resample(data, self.sr_original, self.sr_downs)\n", " self.sr_curr = self.sr_downs\n", " if self.preproc_mode == \"c\":\n", " data = np.abs(\n", " librosa.cqt(data, # Computes the constant-Q transform of an audio signal\n", " hop_length=self.hop_length,\n", " sr=self.sr_curr, # data sample rate\n", " n_bins=self.cqt_n_bins,\n", " bins_per_octave=self.cqt_bins_per_octave))\n", " else:\n", " print(\"invalid representation mode.\")\n", "\n", " return data\n", "\n", " def save_data(self, filename):\n", " \"\"\"\n", " Saves the generated data output dictionary into an npz file\n", " \"\"\"\n", " np.savez(filename, **self.output)\n", "\n", " def get_nth_filename(self, n):\n", " \"\"\"\n", " Sorts the jams files in the directory, looks for the nth one,\n", " removes the .jams extension and returns only the filename\n", " Returns:\n", " [str]: [filename]\n", " \"\"\"\n", " filenames = np.sort(np.array(os.listdir(self.path_anno)))\n", " filenames = list(filter(lambda x: x[-5:] == \".jams\", filenames))\n", " print(filenames[n])\n", " return filenames[n][:-5]\n", "\n", " def load_and_save_repr_nth_file(self, n):\n", " \"\"\"\n", " Gets the filename, preprocesses it, and gets the number of frames.\n", " Saves the file as an npz\n", " \"\"\"\n", "\n", " filename = self.get_nth_filename(n) # Gets only filename with no .jams extension\n", " print(filename)\n", " num_frames = self.load_rep_and_labels_from_raw_file(filename)\n", " print(\"done: \" + filename + \", \" + str(num_frames) + \" frames\")\n", " save_path = self.save_path\n", " if not os.path.exists(save_path): # Creates saving path if it does not exist\n", " os.makedirs(save_path)\n", " self.save_data(save_path + filename + \".npz\") # Saves generated output dictionary in an npz file" ] }, { "cell_type": "code", "execution_count": 21, "id": "3d270bcc", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_9453/1941861401.py:59: WavFileWarning: Chunk (non-data) not understood, skipping it.\n", " self.sr_original, data = wavfile.read(file_audio) # creates sample rate [int] and data from wav file\n" ] }, { "data": { "text/plain": [ "(711, 192)" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "import numpy as np\n", "filenames = '../raw_data/experimentmono.wav'\n", "genrep = TabDataReprGen()\n", "data = genrep.load_rep_and_labels_from_raw_file(filenames)\n", "process = np.swapaxes(genrep.preprocess_audio(data),0,1)\n", "\n", "process.shape" ] }, { "cell_type": "code", "execution_count": 18, "id": "cdf92080", "metadata": {}, "outputs": [], "source": [ "halfwin = con_win_size // 2" ] }, { "cell_type": "code", "execution_count": 22, "id": "ada65089", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(719, 192)" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "full_x = np.pad(process, [(halfwin, halfwin), # full x is the entire song padded with halfwin*2 frames\n", " (0, 0)],mode='constant')\n", "full_x.shape" ] }, { "cell_type": "code", "execution_count": 29, "id": "20dece60", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "719" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "full_x.shape[0]" ] }, { "cell_type": "code", "execution_count": 42, "id": "d954ae96", "metadata": {}, "outputs": [], "source": [ "x_new = np.empty((full_x.shape[0],192,9,1))\n", "for frame_idx in range(0,full_x.shape[0]): # for all frames in the experiment\n", " sample_x = full_x[frame_idx:frame_idx + con_win_size]\n", " np.append(x_new,np.expand_dims(np.swapaxes(sample_x, 0, 1), -1))\n" ] }, { "cell_type": "code", "execution_count": 43, "id": "7f703443", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(719, 192, 9, 1)" ] }, "execution_count": 43, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x_new.shape" ] }, { "cell_type": "code", "execution_count": 44, "id": "a602b47c", "metadata": {}, "outputs": [], "source": [ "y_pred = model.predict(x_new)" ] }, { "cell_type": "code", "execution_count": 47, "id": "e1f0cf12", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(719, 6, 21)" ] }, "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_pred.shape" ] }, { "cell_type": "code", "execution_count": 57, "id": "6e185e9d", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
0123456789...11121314151617181920
00.9922940.0011950.0004150.0012780.0020040.0008840.0005050.0005570.0003670.000125...0.0000270.0000570.0000300.0000220.0000170.0000200.0000320.0000290.0000260.000038
10.9866000.0005930.0008140.0010530.0017660.0021290.0014800.0031720.0009260.000452...0.0002380.0001030.0000580.0000320.0000460.0000130.0000130.0000210.0000160.000009
20.9706620.0009180.0003660.0011930.0026360.0017880.0022370.0073820.0016750.007143...0.0017300.0005760.0001670.0001660.0000590.0000370.0000240.0000180.0000520.000030
30.9574650.0011250.0015820.0026890.0016050.0036330.0028920.0094900.0055130.003275...0.0029560.0013340.0006850.0006120.0000700.0000960.0000320.0000980.0000490.000054
40.9556030.0015010.0016800.0027550.0014410.0032930.0043680.0087460.0031530.006144...0.0026500.0014810.0010790.0010010.0001170.0003080.0001490.0000690.0000370.000052
50.9815000.0015510.0006320.0017270.0017020.0013640.0009660.0024850.0015950.001890...0.0016960.0006010.0006930.0004890.0000790.0000180.0000820.0000360.0000260.000034
\n", "

6 rows × 21 columns

\n", "
" ], "text/plain": [ " 0 1 2 3 4 5 6 \\\n", "0 0.992294 0.001195 0.000415 0.001278 0.002004 0.000884 0.000505 \n", "1 0.986600 0.000593 0.000814 0.001053 0.001766 0.002129 0.001480 \n", "2 0.970662 0.000918 0.000366 0.001193 0.002636 0.001788 0.002237 \n", "3 0.957465 0.001125 0.001582 0.002689 0.001605 0.003633 0.002892 \n", "4 0.955603 0.001501 0.001680 0.002755 0.001441 0.003293 0.004368 \n", "5 0.981500 0.001551 0.000632 0.001727 0.001702 0.001364 0.000966 \n", "\n", " 7 8 9 ... 11 12 13 14 \\\n", "0 0.000557 0.000367 0.000125 ... 0.000027 0.000057 0.000030 0.000022 \n", "1 0.003172 0.000926 0.000452 ... 0.000238 0.000103 0.000058 0.000032 \n", "2 0.007382 0.001675 0.007143 ... 0.001730 0.000576 0.000167 0.000166 \n", "3 0.009490 0.005513 0.003275 ... 0.002956 0.001334 0.000685 0.000612 \n", "4 0.008746 0.003153 0.006144 ... 0.002650 0.001481 0.001079 0.001001 \n", "5 0.002485 0.001595 0.001890 ... 0.001696 0.000601 0.000693 0.000489 \n", "\n", " 15 16 17 18 19 20 \n", "0 0.000017 0.000020 0.000032 0.000029 0.000026 0.000038 \n", "1 0.000046 0.000013 0.000013 0.000021 0.000016 0.000009 \n", "2 0.000059 0.000037 0.000024 0.000018 0.000052 0.000030 \n", "3 0.000070 0.000096 0.000032 0.000098 0.000049 0.000054 \n", "4 0.000117 0.000308 0.000149 0.000069 0.000037 0.000052 \n", "5 0.000079 0.000018 0.000082 0.000036 0.000026 0.000034 \n", "\n", "[6 rows x 21 columns]" ] }, "execution_count": 57, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pd.DataFrame(y_pred[467])" ] }, { "cell_type": "code", "execution_count": 52, "id": "49727c4c", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n" ] }, { "cell_type": "code", "execution_count": 60, "id": "f89ec04e", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAW0AAAB/CAYAAAAkaJMGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAJrUlEQVR4nO3ca6hlZR3H8e+vcSwqqXGMMmdoMiKwF6UcpNIiKkyn0IqIkS52AYkSFIowBIneVRRRRGEl3SSl7CJhmN2IXjg2TqOpkzmKkTJpTaF2Ncd/L/aa2J6z99nbOXvtfZ78fuBw1l7rWXv9ec6zf2fd9kpVIUlqwxMWXYAkaXqGtiQ1xNCWpIYY2pLUEENbkhpiaEtSQ47o402POXpDbdu68bDX/91NT55hNZK0/v2Lv/NQ/TuT2vUS2tu2buT6a7Ye9vqvffaLZ1eMJDVgZ/1kqnaeHpGkhhjaktSQqUI7yelJbkuyL8mFfRclSRptYmgn2QB8DjgDOAE4O8kJfRcmSVppmj3tk4F9VXVnVT0EXA6c1W9ZkqRRpgnt44A/DL2+u5snSZqzmV2ITHJukl1Jdv3pwMFZva0kacg0oX0PMHzT9ZZu3qNU1SVVtVRVS8/YvGFW9UmShkwT2r8Cnp/kuUmOBHYAV/VbliRplInfiKyqh5OcB1wDbAAurapbeq9MkrTCVF9jr6qrgat7rkWSNIHfiJSkhhjaktQQQ1uSGtLLo1mL4mA90sdbS9LjmnvaktQQQ1uSGmJoS1JDDG1JaoihLUkNMbQlqSGGtiQ1xNCWpIYY2pLUEENbkhpiaEtSQwxtSWqIoS1JDTG0JakhhrYkNaSX52k/QvHPeqiPt9Y8JGtbv6rt7UvrmHvaktQQQ1uSGmJoS1JDJoZ2kq1Jfpbk1iS3JDl/HoVJklaa5kLkw8AHqmp3kqOAG5JcW1W39lybJGmZiXvaVbW/qnZ30w8Ce4Hj+i5MkrTSYzqnnWQbcCKws5dqJEmrmjq0kzwVuBK4oKoeGLH83CS7kuw6cOCRWdYoSepMFdpJNjII7Muq6juj2lTVJVW1VFVLmzd7U4ok9WGau0cCfBnYW1Wf6r8kSdI40+wSnwK8HXhVkj3dz/ae65IkjTDxlr+q+iWwxodBSJJmwZPPktQQQ1uSGmJoS1JDenmedgH/qoN9vLXmIWv9X77G+/QX+Txun8Wtdc49bUlqiKEtSQ0xtCWpIYa2JDXE0JakhhjaktQQQ1uSGmJoS1JDDG1JaoihLUkNMbQlqSGGtiQ1xNCWpIYY2pLUEENbkhrS2/O01/hEZa3FWp4nPQuLfB629H/OPW1JaoihLUkNMbQlqSFTh3aSDUl+neQHfRYkSRrvsexpnw/s7asQSdJkU4V2ki3A64Av9VuOJGk10+5pfxr4EN7JJ0kLNTG0k7weuK+qbpjQ7twku5LsOnDAbJekPkyzp30KcGaSu4DLgVcl+cbyRlV1SVUtVdXS5s3elCJJfZiYrlX14araUlXbgB3AT6vqbb1XJklawV1iSWrIY3r2SFX9HPh5L5VIkiZyT1uSGmJoS1JDDG1Jakgvz9M+grDpCU/q4601jbU+z7oOzqaOw97+GuuX/o+5py1JDTG0JakhhrYkNcTQlqSGGNqS1BBDW5IaYmhLUkMMbUlqiKEtSQ0xtCWpIYa2JDXE0JakhhjaktQQQ1uSGmJoS1JDUj08uzjJn4Dfr9LkGODPM9/wbKzn2sD61sr61sb6Dt+k2p5TVc+Y9Ca9hPbEjSa7qmpp7huewnquDaxvraxvbazv8M2qNk+PSFJDDG1JasiiQvuSBW13Guu5NrC+tbK+tbG+wzeT2hZyTluSdHg8PSJJDekttJOcnuS2JPuSXDhi+ROTXNEt35lkW1+1jNj21iQ/S3JrkluSnD+izSuT3J9kT/dz8bzq67Z/V5LfdNveNWJ5knym67+bkpw0x9peMNQve5I8kOSCZW3m2n9JLk1yX5Kbh+YdneTaJLd3vzeNWfecrs3tSc6ZY32fSPLb7u/33SRPH7PuqmOhx/o+kuSeob/h9jHrrvpZ76m2K4bquivJnjHrzqPvRuZJb+Ovqmb+A2wA7gCOB44EbgROWNbmfcAXuukdwBV91DKmvmOBk7rpo4DfjajvlcAP5lXTiBrvAo5ZZfl24IdAgJcAOxdU5wbgjwzuMV1Y/wGvAE4Cbh6a93Hgwm76QuBjI9Y7Griz+72pm940p/pOA47opj82qr5pxkKP9X0E+OAUf/9VP+t91LZs+SeBixfYdyPzpK/x19ee9snAvqq6s6oeAi4HzlrW5izgq930t4FXJ0lP9TxKVe2vqt3d9IPAXuC4eWx7hs4CvlYD1wFPT3LsAup4NXBHVa32ZareVdUvgL8smz08xr4KvGHEqq8Frq2qv1TVX4FrgdPnUV9V/aiqHu5eXgdsmfV2pzWm/6YxzWe9t9q6zHgL8M1ZbvOxWCVPehl/fYX2ccAfhl7fzcpQ/F+bbuDeD2zuqZ6xutMyJwI7Ryx+aZIbk/wwyQvnWxkF/CjJDUnOHbF8mj6ehx2M/8Assv8AnllV+7vpPwLPHNFmvfTjuxkcOY0yaSz06bzu9M2lYw7vF91/Lwfurarbxyyfa98ty5Next/j+kJkkqcCVwIXVNUDyxbvZnDI/yLgs8D35lzeqVV1EnAG8P4kr5jz9idKciRwJvCtEYsX3X+PUoNj0XV5q1SSi4CHgcvGNFnUWPg88DzgxcB+Bqch1puzWX0ve259t1qezHL89RXa9wBbh15v6eaNbJPkCOBpwIGe6lkhyUYGHXxZVX1n+fKqeqCq/tZNXw1sTHLMvOqrqnu63/cB32VwGDpsmj7u2xnA7qq6d/mCRfdf595Dp4y63/eNaLPQfkzyTuD1wFu7D/YKU4yFXlTVvVV1sKoeAb44ZrsL678uN94EXDGuzbz6bkye9DL++grtXwHPT/Lcbm9sB3DVsjZXAYeulL4Z+Om4QTtr3XmwLwN7q+pTY9o869A59iQnM+irufxTSfKUJEcdmmZwwermZc2uAt6RgZcA9w8dis3L2L2cRfbfkOExdg7w/RFtrgFOS7KpO/w/rZvXuySnAx8Czqyqf4xpM81Y6Ku+4Wskbxyz3Wk+6315DfDbqrp71MJ59d0qedLP+Ovxiup2BldR7wAu6uZ9lMEABXgSg8PqfcD1wPF91TKitlMZHKrcBOzpfrYD7wXe27U5D7iFwdXw64CXzbG+47vt3tjVcKj/husL8Lmuf38DLM2rvm77T2EQwk8bmrew/mPwz2M/8B8G5wXfw+AayU+A24EfA0d3bZeALw2t++5uHO4D3jXH+vYxOJ95aAweupvq2cDVq42FOdX39W5s3cQggI5dXl/3esVnve/auvlfOTTehtouou/G5Ukv489vREpSQx7XFyIlqTWGtiQ1xNCWpIYY2pLUEENbkhpiaEtSQwxtSWqIoS1JDfkvDw5E7zJjf24AAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "for frame in range(0,719):\n", " plt.imshow(y_pred[frame])" ] }, { "cell_type": "code", "execution_count": null, "id": "636568ec", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.12" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 5 }