{ "cells": [ { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
CulmenLengthCulmenDepthFlipperLengthBodyMassSpecies
18748.416.322.054.001
30349.519.020.038.002
19650.515.922.255.501
1638.719.019.534.500
8541.320.319.435.500
23447.414.621.247.251
14539.018.718.536.500
2137.718.718.036.000
3039.516.717.832.500
33042.517.318.733.502
\n", "
" ], "text/plain": [ " CulmenLength CulmenDepth FlipperLength BodyMass Species\n", "187 48.4 16.3 22.0 54.00 1\n", "303 49.5 19.0 20.0 38.00 2\n", "196 50.5 15.9 22.2 55.50 1\n", "16 38.7 19.0 19.5 34.50 0\n", "85 41.3 20.3 19.4 35.50 0\n", "234 47.4 14.6 21.2 47.25 1\n", "145 39.0 18.7 18.5 36.50 0\n", "21 37.7 18.7 18.0 36.00 0\n", "30 39.5 16.7 17.8 32.50 0\n", "330 42.5 17.3 18.7 33.50 2" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "\n", "# load the training dataset (excluding rows with null values)\n", "penguins = pd.read_csv('/Users/johnnydevriese/projects/data/penguins.csv').dropna()\n", "\n", "# Deep Learning models work best when features are on similar scales\n", "# In a real solution, we'd implement some custom normalization for each feature, but to keep things simple\n", "# we'll just rescale the FlipperLength and BodyMass so they're on a similar scale to the bill measurements\n", "penguins['FlipperLength'] = penguins['FlipperLength']/10\n", "penguins['BodyMass'] = penguins['BodyMass']/100\n", "\n", "# The dataset is too small to be useful for deep learning\n", "# So we'll oversample it to increase its size\n", "for i in range(1,3):\n", " penguins = penguins.append(penguins)\n", "\n", "# Display a random sample of 10 observations\n", "sample = penguins.sample(10)\n", "sample" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['CulmenLength' 'CulmenDepth' 'FlipperLength' 'BodyMass' 'Species'] SpeciesName\n", "[ 42.7 18.3 19.6 40.75 0 ] Adelie\n", "[ 37.9 18.6 19.3 29.25 0 ] Adelie\n", "[ 39.0 17.1 19.1 30.5 0 ] Adelie\n", "[ 50.2 18.8 20.2 38.0 2 ] Chinstrap\n", "[ 45.2 14.8 21.2 52.0 1 ] Gentoo\n", "[ 45.7 13.9 21.4 44.0 1 ] Gentoo\n", "[ 38.8 20.0 19.0 39.5 0 ] Adelie\n", "[ 43.8 13.9 20.8 43.0 1 ] Gentoo\n", "[ 46.0 18.9 19.5 41.5 2 ] Chinstrap\n", "[ 49.4 15.8 21.6 49.25 1 ] Gentoo\n" ] } ], "source": [ "penguin_classes = ['Adelie', 'Gentoo', 'Chinstrap']\n", "print(sample.columns[0:5].values, 'SpeciesName')\n", "for index, row in penguins.sample(10).iterrows():\n", " print('[',row[0], row[1], row[2],row[3], int(row[4]), ']',penguin_classes[int(row[-1])])" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training Set: 957, Test Set: 411 \n", "\n", "Sample of features and labels:\n", "[51.1 16.5 22.5 52.5] 1 (Gentoo)\n", "[50.7 19.7 20.3 40.5] 2 (Chinstrap)\n", "[49.5 16.2 22.9 58. ] 1 (Gentoo)\n", "[39.3 20.6 19. 36.5] 0 (Adelie)\n", "[42.5 20.7 19.7 45. ] 0 (Adelie)\n", "[50. 15.3 22. 55.5] 1 (Gentoo)\n", "[50.2 18.7 19.8 37.75] 2 (Chinstrap)\n", "[50.7 19.7 20.3 40.5] 2 (Chinstrap)\n", "[49.1 14.5 21.2 46.25] 1 (Gentoo)\n", "[43.2 16.6 18.7 29. ] 2 (Chinstrap)\n", "[38.8 17.6 19.1 32.75] 0 (Adelie)\n", "[37.8 17.1 18.6 33. ] 0 (Adelie)\n", "[45.8 14.2 21.9 47. ] 1 (Gentoo)\n", "[43.8 13.9 20.8 43. ] 1 (Gentoo)\n", "[36. 17.1 18.7 37. ] 0 (Adelie)\n", "[43.3 13.4 20.9 44. ] 1 (Gentoo)\n", "[36. 18.5 18.6 31. ] 0 (Adelie)\n", "[41.1 19. 18.2 34.25] 0 (Adelie)\n", "[33.1 16.1 17.8 29. ] 0 (Adelie)\n", "[40.9 13.7 21.4 46.5] 1 (Gentoo)\n", "[45.2 17.8 19.8 39.5] 2 (Chinstrap)\n", "[48.4 14.6 21.3 58.5] 1 (Gentoo)\n", "[43.6 13.9 21.7 49. ] 1 (Gentoo)\n", "[38.5 17.9 19. 33.25] 0 (Adelie)\n" ] } ], "source": [ "from sklearn.model_selection import train_test_split\n", "\n", "features = ['CulmenLength','CulmenDepth','FlipperLength','BodyMass']\n", "label = 'Species'\n", " \n", "# Split data 70%-30% into training set and test set\n", "x_train, x_test, y_train, y_test = train_test_split(penguins[features].values,\n", " penguins[label].values,\n", " test_size=0.30,\n", " random_state=0)\n", "\n", "print ('Training Set: %d, Test Set: %d \\n' % (len(x_train), len(x_test)))\n", "print(\"Sample of features and labels:\")\n", "\n", "# Take a look at the first 25 training features and corresponding labels\n", "for n in range(0,24):\n", " print(x_train[n], y_train[n], '(' + penguin_classes[y_train[n]] + ')')\n", "\n" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Libraries imported - ready to use PyTorch 1.10.0\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.utils.data as torch_data\n", "\n", "# Set random seed for reproducability\n", "torch.manual_seed(0)\n", "\n", "print(\"Libraries imported - ready to use PyTorch\", torch.__version__)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Ready to load data\n" ] } ], "source": [ "# Create a dataset and loader for the training data and labels\n", "train_x = torch.Tensor(x_train).float()\n", "train_y = torch.Tensor(y_train).long()\n", "train_ds = torch_data.TensorDataset(train_x,train_y)\n", "train_loader = torch_data.DataLoader(train_ds, batch_size=20,\n", " shuffle=False, num_workers=1)\n", "\n", "# Create a dataset and loader for the test data and labels\n", "test_x = torch.Tensor(x_test).float()\n", "test_y = torch.Tensor(y_test).long()\n", "test_ds = torch_data.TensorDataset(test_x,test_y)\n", "test_loader = torch_data.DataLoader(test_ds, batch_size=20,\n", " shuffle=False, num_workers=1)\n", "print('Ready to load data')\n", "\n" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "PenguinNet(\n", " (fully_connected1): Linear(in_features=4, out_features=10, bias=True)\n", " (fully_connected2): Linear(in_features=10, out_features=10, bias=True)\n", " (fully_connected3): Linear(in_features=10, out_features=3, bias=True)\n", ")\n" ] } ], "source": [ "# Number of hidden layer nodes\n", "hl = 10\n", "hidden_layer_nodes = 10\n", "initial_input_feature_dimension = len(features)\n", "output_feature_dimension = len(penguin_classes)\n", "\n", "# Define the neural network\n", "class PenguinNet(nn.Module):\n", " def __init__(self):\n", " super(PenguinNet, self).__init__()\n", " self.fully_connected1 = nn.Linear(in_features=len(features), out_features=hl, bias=True) # bias=True is default\n", " self.fully_connected2 = nn.Linear(hl, hl)\n", " self.fully_connected3 = nn.Linear(hl, len(penguin_classes))\n", "\n", " def forward(self, x):\n", " x = torch.relu(self.fully_connected1(x))\n", " x = torch.relu(self.fully_connected2(x))\n", " x = torch.relu(self.fully_connected3(x))\n", " return x\n", "\n", "# Create a model instance from the network\n", "model = PenguinNet()\n", "print(model)" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 0\n", "Training set: Average loss: 1.118814\n", "Validation set: Average loss: 1.023595, Accuracy: 148/411 (36%)\n", "\n", "Epoch: 1\n", "Training set: Average loss: 1.010274\n", "Validation set: Average loss: 0.983460, Accuracy: 163/411 (40%)\n", "\n", "Epoch: 2\n", "Training set: Average loss: 0.965314\n", "Validation set: Average loss: 0.934165, Accuracy: 191/411 (46%)\n", "\n", "Epoch: 3\n", "Training set: Average loss: 0.911513\n", "Validation set: Average loss: 0.867269, Accuracy: 250/411 (61%)\n", "\n", "Epoch: 4\n", "Training set: Average loss: 0.817720\n", "Validation set: Average loss: 0.742112, Accuracy: 272/411 (66%)\n", "\n", "Epoch: 5\n", "Training set: Average loss: 0.733329\n", "Validation set: Average loss: 0.691639, Accuracy: 302/411 (73%)\n", "\n", "Epoch: 6\n", "Training set: Average loss: 0.696301\n", "Validation set: Average loss: 0.661350, Accuracy: 312/411 (76%)\n", "\n", "Epoch: 7\n", "Training set: Average loss: 0.671731\n", "Validation set: Average loss: 0.640087, Accuracy: 327/411 (80%)\n", "\n", "Epoch: 8\n", "Training set: Average loss: 0.653092\n", "Validation set: Average loss: 0.624311, Accuracy: 338/411 (82%)\n", "\n", "Epoch: 9\n", "Training set: Average loss: 0.638097\n", "Validation set: Average loss: 0.610605, Accuracy: 345/411 (84%)\n", "\n", "Epoch: 10\n", "Training set: Average loss: 0.625696\n", "Validation set: Average loss: 0.598022, Accuracy: 345/411 (84%)\n", "\n", "Epoch: 11\n", "Training set: Average loss: 0.614685\n", "Validation set: Average loss: 0.588183, Accuracy: 353/411 (86%)\n", "\n", "Epoch: 12\n", "Training set: Average loss: 0.605506\n", "Validation set: Average loss: 0.578678, Accuracy: 358/411 (87%)\n", "\n", "Epoch: 13\n", "Training set: Average loss: 0.597361\n", "Validation set: Average loss: 0.569911, Accuracy: 361/411 (88%)\n", "\n", "Epoch: 14\n", "Training set: Average loss: 0.590228\n", "Validation set: Average loss: 0.562248, Accuracy: 361/411 (88%)\n", "\n", "Epoch: 15\n", "Training set: Average loss: 0.583250\n", "Validation set: Average loss: 0.556146, Accuracy: 372/411 (91%)\n", "\n", "Epoch: 16\n", "Training set: Average loss: 0.576846\n", "Validation set: Average loss: 0.549725, Accuracy: 375/411 (91%)\n", "\n", "Epoch: 17\n", "Training set: Average loss: 0.571098\n", "Validation set: Average loss: 0.544390, Accuracy: 382/411 (93%)\n", "\n", "Epoch: 18\n", "Training set: Average loss: 0.565975\n", "Validation set: Average loss: 0.540335, Accuracy: 384/411 (93%)\n", "\n", "Epoch: 19\n", "Training set: Average loss: 0.561476\n", "Validation set: Average loss: 0.536972, Accuracy: 389/411 (95%)\n", "\n", "Epoch: 20\n", "Training set: Average loss: 0.557517\n", "Validation set: Average loss: 0.532509, Accuracy: 390/411 (95%)\n", "\n", "Epoch: 21\n", "Training set: Average loss: 0.553931\n", "Validation set: Average loss: 0.529417, Accuracy: 396/411 (96%)\n", "\n", "Epoch: 22\n", "Training set: Average loss: 0.550773\n", "Validation set: Average loss: 0.528216, Accuracy: 397/411 (97%)\n", "\n", "Epoch: 23\n", "Training set: Average loss: 0.547976\n", "Validation set: Average loss: 0.523656, Accuracy: 397/411 (97%)\n", "\n", "Epoch: 24\n", "Training set: Average loss: 0.545466\n", "Validation set: Average loss: 0.521025, Accuracy: 397/411 (97%)\n", "\n", "Epoch: 25\n", "Training set: Average loss: 0.543647\n", "Validation set: Average loss: 0.519855, Accuracy: 400/411 (97%)\n", "\n", "Epoch: 26\n", "Training set: Average loss: 0.542047\n", "Validation set: Average loss: 0.517385, Accuracy: 398/411 (97%)\n", "\n", "Epoch: 27\n", "Training set: Average loss: 0.540234\n", "Validation set: Average loss: 0.515388, Accuracy: 400/411 (97%)\n", "\n", "Epoch: 28\n", "Training set: Average loss: 0.538977\n", "Validation set: Average loss: 0.512899, Accuracy: 401/411 (98%)\n", "\n", "Epoch: 29\n", "Training set: Average loss: 0.537303\n", "Validation set: Average loss: 0.512066, Accuracy: 404/411 (98%)\n", "\n", "Epoch: 30\n", "Training set: Average loss: 0.536062\n", "Validation set: Average loss: 0.511284, Accuracy: 404/411 (98%)\n", "\n", "Epoch: 31\n", "Training set: Average loss: 0.534580\n", "Validation set: Average loss: 0.508444, Accuracy: 404/411 (98%)\n", "\n", "Epoch: 32\n", "Training set: Average loss: 0.533200\n", "Validation set: Average loss: 0.507806, Accuracy: 404/411 (98%)\n", "\n", "Epoch: 33\n", "Training set: Average loss: 0.532376\n", "Validation set: Average loss: 0.505557, Accuracy: 404/411 (98%)\n", "\n", "Epoch: 34\n", "Training set: Average loss: 0.531220\n", "Validation set: Average loss: 0.503028, Accuracy: 404/411 (98%)\n", "\n", "Epoch: 35\n", "Training set: Average loss: 0.529759\n", "Validation set: Average loss: 0.502396, Accuracy: 404/411 (98%)\n", "\n", "Epoch: 36\n", "Training set: Average loss: 0.528576\n", "Validation set: Average loss: 0.501712, Accuracy: 404/411 (98%)\n", "\n", "Epoch: 37\n", "Training set: Average loss: 0.527694\n", "Validation set: Average loss: 0.499238, Accuracy: 404/411 (98%)\n", "\n", "Epoch: 38\n", "Training set: Average loss: 0.526515\n", "Validation set: Average loss: 0.498586, Accuracy: 404/411 (98%)\n", "\n", "Epoch: 39\n", "Training set: Average loss: 0.525752\n", "Validation set: Average loss: 0.496938, Accuracy: 404/411 (98%)\n", "\n", "Epoch: 40\n", "Training set: Average loss: 0.524745\n", "Validation set: Average loss: 0.496314, Accuracy: 405/411 (99%)\n", "\n", "Epoch: 41\n", "Training set: Average loss: 0.524034\n", "Validation set: Average loss: 0.494481, Accuracy: 404/411 (98%)\n", "\n", "Epoch: 42\n", "Training set: Average loss: 0.523150\n", "Validation set: Average loss: 0.492949, Accuracy: 404/411 (98%)\n", "\n", "Epoch: 43\n", "Training set: Average loss: 0.522167\n", "Validation set: Average loss: 0.492328, Accuracy: 404/411 (98%)\n", "\n", "Epoch: 44\n", "Training set: Average loss: 0.521537\n", "Validation set: Average loss: 0.490820, Accuracy: 401/411 (98%)\n", "\n", "Epoch: 45\n", "Training set: Average loss: 0.521010\n", "Validation set: Average loss: 0.489736, Accuracy: 401/411 (98%)\n", "\n", "Epoch: 46\n", "Training set: Average loss: 0.520252\n", "Validation set: Average loss: 0.489686, Accuracy: 404/411 (98%)\n", "\n", "Epoch: 47\n", "Training set: Average loss: 0.519929\n", "Validation set: Average loss: 0.488752, Accuracy: 401/411 (98%)\n", "\n", "Epoch: 48\n", "Training set: Average loss: 0.519249\n", "Validation set: Average loss: 0.488609, Accuracy: 405/411 (99%)\n", "\n", "Epoch: 49\n", "Training set: Average loss: 0.518899\n", "Validation set: Average loss: 0.487255, Accuracy: 401/411 (98%)\n", "\n" ] } ], "source": [ "# Specify the loss criteria (we'll use CrossEntropyLoss for multi-class classification)\n", "loss_criteria = nn.CrossEntropyLoss()\n", "\n", "def train(model, data_loader, optimizer):\n", " # Set the model to training mode\n", " model.train()\n", " train_loss = 0\n", " \n", " for batch, tensor in enumerate(data_loader):\n", " data, target = tensor\n", " #feedforward\n", " optimizer.zero_grad()\n", " out = model(data)\n", " loss = loss_criteria(out, target)\n", " train_loss += loss.item()\n", "\n", " # backpropagate\n", " loss.backward()\n", " optimizer.step()\n", "\n", " #Return average loss\n", " avg_loss = train_loss / (batch+1)\n", " print('Training set: Average loss: {:.6f}'.format(avg_loss))\n", " return avg_loss\n", "\n", "def test(model, data_loader):\n", " # Switch the model to evaluation mode (so we don't backpropagate)\n", " model.eval()\n", " test_loss = 0\n", " correct = 0\n", "\n", " with torch.no_grad():\n", " batch_count = 0\n", " for batch, tensor in enumerate(data_loader):\n", " batch_count += 1\n", " data, target = tensor\n", " # Get the predictions\n", " out = model(data)\n", "\n", " # calculate the loss\n", " test_loss += loss_criteria(out, target).item()\n", "\n", " # Calculate the accuracy\n", " _, predicted = torch.max(out.data, 1)\n", " correct += torch.sum(target==predicted).item()\n", " \n", " # Calculate the average loss and total accuracy for this epoch\n", " avg_loss = test_loss/batch_count\n", " print('Validation set: Average loss: {:.6f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n", " avg_loss, correct, len(data_loader.dataset),\n", " 100. * correct / len(data_loader.dataset)))\n", " \n", " # return average loss for the epoch\n", " return avg_loss\n", "\n", "# Use an \"Adam\" optimizer to adjust weights\n", "# (see https://pytorch.org/docs/stable/optim.html#algorithms for details of supported algorithms)\n", "learning_rate = 0.001\n", "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n", "optimizer.zero_grad()\n", "\n", "# We'll track metrics for each epoch in these arrays\n", "epoch_nums = []\n", "training_loss = []\n", "validation_loss = []\n", "\n", "# Train over 50 epochs\n", "epochs = 50\n", "# for epoch in range(1, epochs + 1):\n", "for epoch in range(epochs):\n", "\n", " # print the epoch number\n", " print('Epoch: {}'.format(epoch))\n", " \n", " # Feed training data into the model to optimize the weights\n", " train_loss = train(model, train_loader, optimizer)\n", " \n", " # Feed the test data into the model to check its performance\n", " test_loss = test(model, test_loader)\n", " \n", " # Log the metrics for this epoch\n", " epoch_nums.append(epoch)\n", " training_loss.append(train_loss)\n", " validation_loss.append(test_loss)" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEGCAYAAABo25JHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAyQElEQVR4nO3deXxcdb3/8ddnlmSy72mztEn3vU3TtFQKtKUFCgJlE8qigGgF9aJeRcCrInq9l/u7XC7wkMWCICoXrOxqKVgsWynQdKX7vqRpm31fZ+b7++NM1iZpkmYySebzfDzmcWbOnHPyPWjzznc5368YY1BKKRW8bIEugFJKqcDSIFBKqSCnQaCUUkFOg0AppYKcBoFSSgU5R6AL0FOJiYkmMzMz0MVQSqlBZePGjUXGmKSOvht0QZCZmUlubm6gi6GUUoOKiBzp7DttGlJKqSCnQaCUUkFOg0AppYLcoOsjUEoNLY2NjeTl5VFXVxfoogwJLpeL9PR0nE5nt8/RIFBKBVReXh5RUVFkZmYiIoEuzqBmjKG4uJi8vDxGjRrV7fO0aUgpFVB1dXUkJCRoCPQBESEhIaHHtSsNAqVUwGkI9J3e/LcMmiDYc7KS/1y1i6p6d6CLopRSA0rQBMGxkhp+++FB9pysCHRRlFIDSFlZGU8++WSPz7vssssoKyvr8pif//znrFmzppcl6z9BEwSTUqMB2HmiMsAlUUoNJJ0Fgcfj6fK8VatWERsb2+Uxv/zlL1m8ePHZFK9fBE0QpMa4iHY52H1CawRKqRb33XcfBw4cICsri9mzZ7Nw4UJuuukmpk2bBsBVV13FrFmzmDJlCitWrGg+LzMzk6KiIg4fPsykSZP45je/yZQpU7j44oupra0F4LbbbuOVV15pPv6BBx4gOzubadOmsXv3bgAKCwu56KKLyM7O5lvf+hYZGRkUFRX1638Dvw0fFZHngMuBAmPM1A6+nwg8D2QD/2aMedhfZfH9PCalRLNLg0CpAevBv+5gZ37f/hudnBrNA1dM6fT7hx56iO3bt7Nlyxbef/99vvzlL7N9+/bm4ZfPPfcc8fHx1NbWMnv2bK699loSEhLaXGPfvn289NJLPPPMM1x//fW8+uqr3HLLLaf9rMTERDZt2sSTTz7Jww8/zLPPPsuDDz7IhRdeyP3338/q1avbhE1/8WeN4PfAki6+LwHuBvwaAK1NSolm98lKvF5dp1kp1bE5c+a0GYP/+OOPM2PGDObOncuxY8fYt2/faeeMGjWKrKwsAGbNmsXhw4c7vPY111xz2jEff/wxy5YtA2DJkiXExcX13c10k99qBMaYD0Uks4vvC4ACEfmyv8rQ3qSUKGoaPBwtqSEzMaK/fqxSqpu6+su9v0REtPxueP/991mzZg3r168nPDycBQsWdDhGPzQ0tPm93W5vbhrq7Di73Y7bbY1gNCbwf5gOij4CEVkuIrkikltYWNjr60xKsTqMtXlIKdUkKiqKysqOB5GUl5cTFxdHeHg4u3fv5tNPP+3zn3/eeeexcuVKAN59911KS0v7/GecyaAIAmPMCmNMjjEmJympw3UVumX8sChsokGglGqRkJDAvHnzmDp1Kvfcc0+b75YsWYLb7Wb69On87Gc/Y+7cuX3+8x944AHeffddsrOzefvtt0lJSSEqKqrPf05XxJ/VEl/T0N866ixudcwvgKrudhbn5OSYs1mYZvEjHzAqMYJnvpbT62sopfrOrl27mDRpUqCLETD19fXY7XYcDgfr16/nrrvuYsuWLWd1zY7+m4rIRmNMh7/4gm7SuYnDo9hyrCzQxVBKKQCOHj3K9ddfj9frJSQkhGeeeabfy+DP4aMvAQuARBHJAx4AnADGmKdFZDiQC0QDXhH5PjDZGOPXdptJKdH8bdsJKuoaiXZ1f5pWpZTyh3HjxrF58+aAlsGfo4ZuPMP3J4F0f/38zkz2dRjvPlHJnFHx/f3jlVJqwBkUncV9SUcOKaVUW0EXBMOiQ4kLd7JbJ59TSikgCINARJg4PFonn1NKKZ+gCwKwmof2nKzAo1NNKKV6KDIyEoD8/Hyuu+66Do9ZsGABZxrm/uijj1JTU9P8uTvTWvtLkAZBFHWNXg4XVwe6KEqpQSo1NbV5ZtHeaB8E3ZnW2l+CNAi0w1gpZbn33nvbrEfwi1/8ggcffJBFixY1Txn95ptvnnbe4cOHmTrVela2traWZcuWMX36dG644YY2cw3ddddd5OTkMGXKFB544AHAmsguPz+fhQsXsnDhQqBlWmuARx55hKlTpzJ16lQeffTR5p/X2XTXZyvoHigDGJscid0m7DpRweXTUwNdHKVUk7fvg5Nf9O01h0+DSx/q9Otly5bx/e9/n29/+9sArFy5ktWrV/ODH/yA6OhoioqKmDt3LldeeWWn6wE/9dRThIeHs23bNrZt20Z2dnbzd7/+9a+Jj4/H4/GwaNEitm3bxt13380jjzzC2rVrSUxMbHOtjRs38vzzz/PZZ59hjOGcc85h/vz5xMXFdXu6654KyhqBy2lnTFIEu7XDWKmgN3PmTAoKCsjPz2fr1q3ExcWRkpLCT37yE6ZPn87ixYs5fvw4p06d6vQaH374YfMv5OnTpzN9+vTm71auXEl2djYzZ85kx44d7Ny5s8vyfPzxx1x99dVEREQQGRnJNddcw0cffQR0f7rrngrKGgFYzUMbDpUEuhhKqda6+Mvdn6677jpeeeUVTp48ybJly3jxxRcpLCxk48aNOJ1OMjMzO5x+urWOaguHDh3i4YcfZsOGDcTFxXHbbbed8Tpdzf/W3emueyooawRgBUF+eR1lNQ2BLopSKsCWLVvGyy+/zCuvvMJ1111HeXk5ycnJOJ1O1q5dy5EjR7o8/4ILLuDFF18EYPv27Wzbtg2AiooKIiIiiImJ4dSpU7z99tvN53Q2/fUFF1zAG2+8QU1NDdXV1bz++uucf/75fXi3pwvqGgHArhOVfGlMwhmOVkoNZVOmTKGyspK0tDRSUlK4+eabueKKK8jJySErK4uJEyd2ef5dd93F7bffzvTp08nKymLOnDkAzJgxg5kzZzJlyhRGjx7NvHnzms9Zvnw5l156KSkpKaxdu7Z5f3Z2NrfddlvzNb7xjW8wc+bMPmsG6ohfp6H2h7OdhrpJQUUdc/7jPX5++WS+ft6oM5+glPKLYJ+G2h96Og110DYNJUWFkhARolNNKKWCXtAGgYgwKSWaXTpySCkV5II2CMB6wnjPqUrcHm+gi6JUUBtsTdQDWW/+WwZ5EETT4PZyqEinmlAqUFwuF8XFxRoGfcAYQ3FxMS6Xq0fnBe2oIYCJw62RQztPVDBuWP8uFq2UsqSnp5OXl0dhYWGgizIkuFwu0tN7tuZX8ATBkU/gvV/BTX8GlxUAY5MjcdqFXScqWZoV2OIpFaycTiejRunIvUAKnqYhZxgc/QQ+X9G8K8RhY0xSpI4cUkoFteAJgtSZMO4SWP8bqG8ZKTQ5JVpnIVVKBbXgCQKA+fdCbSl8/kzzrokpUZyqqKekWqeaUEoFp+AKgvRZMHaxr1ZQBejaBEop5bcgEJHnRKRARLZ38r2IyOMisl9EtolIdkfH9bn590FNMeT+DoBpaTGE2G28u+Nkv/x4pZQaaPxZI/g9sKSL7y8Fxvley4Gn/FiWFiNmw5gLYd3j0FBNbHgIV8xI5S8b8yivbeyXIiil1EDityAwxnwIdDXh/1LgD8byKRArIin+Kk8b8++FmiLIfR6AO84bRU2Dh5c/P9ovP14ppQaSQPYRpAHHWn3O8+07jYgsF5FcEcntk4dORs6FUfNh3WPQUMPk1Gi+NDqBFz45rNNNKKWCTiCDoKPFPzt8xtwYs8IYk2OMyUlKSuqbnz7/XqgugE0vAFatIL+8jre3a1+BUiq4BDII8oARrT6nA/n99tMz50Hm+fDxo9BYx4UTkxmVGMHvPj7Ub0VQSqmBIJBB8BbwNd/ooblAuTHmRL+WYP69UHUSNr2AzSbcPi+TLcfK2HiktF+LoZRSgeTP4aMvAeuBCSKSJyJ3iMidInKn75BVwEFgP/AM8G1/laVTmefByHPh4/+FxjquzU4n2uXgOa0VKKWCiN8mnTPG3HiG7w3wHX/9/G4RgQX3wh+WwraXiZh1GzfOGckzHx0kr7SG9LjwgBZPKaX6Q3A9WdyRUfMheTLkPgfAredmIiK88MnhwJZLKaX6iQaBCMy6HU5sheObSI0N49Kpw3n582NU1bsDXTqllPI7DQKAGTeAMxw2tjxgVlnv5i+5x85wolJKDX4aBACuGJh6DXzxKtRVMHNkHNkjY3l+3WE8Xl0+Tyk1tGkQNJn1dWishi/+AsAd543maEkNa3adCnDBlFLKvzQImqRlw/BpVvOQMVwyZRipMS5e/EznH1JKDW0aBE2aOo1PfgHHN+Gw27gmO52P9xVyqqIu0KVTSim/0SBobdpXwBkBG62hpNdkp+E18OaW4wEumFJK+Y8GQWuuaJh2HWx/DerKGZ0UycyRsby68TjW829KKTX0aBC0l3M7NNbAtpUAXJudzp5TlezI16UslVJDkwZBe6kzISXLWrTGGC6fnkKI3carm/ICXTKllPILDYKO5NwOBTsgbwOx4SEsmpTMW1vyadRFa5RSQ5AGQUemXgchUc1LWV6bnU5xdQMf7OmD1dGUUmqA0SDoSGgkTP8K7HgNakuZPyGJhIgQXtuszUNKqaFHg6Azs24Hdx1s/TNOu40rs1JZs7OAspqGQJdMKaX6lAZBZ1KmW08a73wTsJqHGjxe/ratfxdRU0opf9Mg6Mr4S+HYp1BTwpTUaCYMi9LRQ0qpIUeDoCvjl4Dxwv73EBGuyU5j89EyDhZWBbpkSinVZzQIupI6EyKSYO9qAK6amYZN4LVNOuWEUmro0CDois0G4y6B/f8Aj5th0S7OG5fE65uP49V1CpRSQ4QGwZmMvwTqyuHYZwBcm53G8bJaPj1UHOCCKaVU39AgOJMxC8HmbG4eunjycCJDHby6UZuHlFJDgwbBmYRGQeY82PsOAGEhdhZNSubj/fqUsVJqaPBrEIjIEhHZIyL7ReS+Dr6PE5HXRWSbiHwuIlP9WZ5eG78EivZAyUEAZo6I5VRFPSfLdcEapdTg57cgEBE78ARwKTAZuFFEJrc77CfAFmPMdOBrwGP+Ks9ZGX+Jtd37LgBZI+MA2HKsNFAlUkqpPuPPGsEcYL8x5qAxpgF4GVja7pjJwHsAxpjdQKaIDPNjmXonfjQkjm/uJ5iUEoXTLmw5Vh7ggiml1NnzZxCkAcdafc7z7WttK3ANgIjMATKA9PYXEpHlIpIrIrmFhQFqmx9/CRz+GOorCXXYmZwSzdZjZYEpi1JK9SF/BoF0sK/94PuHgDgR2QL8C7AZcJ92kjErjDE5xpicpKSkPi9ot4xfAt5GOLAWgBkjYvnieDkefZ5AKTXI+TMI8oARrT6nA/mtDzDGVBhjbjfGZGH1ESQBh/xYpt4bcQ64YppHD81Ij6Wq3q3TTSilBj1/BsEGYJyIjBKREGAZ8FbrA0Qk1vcdwDeAD40xA3NxYLsTxi6Gfe+A18uMEbEAbNHmIaXUIOe3IDDGuIHvAu8Au4CVxpgdInKniNzpO2wSsENEdmONLvqev8rTJ8YvgepCyN/M6MQIolwODQKl1KDn8OfFjTGrgFXt9j3d6v16YJw/y9Cnxi4GscHe1djSZzEjPZateWWBLpVSSp0VfbK4J8Ljrb4C3zDSGSNi2H2ikrpGT4ALppRSvadB0FPjL4GT26Ainxnpsbi9hh35A7NbQymlukODoKfGL7G2e98hy9dhrM8TKKUGMw2CnkqaCLEjYe87JEe7SIlxaT+BUmpQ0yDoKRHImAcntgDW8wQ6ckgpNZhpEPRG4nioPAF1FcwYEcuR4hpKqxsCXSqllOoVDYLeSJpgbYv2tfQTaPOQUmqQ0iDojcSmINjDtPQYRGCrzkSqlBqkNAh6Iy7TWr6yaC+RoQ7GJUdqjUApNWhpEPSG3QEJY6BwL9DSYWyMzkSqlBp8NAh6K3EcFPmCYEQsJdUN5JXWBrhQSinVcxoEvZU4wVrD2N3Q3GGsw0iVUoORBkFvJY4H44GSg0wYHkWow6ZPGCulBiUNgt5KGm9ti/bitNuYmhajHcZKqUFJg6C3EnyzZxftAawO4y+Ol+P2eANYKKWU6jkNgt4KjYTodCjaB1hTUtc1etlzqjLABVNKqZ7RIDgbSeOh0KoRtMxEqg+WKaUGFw2Cs5E43qoReL2MjA8nNtypHcZKqUGnW0EgIt8TkWix/E5ENonIxf4u3ICXOB4aq6EyHxHRpSuVUoNSd2sEXzfGVAAXA0nA7cBDfivVYNE0+Vyr5qG9pyqpqncHsFBKKdUz3Q0C8W0vA543xmxttS94JTYNIbU6jHMy4/Aa2HSkNICFUkqpnuluEGwUkXexguAdEYkCdJxkRBK4YpuHkM4cGYdNIPdwSWDLpZRSPdDdILgDuA+YbYypAZxYzUNdEpElIrJHRPaLyH0dfB8jIn8Vka0iskNEznjNAUXEqhX4Jp+LDHUwOTWaDYe1RqCUGjy6GwRfAvYYY8pE5Bbgp0CX4yRFxA48AVwKTAZuFJHJ7Q77DrDTGDMDWAD8j4iE9KD8gZc0vnnyOYCcjHg2HyulUR8sU0oNEt0NgqeAGhGZAfwYOAL84QznzAH2G2MOGmMagJeBpe2OMUCUiAgQCZQAg6unNXECVBdArVULmDMqnrpGLzvyKwJcMKWU6p7uBoHbWJPtLwUeM8Y8BkSd4Zw04Firz3m+fa39BpgE5ANfAN8zxpz2p7SILBeRXBHJLSws7GaR+0n7DuOMOAA2HNJ+AqXU4NDdIKgUkfuBrwJ/9zX7OM9wTkejitqv3HIJsAVIBbKA34hI9GknGbPCGJNjjMlJSkrqZpH7SdPkc74hpMnRLjISwtmgHcZKqUGiu0FwA1CP9TzBSay/7P/7DOfkASNafU7H+su/tduB14xlP3AImNjNMg0MsRlgD20eOQRWP0HukVJdsUwpNSh0Kwh8v/xfBGJE5HKgzhhzpj6CDcA4ERnl6wBeBrzV7pijwCIAERkGTAAO9qD8gWezQ8LY5qYhgNmZcZRUN3CwqDqABVNKqe7p7hQT1wOfA18Brgc+E5HrujrHGOMGvgu8A+wCVhpjdojInSJyp++wXwHnisgXwHvAvcaYot7dSgAljmtuGgKYPSoe0H4CpdTg4Ojmcf+G9QxBAYCIJAFrgFe6OskYswpY1W7f063e52NNWzG4JU2AXW9BYx04XYxOjCA+IoQNh0tZNmdkoEunlFJd6m4fga0pBHyKe3Du0Jc4HowXSg4AICLkZMSRe0RrBEqpga+7v8xXi8g7InKbiNwG/J12f+kHtcS2I4cAZmfGc6S4hoKKugAVSimluqe7ncX3ACuA6cAMYIUx5l5/FmxQSRgLSJsO45xM63mCXJ2ATik1wHW3jwBjzKvAq34sy+AVEg6xI9oMIZ2aFoPLaePzQyVcNi0lgIVTSqmudRkEIlLJ6Q+BgfWwmDHGnPbwV9BKnNBmziGn3cbMEdpPoJQa+LpsGjLGRBljojt4RWkItJM0AYr2g7dlhozZmXHszK/QhWqUUgOajvzpK4njwF0L5Uebd+VkxuM1sPmo9hMopQYuDYK+kuhbtrJVh3F2hrVQja5PoJQayDQI+koHQ0ibF6rRJ4yVUgOYBkFfiUiA8IQ2HcagC9UopQY+DYK+1G7kEFgPlulCNUqpgUyDoC81TT7Xavrp2U0Plun6BEqpAUqDoC+lzYLaEsjf1LxLF6pRSg10GgR9acrV4AyHjS+02Z2TEU/uYV2oRik1MGkQ9CVXtBUG21+F+qrm3bMz4yiubmDPqcoAFk4ppTqmQdDXsm+FhirY8VrzrkWThhHmtPPk2gMBLJhSSnVMg6CvjZgDSRPbNA8lRYVyx3mjeGtrPtuPlwewcEopdToNgr4mAtlfg+O5cGpH8+7l80cTF+7kv1bvDmDhlFLqdBoE/jB9GdhDYNMfmndFu5x8Z+FYPtpXxCf7B9+yzEqpoUuDwB8iEmDSFbD1ZWsdY59b5maQGuPiodW7dQSRUmrA0CDwl+yvQV0Z7Ppr8y6X084PLhrPtrxyVn1xMnBlU0qpVjQI/CXzAojLhE1tnym4Jjud8cMiefjdPTr/kFJqQPBrEIjIEhHZIyL7ReS+Dr6/R0S2+F7bRcQjIvH+LFO/sdmsWsHhj6C4Zdio3Sb8+JKJHCqqZmXusQAWUCmlLH4LAhGxA08AlwKTgRtFZHLrY4wx/22MyTLGZAH3Ax8YY4bOXAxZN4PY23QaAyyalExORhyPrdlHbYMnQIVTSimLP2sEc4D9xpiDxpgG4GVgaRfH3wi85Mfy9L+o4TB+CWz5P/A0Nu8WEe67dCIFlfU8t+5QAAuolFL+DYI0oHXbR55v32lEJBxYArzayffLRSRXRHILCwv7vKB+NetWqC6Avavb7M7JjGfxpGSefv8ApdUNASqcUkr5Nwikg32djZm8AljXWbOQMWaFMSbHGJOTlJTUZwXsF2MWQVTqaRPRAdxzyUSqG9z89I3tOpxUKRUw/gyCPGBEq8/pQH4nxy5jqDULNbE7YOYtsH8NlBxs89WE4VH8eMlE/v7FCZ76QOchUkoFhj+DYAMwTkRGiUgI1i/7t9ofJCIxwHzgTT+WJbByboeQSPjbD9osWgPwrQtGc+WMVP77nT2s3V0QoAIqpYKZ34LAGOMGvgu8A+wCVhpjdojInSJyZ6tDrwbeNcZU+6ssARedChc9CAffh81/bPOViPBf105ncko0d7+8mYOFVR1fQyml/EQGW9t0Tk6Oyc3NDXQxes7rhT9cCSe2wnc+s8KhlbzSGq78zTriwp288Z15RLmcASqoUmooEpGNxpicjr7TJ4v7i80GVz5uDSPtoIkoPS6cJ27K5nBxDT/48xa83sEV0EqpwUuDoD/Fj4ZFP7OGkn7xl9O+/tKYBH5++WTW7Crg0TV7A1BApVQw0iDob+fcCemz4e0fQ9XpncNf+1IG1+ek8/g/9/O3bZ0NslJKqb6jQdDfbHZY+gQ0VMOqe077WkT41VVTycmI43svb+GNzccDUEilVDDRIAiEpAmw4D7Y+QbsPH3UbKjDzu+/PodzRsXzg5Vb+OP6w/1eRKVU8NAgCJRz74bh0+HvP4Sa0x+ojgx18Nxts1k0MZmfvbmDJ9/fH4BCKqWCgQZBoNidVhNRbSn8+RZorD3tEJfTzlO3zGJpVir/b/UeHnpbVzZTSvU9DYJASpkOV/8WjnwCK29tM0NpE6fdxv9en8XN54zk6Q8O8NM3tuvQUqVUn3IEugBBb9p1UFcOf/9XeOMuuHqF9cxBKzab8O9XTSXK5eTpDw5QVtvI/7t2OhGh+j+fUurs6W+SgWD2Hdb6xu/9EkKj4cv/A9J28tamNQziwp08tHo3e05W8uTN2YwfFhWYMiulhgxtGhoozvtXqwM593fwz191eti35o/hT3ecQ1lNA0t/s45XNub1YyGVUkORBsFAIQIX/RKyb4WP/gfWPd7pofPGJrLq7vOZMSKGH/1lKz9+ZasueamU6jUNgoFEBC7/X5hyNfzjZ/DZb0+bk6hJcrSLP91xDv9y4Vj+sjGPq59cxwGduVQp1QsaBAONzW51GI+/1JqG4q93g7u+w0Mddhs/vHgCv799DgWV9Vz++Mc8+9FB3B5vPxdaKTWYaRAMRI4QWPYinP9D2PQHeP4yKO98qon545NYdff5nDsmgX//+y6u/M06th4r67/yKqUGNQ2Cgcpmh0U/h+v/CIW7YcV8OLyu08OHx7h49tYcnro5m+Lqeq56ch2/eGsHlXWnP5uglFKtaRAMdJOvhG+8B64Ya2GbLvoNRIRLp6Ww5l/n87W5Gbyw/jCLH/mAt784oU8kK6U6pUEwGCRPhG/+E8ZeZPUbvLbcegitE1EuJw8uncrr355HfEQod724iZuf/Yztxzs/RykVvHSpysHE64WPHob3H7KWurzqKRh1fpenuD1eXvzsKI+u2UtZbSNXz0zjRxdPIDU2rJ8KrZQaCLpaqlKDYDDKy7VqBSUH4EvfhQt/Bk5Xl6dU1DXy5NoDPLfuEAJ84/xR3Dl/jK6NrFSQ0CAYihqq4R8/hw3PQtJEuGYFpMw442l5pTU8/M4e3tiST0JECHfOH8ON54wkUuctUmpI0yAYyvavgTe/C9WFMP8+OO/71hTXZ7Atr4z/Wr2bdfuLiXI5+OrcDG6fN4qkqFD/l1kp1e8CFgQisgR4DLADzxpjHurgmAXAo4ATKDLGzO/qmhoEHagpgVU/gu2vwrCpcMXjkD6rW6duPVbGbz88wNvbT+K027g2O53lF4xmVGKEnwutlOpPAQkCEbEDe4GLgDxgA3CjMWZnq2NigU+AJcaYoyKSbIw5fUX3VjQIurDrb9Y6yJUn4Jw74cJ/g9DuzU56qKiaZz46yCsb82j0eFk8aRhfnZvBeWMTsdnkzBdQSg1ogQqCLwG/MMZc4vt8P4Ax5j9bHfNtINUY89PuXleD4AzqKqzprDc8C9Fp1pTWE5Z0+/SCyjpe+OQwL31+jJLqBjISwrn5nJF8ZdYI4iJC/FhwpZQ/dRUE/nyOIA041upznm9fa+OBOBF5X0Q2isjX/Fie4OCKhi8/DHe8a9UGXrrBWv2s5FC3Tk+OcnHPJRNZf/+FPLYsi+SoUP5j1W7O+c/3+Nc/byH3cIk+nKbUEOPPoSIdtSe0/w3iAGYBi4AwYL2IfGqM2dvmQiLLgeUAI0eO9ENRh6ARc+BbH8Inj8OH/w27/gozboQLfgjxo894eqjDztKsNJZmpbH7ZAUvfnqU1zcf57XNxxkZH85VWalcNTON0UmR/XAzSil/CnTT0H2AyxjzC9/n3wGrjTF/6ey62jTUCxUnYN1jsPF5a13k6TfABT+ChDE9ukxVvZvV20/yxubjrDtQhDEwIz2Gq2emcfmMVBIjdcSRUgNVoPoIHFidxYuA41idxTcZY3a0OmYS8BvgEiAE+BxYZozZ3tl1NQjOQuVJa8Gb3OfAUw/TvgLn/wiSxvf4UifL6/jr1nxe33ycnScqsNuE88YmsjQrlYunDNfnEpQaYAI5fPQyrKGhduA5Y8yvReROAGPM075j7gFuB7xYQ0wf7eqaGgR9oKrAqiHkPgeNtdbEduf/sFsPpHVkz8lK3thynLe25HO8rJZQh43Fk4exdEYq8yckEeqw9/ENKKV6Sh8oUx2rLoJPn4TPn4H6Chh3sRUII+f26nLGGDYdLeXNLfn8bdsJSqobiHY5WDx5GJdOTeH8cYm4nBoKSgWCBoHqWl25FQafPgk1xZBxHpz3Axi7yFo+sxcaPV7W7S/ira35rNl5ioo6NxEhdhZMTObSqcNZMCFZm4+U6kcaBKp7GqqtFdHWPQ6V+ZA4AebeBTOWgbP3s5U2uL2sP1jM6u0n+cfOkxRVNRDisDFvTAILJyazcEIyI+LD+/BGlFLtaRConnE3wI7XYP0TcHIbhMVDztdhzjchavhZXdrjNeQeLuHt7SdZu6eAI8U1AIxJiuBCXyjkZMYT4tClMpTqSxoEqneMgSPrYP2TsGcV2Bww5WqYdStkzOt1s1HL5Q2HiqpZu6eQ9/cU8NnBEho8XsKcdmaPiufcMQmcOyaBKakx2HWaC6XOigaBOnvFB6xlMre+ZHUsx4+GrJsh6yZrkZw+UF3vZt3+ItbtL+KTA8XsK6gCIMrlYO7oBOaOTiAnI47JqdE47VpjUKonNAhU32mogV1vwaY/wpGPQWzWEpozb4GxiyGk79r6CyrrWH+gmPUHivnkQDFHS6xmJJfTxoz0WGZlxDErI47skXE6D5JSZ6BBoPyj+ABs/hNs+T+oOgmOMBizECZcBuOXQGRSn/64/LJaNh0tZeORUjYdKWVHfgVur/X/3/S4MKalxTA1LYZpvpeGg1ItNAiUf3nccPgjqx9h9yqoyAPEmu9owmUw6YoeT2fRHbUNHr44Xs6mo6V8cbyc7cfLmzufAdJiw5icGs2klGgmDY9iYko0GfHhOq22CkoaBKr/GGONNNrzNuz+u/UerAVzJl0Bk66E5Eln3dHcmfKaRrbnlzcHw64TFRwqqsZXcSDMaWfC8CjGD4tkbHIkY5KsbXpcuHZIqyFNg0AFTtlRa8GcXW/B0U8BAwljrUAYfwmkzerW0ppno67Rw75TVew6WcGuExXsPlHJvoIqiqrqm48JcdgYnRjBmKRIMhLCyUyIsLaJESRHhSJ+Ci6l+osGgRoYKk/Bbl8oHPoIjAdComDU+TB6odW/kDDWb7WF9sprGtlfWMn+gioOFFazv6CKQ0XVHCupae57AKsWkZEQTnpcOOlxYaTHhTEivul9ONEuhwaFGvA0CNTAU1MChz6Eg2vhwFooO2Ltj06HzHmQmg2pM2H4tD4didQdbo+X/LI6DhVXc6S4msNFNRwtqSavtJZjJTVUN3jaHB/qsJEYGUpSVKhvG0JSZCjxESHER4YSHx5CfEQICZEhxIWH6MNyKiA0CNTAV3LQCoSDa+HY51B1ytovdqtPITULUrIgebL1OTw+IMU0xlBW00heaS15pTUcL6ulsLLeelVZ26KqBkqq6/F28k8ryuVoCY3mAAkhPiKUmDAn0WEOol1O33sn0S4HDn1uQp0lDQI1+FScgPxNkL/Zeh3fBLUlLd9HJFuBkDzJCof0HEiaCLaBMbupx2soq2mgpLrlVdy0rbLCwgoNK0Aq69xdXi8mzElCZAgJESEkRIQS73vfEhYtARLtchIT7tQmK9WGBoEa/IyBinwo3AUFu6FgV8v7xmrrmNBoq/N5xDkwYjak5UBYbECL3V11jR5KaxqoqHVTUddIRW0j5bXWtqy2sTlIiqvqKa7yhUtNA13987XbhLhwJ7HhIadtY8KcxIY7iQ0Laa6FhIfYcTnthIc4CHPacTltGiRDiAaBGrq8Xig9BHm5cOwzq1mpYAcYr/V9zEhIHNfqNd56RSSDbXA3t3i8hqp6N5V1jacFSHltI6U1DZTWNFJa3UBpTQNlNY3Wq7aBukZvt35GeIidyFAHUS4HUS4nUS6r1hHlchAT7iTOFy4xYdY2zldLiXJZYaJBMnB0FQQ6Ibwa3Gw262G1hDEw4wZrX30lHN8IeRugcC8U7YVNn7bUHAAQcEVDaAy4Yqz3rhiITrNqFCPPgZgR/TaCqTfsNiEmzOpLIK5n59Y1epprG2U1VoDUNnqobfBQ2+ihpsFDbYOb6gYPVXVuKusbqaxzU1nnJr+sloo6N+U1jTR4Og8Uh03aBEhkqIOIUAdhIXbCnfbm9y1B4yAq1NnmnLAQu692YtfnPPxIawQqODQ1LRXthaJ9UF1oLchTX2Ft63zbkoMtgRGVagXCiLlWH0T8aAiLG9Dh0J+MMdQ0eJprG001kIraptBo2VbUuamqd1Pb4KGmwU1Ng8f3ctPo6d7voBCHjTCnFQzhIXYiXQ4iQqxwiQy1gqWpecvltBPqsDW/Dw+xnxZAESFNW3tQdMZrjUApEYhJs15jFnZ+nMdtNS0d/QyOfWptd7ze8n1oNMRlQFwmxGZYr7A4X+0ium3tIiRySIeGiBDh+ys/vYc1ktbq3Z7m2kZTM1dlXSOV9W7qWtVSahs91Dd6qfHVVKrr3VTXuzleVtv8vqbBQ53b02XfSUdCHbbmGkt4iBUWLqeNUIe9eRvqsBHqsBHS9LLbcTqEELu1P9TZUnsJc9oJC7H5AsmOwybYbYLDLjhsNhy+96EOO067BLwJTWsESp1JeR6c2AqlR6D0cNuXp77z80KiIHbk6a/oNGuBn8hkvz9VHYyMMTR4vNQ1eqlr9Fhh0tzcZQVI02crQKyaSZUvTJpCpt7tpd5thU9d07bRQ6PHur6ns/HBPWQTmgOnqSbjsFth4bTbsNsEp90KkqVZadw4Z2Svfo7WCJQ6GzHp1qs9r7eDJqZW7yvyrSk2yo5aC/zUV7S7gEBEohUKUSnWug6xI301DV9oDIFO7f4mIr6/4O1W/4mfeLyGBreXBo+XBl9oWDUYb3MNpimIvMbQ6DF4vAa31+DxeGn0GN853uZtXaOHercXt9eL22Md6/Ya3B6vdV4fhU97GgRK9ZbNBlHDrFd31JZZT1BXnIDKE9ZDc5UnoPKktT2+CWqK2p5jD4XoFKuZyRluPWXtjLDWkG793hnu2/reR6dA/Bir9qFB4hd2m1id2QyMZ1fOhgaBUv0lLNZ6pczo/JiGaqspquyo1fRUdtQKiYYaqxO7sRaqi633DTXW58Ya8DZ2fD17KMSPskIhfhREDgOHC5wua+sItbah0VaNJDpVm6uCkF+DQESWAI8BduBZY8xD7b5fALwJHPLtes0Y80t/lkmpAS0kApImWK+e8DT6QqEWGqqsMCk5YC0eVHLIer9/Tdd9GgCIFRYxaVZtImaE1UQVl9HSbBUa2evbUwOT34JAROzAE8BFQB6wQUTeMsbsbHfoR8aYy/1VDqWCgt1pvVzRwDDruYrR89se4/VaNQl3vRUY7npw11nbulIoPw4Vx33bPOvp7f1rrBpHa2HxEDvCqkW0rlW0qWl08Dk8vqXvwxXTb/9p1Jn5s0YwB9hvjDkIICIvA0uB9kGglOoPNhuERlmv7jIGqot8nd6HfU1WR6D8mNU0VVPcNlDctS1BYzydX9cV01LDiBruC40wX6j4ts5wKzwikiA8wepYH+JDcgPFn0GQBhxr9TkPOKeD474kIluBfOBHxpgd7Q8QkeXAcoCRI3s3dEop1Qsi1trTkUmQPqtn5zY1V7nrrG1NccsoqqZX8QE48klLiJgzTH1hD/UFQkTbmkjTe1esr7zDrOG5EcnW+4hE6zvtOO+QP4Ogo9huP/ZpE5BhjKkSkcuAN4Bxp51kzApgBVjPEfRxOZVS/tDUXEW09TkuA9KyOz/eGPC6W5qtGqut8Kgutobp1hRZtZOaYqu5qnVNpLYUGuugrswajeXtYDZXsVnNWuEJvle81XnfVANpDhZfrSQsDsLjrGObzguJGJI1En8GQR4wotXndKy/+psZYypavV8lIk+KSKIxpt0YOqXUkCfSKjwAkqwnuHvK620JhKoC61VdYC2GVFtiBUlNiTWdSG2Z1YHeWNfNGkmI1azV1MQWGt3y3h7Sch/WG+u9zWkFSEikb+t7OcOtcGp++Y4Xu+/p9FgrjEKj/V6T8WcQbADGicgo4DiwDLip9QEiMhw4ZYwxIjIHsAHFfiyTUmqos9msv/bD4631KnrC425pyqotbRscNcXW57oKa2LDplfZMagvt85tavQwpuW9p8HqTznjiK1OiM03dUkszL4Dzv2X3l2nC34LAmOMW0S+C7yDNXz0OWPMDhG50/f908B1wF0i4gZqgWVmsM15oZQaOuwOsEdaQ2Qjk/r22p5G6zmRpldjtVUDMfi2vpfXbQVMbalVs6kttWoutaVWf4cf6FxDSikVBLqaa0i70JVSKshpECilVJDTIFBKqSCnQaCUUkFOg0AppYKcBoFSSgU5DQKllApyGgRKKRXkBt0DZSJSCBzp5emJQLDOYxSs9673HVz0vjuXYYzp8HHpQRcEZ0NEcjt7sm6oC9Z71/sOLnrfvaNNQ0opFeQ0CJRSKsgFWxCsCHQBAihY713vO7joffdCUPURKKWUOl2w1QiUUkq1o0GglFJBLmiCQESWiMgeEdkvIvcFujz+IiLPiUiBiGxvtS9eRP4hIvt827hAltEfRGSEiKwVkV0iskNEvufbP6TvXURcIvK5iGz13feDvv1D+r6biIhdRDaLyN98n4f8fYvIYRH5QkS2iEiub99Z3XdQBIGI2IEngEuBycCNIjI5sKXym98DS9rtuw94zxgzDnjP93mocQM/NMZMAuYC3/H9bzzU770euNAYMwPIApaIyFyG/n03+R6wq9XnYLnvhcaYrFbPDpzVfQdFEABzgP3GmIPGmAbgZWBpgMvkF8aYD4GSdruXAi/43r8AXNWfZeoPxpgTxphNvveVWL8c0hji924sVb6PTt/LMMTvG0BE0oEvA8+22j3k77sTZ3XfwRIEacCxVp/zfPuCxTBjzAmwfmECyQEuj1+JSCYwE/iMILh3X/PIFqAA+IcxJijuG3gU+DHgbbUvGO7bAO+KyEYRWe7bd1b37ejjAg5U0sE+HTc7BIlIJPAq8H1jTIVIR//TDy3GGA+QJSKxwOsiMjXARfI7EbkcKDDGbBSRBQEuTn+bZ4zJF5Fk4B8isvtsLxgsNYI8YESrz+lAfoDKEginRCQFwLctCHB5/EJEnFgh8KIx5jXf7qC4dwBjTBnwPlYf0VC/73nAlSJyGKup90IR+RND/74xxuT7tgXA61hN32d138ESBBuAcSIySkRCgGXAWwEuU396C7jV9/5W4M0AlsUvxPrT/3fALmPMI62+GtL3LiJJvpoAIhIGLAZ2M8Tv2xhzvzEm3RiTifXv+Z/GmFsY4vctIhEiEtX0HrgY2M5Z3nfQPFksIpdhtSnageeMMb8ObIn8Q0ReAhZgTUt7CngAeANYCYwEjgJfMca071Ae1ETkPOAj4Ata2ox/gtVPMGTvXUSmY3UO2rH+sFtpjPmliCQwhO+7NV/T0I+MMZcP9fsWkdFYtQCwmvb/zxjz67O976AJAqWUUh0LlqYhpZRSndAgUEqpIKdBoJRSQU6DQCmlgpwGgVJKBTkNAqX6kYgsaJopU6mBQoNAKaWCnAaBUh0QkVt88/xvEZHf+iZ2qxKR/xGRTSLynogk+Y7NEpFPRWSbiLzeNBe8iIwVkTW+tQI2icgY3+UjReQVEdktIi9KMEyIpAY0DQKl2hGRScANWJN7ZQEe4GYgAthkjMkGPsB6ahvgD8C9xpjpWE82N+1/EXjCt1bAucAJ3/6ZwPex1sYYjTVvjlIBEyyzjyrVE4uAWcAG3x/rYViTeHmBP/uO+RPwmojEALHGmA98+18A/uKbDybNGPM6gDGmDsB3vc+NMXm+z1uATOBjv9+VUp3QIFDqdAK8YIy5v81OkZ+1O66r+Vm6au6pb/Xeg/47VAGmTUNKne494DrffO9N68FmYP17uc53zE3Ax8aYcqBURM737f8q8IExpgLIE5GrfNcIFZHw/rwJpbpL/xJRqh1jzE4R+SnWKlA2oBH4DlANTBGRjUA5Vj8CWNP+Pu37RX8QuN23/6vAb0Xkl75rfKUfb0OpbtPZR5XqJhGpMsZEBrocSvU1bRpSSqkgpzUCpZQKclojUEqpIKdBoJRSQU6DQCmlgpwGgVJKBTkNAqWUCnL/H8IFD/4snB5TAAAAAElFTkSuQmCC", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "%matplotlib inline\n", "from matplotlib import pyplot as plt\n", "\n", "plt.plot(epoch_nums, training_loss)\n", "plt.plot(epoch_nums, validation_loss)\n", "plt.xlabel('epoch')\n", "plt.ylabel('loss')\n", "plt.legend(['training', 'validation'], loc='upper right')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "fully_connected1.weight \n", " [[-0.00374341 0.2682218 -0.41152257 -0.3679695 ]\n", " [-0.17916061 -0.08960593 0.11843108 0.5180272 ]\n", " [-0.04437202 0.13230628 -0.15110654 -0.09828269]\n", " [-0.47767425 -0.33114105 -0.20611155 0.01852179]\n", " [ 0.22086579 0.5711509 -0.40086356 -0.18697421]\n", " [ 0.31580442 0.24776897 -0.20200174 0.39890492]\n", " [-0.08059168 0.05290705 0.4527381 -0.46383518]\n", " [-0.3545517 -0.15797205 -0.23337851 0.39141223]\n", " [-0.32408983 -0.23016644 -0.34932023 -0.4682805 ]\n", " [-0.47349784 0.8002842 0.30180416 0.15444154]]\n", "fully_connected1.bias \n", " [ 0.02629578 -0.20744474 0.08459234 -0.46684736 -0.35585782 -0.45410082\n", " 0.31546897 0.25728968 -0.22174752 0.24439509]\n", "fully_connected2.weight \n", " [[ 0.20224687 0.3143725 0.12550515 0.04272011 0.21202639 -0.18619564\n", " 0.05892715 -0.24517313 -0.21917307 -0.16335806]\n", " [ 0.14308453 0.08098809 -0.18731831 0.09553465 0.7475572 -0.01170831\n", " 0.01207405 0.03671877 0.19618031 0.71772873]\n", " [-0.24369258 -0.09592994 0.12428063 0.2620103 0.44033986 0.32761905\n", " 0.06293392 -0.24256472 0.02909058 -0.6438864 ]\n", " [-0.29470977 0.4369507 0.2404469 -0.31544605 -0.65187347 -0.03367811\n", " -0.05203882 -0.09720274 0.12160733 -0.44794998]\n", " [ 0.11592636 0.15991893 0.22637847 0.11824107 -0.31298175 -0.20513597\n", " 0.15789726 0.0661869 -0.24668422 -0.1820901 ]\n", " [ 0.29749104 0.33983657 -0.13788326 -0.07958971 -1.0037647 0.04011776\n", " -0.23813814 -0.21048178 -0.01742402 -0.21410409]\n", " [-0.12950484 0.18764248 -0.19243696 0.2869356 0.21671084 -0.26666948\n", " -0.07870413 0.01426902 0.04613796 0.07500109]\n", " [ 0.12409672 0.01894209 -0.15429662 0.1496355 -0.30334112 -0.1874303\n", " -0.07916126 -0.15403877 -0.11062703 -0.25918713]\n", " [-0.06726643 0.16598707 -0.20601156 -0.01622862 -0.10633215 -0.07815906\n", " 0.00878868 0.00450952 0.06399861 0.4654336 ]\n", " [ 0.29954556 0.20082232 0.3002309 -0.02287012 -0.2840742 -0.14991638\n", " 0.21532115 -0.00204995 -0.15717986 -0.24232906]]\n", "fully_connected2.bias \n", " [-0.2959424 -0.09140179 -0.24091302 0.11557585 0.17096573 -0.3224678\n", " 0.19725719 -0.24745122 0.03521875 -0.1282217 ]\n", "fully_connected3.weight \n", " [[-0.06091028 -0.06208903 -0.28376698 -0.27304304 -0.04948315 0.0040895\n", " -0.14365433 0.11912274 -0.28462344 -0.02134135]\n", " [ 0.27809682 -0.41300255 0.27310103 0.7309681 -0.2853832 0.6525562\n", " -0.03649095 -0.14116624 -0.0045454 -0.25554216]\n", " [ 0.03393281 -0.19290853 0.71934235 -0.31080088 0.15194914 -0.3314264\n", " -0.07604478 -0.06650442 -1.1165304 0.17134616]]\n", "fully_connected3.bias \n", " [ 0.25107792 0.10447468 -0.24180876]\n" ] } ], "source": [ "for param_tensor in model.state_dict():\n", " print(param_tensor, \"\\n\", model.state_dict()[param_tensor].numpy())" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWIAAAElCAYAAADeAeiuAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAnqklEQVR4nO3dfbzlY73/8dd7xn0IDY7cRBoVTqEhpRtRkgrdGp1K5aQ6pXTzC906lXNUKqU6miLqaCRRKickNyV3Y9yLKGIyYUK5Z3j//riuXctu773W7L32fNda+/3s8X3sta7vd13fa63MZ13r872+1yXbREREc6Y13YCIiKkugTgiomEJxBERDUsgjohoWAJxRETDlmm6Af1Gy6xoLbdK083oWVs8fYOmm9DzHs1IpbYuu2T+IttrTqSO6as+yV58f9vjfP/tp9reeSLnmqgE4iWk5VZh+ae+vulm9Kxzzz+86Sb0vPsefKTpJvS8Gass+8eJ1uHFD7D802a3Pe6BSw6fMdFzTVQCcUQMJgFS063oSHLEETG4NK391kk10lGSbpN05bDyfSVdK+kqSZ9rKT9Q0vV130vb1Z8ecUQMru71iI8Gvgp85x9V60XAbsAzbD8oaa1avikwG9gMeCLwC0mb2B41J5UecUQMKMG06e23Dtg+B7hjWPG7gENsP1iPua2W7wYcZ/tB2zcA1wPbjFV/AnFEDCbRaWpihqR5Lds+HZ5hE+D5ki6QdLakrWv5usDNLcctqGWjSmoiIgaUOk1NLLI9axwnWAZYHdgW2Bo4XtKTy4n/yZhjFhOII2JwdXgxbpwWACe6TGF5oaRHgRm1fP2W49YDbhmroqQmImJwSe238fsRsEM5jTYBlgMWAScDsyUtL2kjYCZw4VgVpUccEYNJ6vhiXPuqNBfYnpJPXgB8EjgKOKoOaXsI2Kv2jq+SdDxwNbAYePdYIyYggTgiBlmXUhO29xxl1xtHOf5g4OBO608gjogBpcnOEXdNAnFEDK5p/XGLcwJxRAymoXHEfSCBOCIGV59M+pNAHBEDqnujJiZbAnFEDK6kJiIiGjTxGzaWmgTiiBhc6RFHRDQsPeKIiCblYl1ERLMyjjgiomm5xTkionnJEUdENCw94oiIhqVHHBHRoC5ODD/ZEogjYmCpT3rE/ZFAiYhYQqIE4nZbR3VJR0m6rS6LNHzfhyRZ0oyWsgMlXS/pWkkvbVd/AnFEDCZ1uHXmaGDnfzqFtD7wEuCmlrJNgdnAZvU1X5c0Zo4kgTgiBlT73nCnPWLb5wB3jLDrS8CHAbeU7QYcZ/tB2zcA1wPbjFV/csQRMbA6DLQzJM1reT7H9pwO6t4V+JPty4adZ13g/JbnC2rZqHq+RyzpVTX/8rRR9p8laVabOv5+jKRTJK02CU2NiB4zbdq0thuwyPaslq2TILwS8FHgEyPtHqHMI5T9o52dvJmG7Qn8mpJzmTDbu9i+qxt1RUQP626OeLiNgY2AyyTdCKwHzJf0L5Qe8Potx64H3DJWZT0diCWtDGwH7E0NxJJWlHScpMslfR9YseX4nSSdJ2m+pB/U1w+v88ahq5uS3ijpQkmXSvpGu4R6RPQPdTFHPJztK2yvZXtD2xtSgu9Wtv8MnAzMlrS8pI2AmcCFY9XX04EY2B34ue3fAXdI2gp4F3Cf7WcABwPPAqjB9WPAi21vBcwDPjBaxZKeDuwBbGd7C+AR4N9GOXYfSfMkzfPi+7v13iJiknVx+Npc4DzgqZIWSNp7tGNtXwUcD1wN/Bx4t+1Hxqq/1y/W7QkcVh8fV5/PBL4CYPtySZfX/dsCmwLn1g93OcoHN5odKUH8onr8isBtIx1Yc0ZzAKattNaYuZ6I6B3duqHD9p5t9m847PnBlI5iR3o2EEt6ArADsLkkA9MpCe9LGDnxLeD0dh/YsOOPsX1gN9obET1GoGm5s26iXgt8x/aTah5mfeAGYD41hSBpc+AZ9fjzge0kPaXuW0nSJmPUfwbwWklr1ePXkPSkSXovEdGAycoRd1vP9ogpaYhDhpX9ENgSWLGmJC6lJsFt3y7pLcBcScvX4z8G/G6kym1fLeljwGmSpgEPA+8G/tjl9xERDRi6WNcPejYQ295+hLKvtHnNL4Gtx6qrNZdj+/vA9yfQzIjoYQnEERFN6484nEAcEQNK6RFHRDSu3sLc8xKII2Ig5WJdREQv6I84nEAcEQMqOeKIiOYlEEdENCyBOCKiYf0y10QCcUQMpF6aS6KdBOKIGFgJxBERDUsgjohoWn/E4QTiiBhQ6p9bnPujlRERS0iA1H7rqC7pKEm3Sbqypezzkq6pCxmfJGm1ln0HSrpe0rWSXtqu/gTiiBhQXV3F+Whg52FlpwOb14WMfwccCCBpU8qq85vV13y93QrxCcQRMbC61SO2fQ5wx7Cy02wvrk/PB9arj3cDjrP9oO0bgOuBbcaqP4E4IgZWhz3iGZLmtWz7jONUbwP+rz5eF7i5Zd+CWjaqXKyLiMHUeY93ke1Z4z6N9FFgMXDsP878T0Zaef7vEogjYiAJmD59csevSdoLeAWwo+2hYLsAWL/lsPWAW8aqJ6mJiBhYXbxYN1LdOwP7A7vavq9l18nAbEnLS9oImEldbX406RFHxGBagotxbauS5gLbU/LJC4BPUkZJLA+cXgP6+bbfafsqSccDV1NSFu+2/chY9ScQR8RAKuOIuxOJbe85QvGRYxx/MHBwp/UnEEfEgMrsaxERjeuTOJxAHBEDSjAtE8NHRDSnmzniyZZAHBEDq0/icAJxRAyu9IgjIhrWJ3E4gTgiBpNysW5wbfn0DTj3gq823YyetforD2u6CT3vzp/s13QTpoj+GUfcdq4JSRtLWr4+3l7Se1tnoo+I6FXdmo94snUy6c8PgUckPYVyS99GwPcmtVUREV0wmZP+dFMngfjROgv9q4DDbL8fWGdymxURMUEd9IZ7JA53lCN+WNKewF7AK2vZspPXpIiIieunGzo66RG/FXgOcLDtG+r8mv87uc2KiJi4adPUdusFbXvEtq+WtD+wQX1+A3DIZDcsImKiBqZHLOmVwKXAz+vzLSSdPMntioiYmD7KEXeSmjiIshT0XQC2L6WMnIiI6Fmi/YiJXukxdxKIF9v+67CyMVckjYjoBd3qEUs6StJtkq5sKVtD0umSrqt/V2/Zd6Ck6yVdK+ml7ervJBBfKekNwHRJMyUdDvyms+ZHRDRn+jS13Tp0NLDzsLIDgDNszwTOqM+RtCkwG9isvubrkqaPVXkngXjfWuGDwFzgb8B+nbY+IqIJpcfbndSE7XOAO4YV7wYcUx8fA+zeUn6c7Qfr4IbrKendUXUyauI+4KN1i4joGx12eGdImtfyfI7tOR28bm3bCwFsL5S0Vi1fFzi/5bgFtWxUowZiSYfZ3k/STxghJ2x71w4aGhHRmA57vItsz+rmaUcoG/O62lg94u/Wv4eOuzkREQ2a5EERt0pap/aG1wFuq+ULgPVbjlsPuGWsikYNxLYvrg/nAffbfhSgJp2XH2/LIyKWBlGGsE2ikylTPxxS//64pfx7kr4IPBGYCVw4VkWdXKw7A1ip5fmKwC+WsMEREUuX2o+Y6HTUhKS5wHnAUyUtkLQ3JQC/RNJ1wEvqc2xfBRwPXE25Ee7dth8Zq/5OJv1ZwfY9Q09s3yNppbFeEBHRC7qVmrC95yi7dhzl+IOBgzutv5Me8b2Sthp6IulZwP2dniAiogkCpkltt17QSY94P+AHkoaSzesAe0xaiyIiuqRH4mxbnYwjvkjS04CnUr5krrH98KS3LCJignplLol2Opl9bSVgf+B9tq8ANpT0iklvWUTEBHQyz0SvxOlOcsTfBh6iTA4PZYzcZyatRRERXTJdarv1gk4C8ca2Pwc8DGD7fka+cyQioqf0yzSYnVyse0jSitRb9CRtTJkAKCKiZ5VRE023ojOdBOJPUgYlry/pWGA74C2T2aiIiAnroR5vO52Mmjhd0nxgW8qXzPtsL5r0lkVETFCfxOGOesQALwSeR0lPLAucNGktiojoAsGSTPzeqLaBWNLXgadQJoUHeIekF9t+96S2LCJiggYmNUHpDW9ue+hi3THAFZPaqoiILuiPMNxZIL4W2AD4Y32+PnD5pLUoIqILJHpmLol2OgnETwB+K2loPs2tgfMknQxZqSMielefxOGOAvEnJr0VERGTYGByxLbPBpD0BOAFwE0tq3d0laS1gS9RhsrdSbm1+nO2l3iUhqT9KIsA3tfVRkZEXxCdT/zetFFvcZb0U0mb18frAFcCbwO+W4NcV6l8df0IOMf2k20/C5hNWe9pPPbjsSuLRMRUMiCT/mxk+8r6+K3A6bZfCTybEpC7bQfgIdtHDBXY/qPtwyVNl/R5SRdJulzSOwAkbS/pLEknSLpG0rEq3ktZK+pMSWfWY/eUdIWkKyV9dugco5VHRP/rl7kmxgrErXMO7wicAmD7buDRSWjLZsD8UfbtDfzV9taUi4Vvl7RR3bclpfe7KfBkYDvbX6Gsmvoi2y+S9ETgs5RgvwWwtaTdRysffnJJ+0iaJ2ne7Ytu78Z7jYilYFoHWyckvV/SVbXDNlfSCpLWkHS6pOvq39Un0s7R3CxpX0mvAraizDdBnQBo2fGesFOSvibpMkkXATsBb5Z0KXABZSTHzHrohbYX1FWmLwU2HKG6rYGzbN9uezFwLCXfPVr5Y9ieY3uW7Vlrzlizq+8zIiaH6E6PWNK6wHuBWbY3B6ZT0qYHAGfYnklZZPmA8bZ1rEC8N6WX+hZgD9t31fJtKXMUd9tVlIAPQL1zb0dgTcpnuq/tLeq2ke3T6qGtM8E9wsgXIEf7tHvjd0lETIppar91aBlgRUnLUK493QLsBhxT9x8D7D7udo62w/Zttt9pe7eWoIftM20fOt4TjuGXwAqS3tVSNnSx7VTgXZKWBZC0iaTHtanvbmCV+vgC4IWSZkiaDuwJnD1GeUT0OanMNdFuA2YMpR7rtk9rPbb/BBwK3AQspKRJTwPWtr2wHrMQWGu8be100p9JZ9s1P/slSR8GbgfupSzT9ANKymF+HV1xO+2/feYA/ydpYc0THwicSekFn2L7xwCjlUdE/+uwx7vI9qzRdtbc727ARsBdlMWU39iN9g3pmUAMf/9WmT3K7o/UrdVZdRt6/XtaHh8OHN7y/HvA90Y454jlEdH/ujQo4sXADbZvL3XqROC5wK2S1rG9sA7xvW28J+j0omFERF8pK3So7daBm4BtJa1Uf5HvCPwWOBnYqx6zFzDuX9Oj9oglHU5dHmkktt873pNGRCwN3ehp2r5A0gmU4bWLgUsoqc+VgeMl7U0J1q8b7znGSk3MG2+lERFNk7p3i7PtT1KWjWv1IKV3PGGjBmLbx4y2LyKiH/TIjXNtdbJCx5qUkQubAisMldveYRLbFRExYX0y509HKZRjKYnpjYD/BG4ELprENkVETFgXL9ZNuk4C8RNsHwk8bPts22+j3F0XEdHT+mX2tU7GEQ9N/rNQ0sspt/aNd2rKiIilY8luYW5UJ4H4M5IeD3yQcoPEqsD7J7VVERETJGB6r3R52+hkhY6f1od/BV40uc2JiOiegekRS/o2I9zYUXPFERE9q1cmfm+nk9TET1serwC8ipInjojoWWXURNOt6EwnqYkftj6XNBf4xaS1KCKiG3poVEQ745l9bSawQbcbEhHRbb0yTridTnLEd/PYHPGfKXfaRUT0LAHT+2R+yU5SE6u0OyYioveIaX2yGlrb7wtJZ3RSFhHRS8rioX1+Z52kFShrxs2oS4UMNXlV4IlLoW0REeM3IHfWvQPYjxJ0L+YfgfhvwNcmt1kRERPX9xfrbH8Z+LKkfev6bxERfaNcrOuPQNzJNcVHJa029ETS6pL+Y/KaFBHRHd3KEUtaTdIJkq6R9FtJz5G0hqTTJV1X/64+3nZ2EojfbvuuoSe27wTePt4TRkQsDaIEuHZbh74M/Nz204BnUuZoPwA4w/ZM4Iz6fFw6acc0tdywLWk6sNx4TxgRsVSozDXRbmtbjbQq8ALgSADbD9XO6W7A0JJyxwC7j7epnQTiUykrle4oaQdgLvDz8Z4wImJpUQcbZWTYvJZtn2HVPBm4Hfi2pEskfUvS44C1bS8EqH/XGm87O7nFeX9gH+Bdtd2nAd8c7wkjIpaGoaWSOrDI9qwx9i8DbAXsa/sCSV9mAmmIkbTtEdt+1PYRtl9r+zXAVZQJ4iMieto0td86sABYYPuC+vwESmC+VdI6APXvbeNuZycHSdpC0mcl3Qh8GrhmvCeMiFg62ueHO8kR2/4zcLOkp9aiHYGrgZOBvWrZXsCPx9vSse6s2wSYDewJ/AX4PiDbWaUjInre0KiJLtkXOFbScsAfgLfW6o+XtDdwE/C68VY+Vo74GuBXwCttXw8gKWvVRUTf6NYKHbYvBUbKI+/YjfrH+sJ4DWXKyzMlfVPSjtAnUxlFRNDxqInGjXWL80nASXWYxu6UlZvXlvQ/wEm2T1s6TYx+cudP9mu6CT1vn+9f1nQTpgb1z5p1nYyauNf2sbZfAawHXEqXh25ERHSbgOlS260XLFEu2/Ydtr9he4fJalBERLf0fWoiIqLf9UiHt60E4ogYSGX4Wn9E4gTiiBhY6RFHRDRK/b9CR0REP0tqIiKiaT20SnM7CcQRMbASiCMiGqakJiIimlMmhm+6FZ1JII6IgZVRExERDUtqIiKiQUlNREQ0TukRR0Q0qo/GEXdxSaeIiN7R7fmIJU2XdImkn9bna0g6XdJ19e/q421rAnFEDKwuz0f8PuC3Lc8PAM6wPRM4gwksmJFAHBGDq0uRWNJ6wMuBb7UU7wYcUx8fQ1lSblySI46IgdXhxboZkua1PJ9je86wYw4DPgys0lK2tu2FALYXSlprvO1MII6IgdVhCniR7Vmj16FXALfZvljS9t1p2WMlEEfEwOrSoIntgF0l7QKsAKwq6X+BWyWtU3vD6wC3jfcEyRFHxEASIKnt1o7tA22vZ3tDYDbwS9tvBE4G9qqH7QX8eLxtTY84IgbT5I8jPgQ4XtLewE3A68ZbUQJxRAysbsdh22cBZ9XHfwF27Ea9CcQRMbj65M66BOKIGFCZayIionH9MtdEAnFEDKQyaqLpVnQmgTgiBlZSExERDUuPOCKiYX0Shyf3zjpJ/yLpOEm/l3S1pFMk7TM0n+cIx39L0qbjOM8W9fbDiIiik5nXeiRST1ogVrl38CTgLNsb294U+Aiw9mivsf3vtq8ex+m2AEYMxJLS64+YgsqadWq79YLJ7BG/CHjY9hFDBbYvBX4FrCzpBEnXSDq2Bm0knSVpVn18j6SDJV0m6XxJa9fy10m6spafI2k54FPAHpIulbSHpIMkzZF0GvAdSRtK+pWk+XV7bq1r+1rHSbXHfoSkzL8RMSD6pEM8qYF4c+DiUfZtCewHbAo8mTK70XCPA863/UzgHODttfwTwEtr+a62H6pl37e9he3v1+OeBexm+w2UWZFeYnsrYA/gKy3n2Qb4IPCvwMbAq4c3pKZT5kmad/ui2zt68xHRA/okEjfV+7vQ9gLbjwKXAhuOcMxDwFAu+eKWY84Fjpb0dmD6GOc42fb99fGywDclXQH8gPIF0NqWP9h+BJgLPG94Rbbn2J5le9aaM9bs5P1FRA9QB//rBZOZP70KeO0o+x5sefzIKO142LaHH2P7nZKeTVm25FJJW4xyjntbHr8fuBV4JuXL54GWfeaxhj+PiD7VIyngtiazR/xLYPnacwVA0tbACydSqaSNbV9g+xPAImB94G4eu4TJcI8HFtYe+Jt4bE96G0kb1dzwHsCvJ9K+iOgdfZKZmLxAXHuzrwJeUoevXQUcBNwywao/L+kKSVdScseXAWcCmw5drBvhNV8H9pJ0PrAJj+0tn0eZV/RK4AbKSI+I6HPdmhh+aZjUoV22bwFeP8Kub7Yc856Wx9u3PF655fEJwAn18T9dTAPuALYeox3XAc9oKTqw5fF9tkcK3hHRzyZ/YviuyRjbiBhYfRKHp/aadbbPsv2KptsREZOkC0liSetLOlPSbyVdJel9tXwNSadLuq7+XX28zZzSgTgiBlkng9c66jMvBj5o++nAtsC761QMBwBn2J4JnFGfj0sCcUQMLKn91o7thbbn18d3A78F1gV2A46phx0D7D7ediZHHBEDaQkmhp8haV7L8zm254xYp7Qh5c7gC4C1bS+EEqwlrTXetiYQR8TA6jD1sMj2rLZ1SSsDPwT2s/23bg59S2oiIgZWN1ITpR4tSwnCx9o+sRbfKmmdun8dypw245JAHBEDqxt31tXZIY8Efmv7iy27Tgb2qo/3An483nYmNRERg6l7N3RsR5ka4QpJl9ayj1DuyD1e0t7ATcDrxnuCBOKIGEhDtzhPlO1fM3rneccJn4AE4ogYYP1yZ10CcUQMrMw1ERHRsF6Z+L2dBOKIGFz9EYcTiCNicPVJHE4gjojBJMG0PkkSJxBHxODqjzicQBwRg6tP4nACcUQMrj7JTCQQR8Sg6nji98YlEEfEQFqC+Ygbl0AcEQMrgTgiomFJTURENKl702BOugTiiBhInU783gsSiCNicPVJJE4gjoiBlVucIyIa1h9hOIuHRsQg68bqoYCknSVdK+l6SQd0u5kJxBExsNTB/9rWIU0Hvga8DNgU2FPSpt1sZwJxRAykoTvr2m0d2Aa43vYfbD8EHAfs1s22Jke8hObPv3jRisvqj023Y5gZwKKmG9HD8vm012uf0ZMmWsH8+RefuuKymtHBoStImtfyfI7tOS3P1wVubnm+AHj2RNvXKoF4Cdles+k2DCdpnu1ZTbejV+XzaW8QPyPbO3epqpH6ze5S3UBSExER7SwA1m95vh5wSzdPkEAcETG2i4CZkjaStBwwGzi5mydIamIwzGl/yJSWz6e9fEajsL1Y0nuAU4HpwFG2r+rmOWR3NdURERFLKKmJiIiGJRBHRDQsgTgiomEJxANIKvcLDf2NiN6WQDxgJMn/uAK7di2bNrSvsYb1oNE+j3xOSy6f2cRk1MSAkvRu4CXAVcBNwJG2Fzfbqt7R+oUl6cXAasC1wO9t3zfsCy2GkbQV5dbf3wILbD+Qz2z8Mo54AEl6DbAH8Frgh8DFCcKP1RKEPwS8HrgBuB/4k6TP2f5rk+3rZZK2p8xGdiNwJ/B7SYfavrvBZvW1pCYGgKTtJW3WUrQ6cAiwMyW4/L963MYNNK9nSVodeAHwItt7AN8AlqX8ksjP7RFI2gI4EHid7ZcD3wKWB15V9+czG4cE4sGwJnBPDSwAf6L0WP7d9k62H5b0XuAtkqbsr6BRgsSGwEvr4wuAe6gza+Vn9mNJWhbYAtieMjUkwG+AW4FtIZ/ZeE3Zf5SDQNKWALZ/IOlJwO8kvQL4NXAKJThvDTwN2At481RNUQzLCW9OCbg3A58CdpB0l+1fSloAbFvnFHh4qgeWoc9N0jTbDwNHS1oNeJOkhbZPlXQ5sJOkxwN/m+qf2XgkEPe31wDbSPqw7UslfQY4EngDcDhlRYHPAHcAb+n2/fH9ZFhOeNdafAFlFq0LgK9L+hWwA/DKOgH4lNYShF8J7FpXqjjE9mGSHgS+LemHlF8VX09effwyaqIPDevdHQQ8C/iE7UvqaIn3AP9me76kFYBHp2pgGaEnfBSwHbAx5XPbmpIbFmW43/W2bx6luimn/sI6iPKL6mDK57WT7askfYByHeIE23Nqr/nR5lrbv5Ij7jPDhwjZPgg4FzhY0pa2vwZ8GThN0izbD0zhILxSSxBekXLh0rYftn0N5XN7IjDT9tW2z0wQ/gdJKwFPAfYGZgIrA8cA50razPYXgWOBN0vaPkF4/JKa6CPDendvoIx9vcX2IfU61Kckfdz2EZIeogwtmpLqL4F3SLqIkiN/HvA2YGgV3kNt3yjpJkrv+J++5KYySTtRersHAytQLv6+zfbvJL0cOFPSBsDPgEeA3zfW2AGQQNxHWoLwB4CXU3ojH5e0Yg3GHwe+Iuk9to9qsq1NqzcY/Ab4JWUUyXNtPyrpaMrCj6dI+hmwOyXgTPkr/i054U0o6a0DbP9F0qqUGzdmSFqLEnx/bPsB4AFJx071z26ikproA5LWqlfxqf8oNrO9I2XBxz8DJ9b83KeBnzC1e8JD82xMA+YBx1N6dFvXQ84GPk75nKZRLsxd30BTe8bQkMYahJ8EvIOyHNDQxTfXx2+lrGB8ju0Lhj7rBOGJy8W6Hlb/Q18b+D7wP8BJlOBxbP0rysD6hyS9FbjI9pVNtbdpw1I3WwJ/sP1XSc+mBJCP2J5bf3ZfZHvKfmENqV/wzwdup3xhzaTk0l8LXAx8z/ZCSasAKwGr2b62qfYOqvSIe5ztPwOHAm8CdrF9P/BjSl7zSzUIv4Vy99zfGmtow4YF4f+gBN6fSdqH0jN+C+WC5uGUIX6rj1bXFLNs3b5C+aK/2PaJwA8oC2a+TtK6tu+2fWuC8ORIjriHtfzke5TSG/mOpLdTrvavDHyjjn19NqVnfFMzLW1eSxDejXJh7hmUW5VfAaxg+yt1KNazgcNs/6GxxvaI+uV1r6Qbgc2B8ynpLmyfJGkxsAuwh6SvTtXRN0tDUhM9TtJsYD/KvfyzgVcDX7D9I0mbUn7V3Gn7T821sjfUC0nfANa1vU0t24Vyce73lBno/tJgE3tGy4W5F1Nm53sEmEX58jrF9omS1qCkKM5OT3hyJTXRYyStPaxoXeBc2wttfwk4AjiypiOus33lVA3CrXNHSFrG9m3AfwF3SToYwPYpwM8pn2PGuVY1CO9KSUk8xfbvKRcy5wO7SDoE+C5weoLw5EuPuIdIehpwNXAYcE29W2lXYEfKuNeb63E/AR6izB1xb1Pt7RWS3kG58eB24ARgLcrwqz/Y/kQ95nH5rP5B0gzKtYa31rHBz6CkJa6l3HH4JuC7tk9usJlTRnLEveVe4DzKbFavkfQsyjwImwGvlvRnyoWVe4D9p2pgab2Vto4WeSPl7q+LKFf8j6DcXfhxSR+z/Rngvqba26OWpXxWL603uKxMmWfjQ7aPlvQz24/kJpelI4G4h9i+WdKFwFaUGzZmU6Yc3LBuz6OM6fzPqXphTtLzgE0kXW57HuUi039QpmG8ADjCZdrPK4BPArdBxrq25ISfBvylDkn7KiX4Hm/755JeT5mJ7nvAYsjntrQkEPeIlp7H/sB3KD8TF1AunpxICcQLgP+2fUtT7WySpJ2B/wa+BKxai/9ISeU8YnunetxHKWmJuU20s9cM/YKoF+aOAX5df119wfaP6jE7AJ8APpjREUtfAnGPqL0VUW7SuB74IqVn/L46QuKpwG1T9SYESS8EvkqZVe6Cll2rUsZPf6NOUrML8Dpgz6Xfyt4y9OVeg/BzKKuR7E5JSewKHCTpC5Qv+I9Sbmk+tbEGT2G5WNeDatD9FXB4vW15ypO0H+X76sstZf9FmXt5MeUK/3Moy/a81/YVTbSzV9QJebYH5lKGpl0MrG57w7p/M+CVwDOBDwH32r4rOeFmZPhaD6rDhfYHptde3pTVMkRtY8qSUEPlLwM2oIxzXUxZbWNP4DVTPQhXywBXUILvo5Q5mO+X9C0Al0UCfkaZzOcJtu+q5QnCDUgg7l3nUYYRTWktgeFHwLNVlnEH+AVlWsZ5lN7wg7bvtH1HA83sKbVX+wdKkD1O0qdt30e5YeN5kr4BUL+wvmj78gabGyQQ9yyXicv3qP+Aotx+ey4wW9I2LpO7PyRpT0pe+Lxmm9cb6o0trsH4AeC9wLPqML57KdcddpH0bQDb9zTZ3iiSI46+IWldynjhHYBL+McsYbvbvrrJtjWt3t69qF6Y24lyUe7nlOk+n0yZve9M2/8t6XHALNtnN9bgeIwE4ugrKksebUUZ1vcn4Czb1zXbqmbVOao/TRlxM5cynO8Myuf0C2AOZTjkd4DTbH+qvi4X5npEAnFEH1NZTWN/yqRGm1Omrvys7Z9K2p5yU9D1wDcpwfgJti9sprUxmgTiiD5VZ987npJ2+BEl0B5C+Xe9cz3m+cDbKRfuPmf7kWZaG2NJII7oQzUdcRJlBY0jW8r/FXgfZd6S/eqFuxdSbmuesqu39LqMmojoT/dTcuQnAEhaFv4+JO0LwL8Ah9eysxOEe1sCcUR/ehywJWUiKOpER9Prvr8AlwGr1vRF9LgE4og+VO+EO5wyXeoWw3bPokydesBUH9bXLxKII/rXScBC4J119rRHJW1HSU18d6rO0tePcrEuoo/VpbVeT5mTeT5lTo5Dhqa3jP6QQBwxAGpAfhRY3vaC3KzRXxKIIyIalhxxRETDEogjIhqWQBwR0bAE4oiIhiUQR0Q0LIE4xkXSI5IulXSlpB9MZG09SUdLem19/K2xbsuVtL2k547jHDdKmjFC+dskXSHp8vpedlvSutucd8z3EwFlgcGI8bjf9hYAko4F3gl8cWinpOnjmXLR9r+3OWR74B7gN0ta93CS1qMsI7+V7b9KWpmWBUq7oYP3E5EecXTFr4Cn1N7qmZK+B1whabqkz0u6qPY43wFlZQhJX5V0taSfAWsNVSTpLEmz6uOdJc2XdJmkMyRtSAn476+98edLWlPSD+s5Lqq3+CLpCZJOk3RJXSxT/LO1gLspgR3b99i+oaUdh0n6Te0pb1PLHyfpqHquS4Z60PW9HtrSu953hPezk6Tz6nv6QQ38SDqkfhaXSzq0u//XRD9IjzgmRNIywMso66MBbANsbvsGSfsAf7W9taTlgXMlnUaZNeypwL8CawNXA0cNq3dNyqoSL6h1rWH7DklHAPfYPrQe9z3gS7Z/LWkD4FTg6cAngV/b/pSklwP7jND8y4BbgRsknQGcaPsnLfsfZ/u5kl5Q27c5pQf9S9tvk7QacKGkXwBvBjYCtrS9WNIaw97PDOBjwItt3ytpf+ADkr4KvAp4Wp07eLUOP/oYIAnEMV4rSrq0Pv4VcCTwXODCoV4lsBPwjKH8L/B4YCbwAmBuTV3cIumXI9S/LXDOUF227xilHS8GNpX+3uFdVdIq9Ryvrq/9maQ7h7/Q9iOSdga2BnYEviTpWbYPqofMrcedI2nVGiR3AnaV9KF6zArABrUdR9hePEp7twU2pXwZASxHWXn6b8ADwLfqr4OfjvI+Y4AlEMd4/T1HPKQGmHtbi4B9bZ867LhdgHb31quDY6Ck155j+/4R2tL29XU+hgspPdvTgW8DBw3tHn54bddrbF877Hzt2ivgdNt7/tOOkvbYkbK+3Hsoq1THFJIccUymU4F3qa4eIWkTlaXczwFm17zqOsCLRnjtecALJW1UXzv0U/9uYJWW406jBC/qcVvUh+cA/1bLXgasPvwEkp4oaauWoi2AP7Y836Me9zxKiuWv9T3tWwMvkrZsacc7a6qmtb1Dzge2k/SUun+l+nmsDDze9inAfrUNMcWkRxyT6VvAhsD8GrhuB3anzKO7A3AF8Dvg7OEvtH17zTGfKGkacBvwEuAnwAn1Itm+wHuBr0m6nPLf8zmUC3r/CcyVNL/Wf9MI7VsWOFTSEynpgdvra4fcKek3wKrA22rZpynL1V9e39ONwCvqe92klj9MyW9/ddj7eUtt0/K1+GOUL5YfS1qB0mt+/6ifZgyszL4WMQJJZwEfsj2v6bbE4EtqIiKiYekRR0Q0LD3iiIiGJRBHRDQsgTgiomEJxBERDUsgjoho2P8HcuuOwBgzx8YAAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "#Pytorch doesn't have a built-in confusion matrix metric, so we'll use SciKit-Learn\n", "from sklearn.metrics import confusion_matrix\n", "import numpy as np\n", "\n", "# Set the model to evaluate mode\n", "model.eval()\n", "\n", "# Get predictions for the test data\n", "x = torch.Tensor(x_test).float()\n", "_, predicted = torch.max(model(x).data, 1)\n", "\n", "# Plot the confusion matrix\n", "cm = confusion_matrix(y_test, predicted.numpy())\n", "plt.imshow(cm, interpolation=\"nearest\", cmap=plt.cm.Blues)\n", "plt.colorbar()\n", "tick_marks = np.arange(len(penguin_classes))\n", "plt.xticks(tick_marks, penguin_classes, rotation=45)\n", "plt.yticks(tick_marks, penguin_classes)\n", "plt.xlabel(\"Predicted Species\")\n", "plt.ylabel(\"Actual Species\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "model saved as penguin_classifier.pt\n" ] } ], "source": [ "# Save the model weights\n", "# model_file = '/User/johnnydevriese/projects/models/penguin_classifier.pt'\n", "model_file = 'penguin_classifier.pt'\n", "torch.save(model.state_dict(), f=model_file)\n", "del model\n", "print('model saved as', model_file)" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "New sample: [[50.4, 15.3, 20, 50]]\n", "Prediction: Gentoo\n" ] } ], "source": [ "# New penguin features\n", "x_new = [[50.4,15.3,20,50]]\n", "print ('New sample: {}'.format(x_new))\n", "\n", "# Create a new model class and load weights\n", "model = PenguinNet()\n", "model.load_state_dict(torch.load(model_file))\n", "\n", "# Set model to evaluation mode\n", "model.eval()\n", "\n", "# Get a prediction for the new data sample\n", "x = torch.Tensor(x_new).float()\n", "_, predicted = torch.max(model(x).data, 1)\n", "\n", "print('Prediction:',penguin_classes[predicted.item()])" ] } ], "metadata": { "interpreter": { "hash": "16c7f1dc46b458d69b8d4b83cad879badbf4dbe0bbfb50262ef6f7a4b6b16937" }, "kernelspec": { "display_name": "Python 3.9.7 64-bit ('base': conda)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.12" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }