Delete models
Browse files- models/FullyConectedModels/Grid_SearchCV.ipynb +0 -1213
- models/FullyConectedModels/gridsearchcv/grid_model_1.csv +0 -49
- models/FullyConectedModels/model.py +0 -120
- models/FullyConectedModels/parseval.py +0 -83
- models/Parseval_Networks/README.md +0 -3
- models/Parseval_Networks/constraint.py +0 -81
- models/Parseval_Networks/convexity_constraint.py +0 -53
- models/Parseval_Networks/parsevalnet.py +0 -328
- models/README.md +0 -15
- models/_utility.py +0 -109
- models/wideresnet/wresnet.py +0 -329
models/FullyConectedModels/Grid_SearchCV.ipynb
DELETED
|
@@ -1,1213 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"nbformat": 4,
|
| 3 |
-
"nbformat_minor": 0,
|
| 4 |
-
"metadata": {
|
| 5 |
-
"accelerator": "GPU",
|
| 6 |
-
"colab": {
|
| 7 |
-
"name": "Grid_SearchCV.ipynb",
|
| 8 |
-
"provenance": [],
|
| 9 |
-
"collapsed_sections": [],
|
| 10 |
-
"toc_visible": true
|
| 11 |
-
},
|
| 12 |
-
"kernelspec": {
|
| 13 |
-
"display_name": "Python 3",
|
| 14 |
-
"language": "python",
|
| 15 |
-
"name": "python3"
|
| 16 |
-
},
|
| 17 |
-
"language_info": {
|
| 18 |
-
"codemirror_mode": {
|
| 19 |
-
"name": "ipython",
|
| 20 |
-
"version": 3
|
| 21 |
-
},
|
| 22 |
-
"file_extension": ".py",
|
| 23 |
-
"mimetype": "text/x-python",
|
| 24 |
-
"name": "python",
|
| 25 |
-
"nbconvert_exporter": "python",
|
| 26 |
-
"pygments_lexer": "ipython3",
|
| 27 |
-
"version": "3.7.4"
|
| 28 |
-
}
|
| 29 |
-
},
|
| 30 |
-
"cells": [
|
| 31 |
-
{
|
| 32 |
-
"cell_type": "markdown",
|
| 33 |
-
"metadata": {
|
| 34 |
-
"id": "o7mJMiThKvtT"
|
| 35 |
-
},
|
| 36 |
-
"source": [
|
| 37 |
-
"# <font color=\"purple\"><b>Grid Search CV Algorithm for Fully Connected Networks</b></font>"
|
| 38 |
-
]
|
| 39 |
-
},
|
| 40 |
-
{
|
| 41 |
-
"cell_type": "markdown",
|
| 42 |
-
"metadata": {
|
| 43 |
-
"id": "kUrakNvqKvtU"
|
| 44 |
-
},
|
| 45 |
-
"source": [
|
| 46 |
-
"Using the grid search CV algorithm, the hyperparameters of this model is sought.\n",
|
| 47 |
-
"<li><b> Learning Rate:</b> 0.1, 0.01</li>\n",
|
| 48 |
-
"<li><b> Regularization Penalty:</b>0.01, 0.001, 0.0001</li>\n",
|
| 49 |
-
"<li><b> Batch Size:</b> 64, 128</li>\n",
|
| 50 |
-
"<li><b> Epochs:</b> 50, 100, 150</li>"
|
| 51 |
-
]
|
| 52 |
-
},
|
| 53 |
-
{
|
| 54 |
-
"cell_type": "markdown",
|
| 55 |
-
"metadata": {
|
| 56 |
-
"id": "9rFQbEcDKvtV"
|
| 57 |
-
},
|
| 58 |
-
"source": [
|
| 59 |
-
"## <font color=\"blue\">Import Libraries</font>"
|
| 60 |
-
]
|
| 61 |
-
},
|
| 62 |
-
{
|
| 63 |
-
"cell_type": "code",
|
| 64 |
-
"metadata": {
|
| 65 |
-
"id": "6nhsKKJZ02AK"
|
| 66 |
-
},
|
| 67 |
-
"source": [
|
| 68 |
-
"import gzip\n",
|
| 69 |
-
"import pickle\n",
|
| 70 |
-
"import numpy as np\n",
|
| 71 |
-
"import pandas as pd\n",
|
| 72 |
-
"import numpy as np\n",
|
| 73 |
-
"from sklearn.preprocessing import LabelEncoder\n",
|
| 74 |
-
"from tensorflow.keras.utils import to_categorical\n",
|
| 75 |
-
"from tensorflow.keras import backend as K\n",
|
| 76 |
-
"from itertools import product\n",
|
| 77 |
-
"from sklearn.model_selection import train_test_split\n",
|
| 78 |
-
"from sklearn.model_selection import KFold\n",
|
| 79 |
-
"\n",
|
| 80 |
-
"from tensorflow.keras.regularizers import l2\n",
|
| 81 |
-
"from tensorflow.keras.optimizers import SGD\n",
|
| 82 |
-
"import tensorflow\n",
|
| 83 |
-
"import json\n",
|
| 84 |
-
"import cv2\n",
|
| 85 |
-
"import io\n",
|
| 86 |
-
"from sklearn.metrics import accuracy_score\n",
|
| 87 |
-
"from sklearn.metrics import precision_score\n",
|
| 88 |
-
"from sklearn.metrics import recall_score\n",
|
| 89 |
-
"try:\n",
|
| 90 |
-
" to_unicode = unicode\n",
|
| 91 |
-
"except NameError:\n",
|
| 92 |
-
" to_unicode = str\n",
|
| 93 |
-
"from sklearn.preprocessing import LabelEncoder\n",
|
| 94 |
-
"from tensorflow.keras.utils import to_categorical"
|
| 95 |
-
],
|
| 96 |
-
"execution_count": null,
|
| 97 |
-
"outputs": []
|
| 98 |
-
},
|
| 99 |
-
{
|
| 100 |
-
"cell_type": "code",
|
| 101 |
-
"metadata": {
|
| 102 |
-
"id": "TdoFKvR-m04D"
|
| 103 |
-
},
|
| 104 |
-
"source": [
|
| 105 |
-
"!pip install hickle\n",
|
| 106 |
-
"import hickle as hkl"
|
| 107 |
-
],
|
| 108 |
-
"execution_count": null,
|
| 109 |
-
"outputs": []
|
| 110 |
-
},
|
| 111 |
-
{
|
| 112 |
-
"cell_type": "code",
|
| 113 |
-
"metadata": {
|
| 114 |
-
"id": "XqmnAXulmtWm"
|
| 115 |
-
},
|
| 116 |
-
"source": [
|
| 117 |
-
"data = hkl.load(\"data.hkl\")\n",
|
| 118 |
-
"X_train, X_test, Y_train, y_test = data['xtrain'], data['xtest'], data['ytrain'], data['ytest']\n",
|
| 119 |
-
"x_train, x_val, y_train, y_val = train_test_split(X_train, Y_train, test_size=0.1)"
|
| 120 |
-
],
|
| 121 |
-
"execution_count": null,
|
| 122 |
-
"outputs": []
|
| 123 |
-
},
|
| 124 |
-
{
|
| 125 |
-
"cell_type": "code",
|
| 126 |
-
"metadata": {
|
| 127 |
-
"id": "u4zwcqNAnEoE"
|
| 128 |
-
},
|
| 129 |
-
"source": [
|
| 130 |
-
"from tensorflow.data import Dataset\n",
|
| 131 |
-
"import tensorflow.keras as keras\n",
|
| 132 |
-
"from tensorflow.keras.optimizers import Adam\n",
|
| 133 |
-
"from tensorflow.keras.layers import Conv2D,Input,MaxPooling2D, Dense, Dropout, MaxPool1D, Flatten, AveragePooling1D, BatchNormalization\n",
|
| 134 |
-
"from tensorflow.keras import Model\n",
|
| 135 |
-
"import numpy as np\n",
|
| 136 |
-
"import tensorflow as tf\n",
|
| 137 |
-
"from tensorflow.keras.models import Sequential\n"
|
| 138 |
-
],
|
| 139 |
-
"execution_count": null,
|
| 140 |
-
"outputs": []
|
| 141 |
-
},
|
| 142 |
-
{
|
| 143 |
-
"cell_type": "code",
|
| 144 |
-
"metadata": {
|
| 145 |
-
"id": "pW_D-_Dqm5mg"
|
| 146 |
-
},
|
| 147 |
-
"source": [
|
| 148 |
-
"def model_1(weight_decay):\n",
|
| 149 |
-
" model = Sequential()\n",
|
| 150 |
-
" model.add(Conv2D(32, kernel_size=(3, 3),activation='relu', input_shape=(32, 32, 1), kernel_regularizer=l2(weight_decay)))\n",
|
| 151 |
-
" model.add(Conv2D(64, kernel_size=(3, 3), activation='relu', kernel_regularizer=l2(weight_decay)))\n",
|
| 152 |
-
" model.add(MaxPooling2D(pool_size=(2, 2)))\n",
|
| 153 |
-
" model.add(BatchNormalization())\n",
|
| 154 |
-
" model.add(Flatten())\n",
|
| 155 |
-
" model.add(Dense(4, activation='softmax', kernel_regularizer=l2(weight_decay)))\n",
|
| 156 |
-
" return model\n",
|
| 157 |
-
"\n",
|
| 158 |
-
"\n",
|
| 159 |
-
"def model_2(weight_decay):\n",
|
| 160 |
-
" model = Sequential()\n",
|
| 161 |
-
" model.add(Conv2D(32, kernel_size=(3, 3),activation='relu', input_shape=(32, 32, 1), kernel_regularizer=l2(weight_decay)))\n",
|
| 162 |
-
" model.add(Conv2D(64, kernel_size=(3, 3), activation='relu', kernel_regularizer=l2(weight_decay)))\n",
|
| 163 |
-
" model.add(MaxPooling2D(pool_size=(2, 2)))\n",
|
| 164 |
-
" model.add(BatchNormalization())\n",
|
| 165 |
-
" model.add(Conv2D(128, kernel_size=(3, 3), activation='relu', kernel_regularizer=l2(weight_decay)))\n",
|
| 166 |
-
" model.add(MaxPooling2D(pool_size=(2, 2)))\n",
|
| 167 |
-
" model.add(BatchNormalization())\n",
|
| 168 |
-
" model.add(Flatten())\n",
|
| 169 |
-
" model.add(Dense(4, activation='softmax', kernel_regularizer=l2(weight_decay)))\n",
|
| 170 |
-
" return model\n",
|
| 171 |
-
"\n",
|
| 172 |
-
"\n",
|
| 173 |
-
"def model_3(weight_decay):\n",
|
| 174 |
-
" model = Sequential()\n",
|
| 175 |
-
" model.add(Conv2D(32, kernel_size=(3, 3),activation='relu', input_shape=(32, 32, 1),kernel_regularizer=l2(weight_decay)))\n",
|
| 176 |
-
" model.add(Conv2D(64, kernel_size=(3, 3), activation='relu',kernel_regularizer=l2(weight_decay)))\n",
|
| 177 |
-
" model.add(MaxPooling2D(pool_size=(2, 2)))\n",
|
| 178 |
-
" model.add(BatchNormalization())\n",
|
| 179 |
-
" model.add(Conv2D(128, kernel_size=(3, 3), activation='relu',kernel_regularizer=l2(weight_decay)))\n",
|
| 180 |
-
" model.add(MaxPooling2D(pool_size=(2, 2)))\n",
|
| 181 |
-
" model.add(BatchNormalization())\n",
|
| 182 |
-
" model.add(Conv2D(256, kernel_size=(3, 3), activation='relu',kernel_regularizer=l2(weight_decay)))\n",
|
| 183 |
-
" model.add(MaxPooling2D(pool_size=(2, 2)))\n",
|
| 184 |
-
" model.add(BatchNormalization())\n",
|
| 185 |
-
" model.add(Flatten())\n",
|
| 186 |
-
" model.add(Dense(4, activation='softmax',kernel_regularizer=l2(weight_decay)))\n",
|
| 187 |
-
" return model\n"
|
| 188 |
-
],
|
| 189 |
-
"execution_count": null,
|
| 190 |
-
"outputs": []
|
| 191 |
-
},
|
| 192 |
-
{
|
| 193 |
-
"cell_type": "markdown",
|
| 194 |
-
"metadata": {
|
| 195 |
-
"id": "p9X1E2wybnZj"
|
| 196 |
-
},
|
| 197 |
-
"source": [
|
| 198 |
-
"<font color=\"blue\"> The algorithm below is that ... </font>"
|
| 199 |
-
]
|
| 200 |
-
},
|
| 201 |
-
{
|
| 202 |
-
"cell_type": "code",
|
| 203 |
-
"metadata": {
|
| 204 |
-
"id": "V8U0Vk9t0l3V"
|
| 205 |
-
},
|
| 206 |
-
"source": [
|
| 207 |
-
"from sklearn.metrics import confusion_matrix\n",
|
| 208 |
-
"def encoded_label(y_predict):\n",
|
| 209 |
-
" y_list = [] \n",
|
| 210 |
-
" for y_hat in y_predict:\n",
|
| 211 |
-
" y_hat = np.argmax(y_hat)\n",
|
| 212 |
-
" y_list.append(to_categorical(y_hat))\n",
|
| 213 |
-
" return y_list\n",
|
| 214 |
-
"\n",
|
| 215 |
-
"\n",
|
| 216 |
-
"def KFold_GridSearchCV(input_dim, X, Y, X_test, y_test, combinations, filename=\"log.csv\", acc_loss_json=\"hist.json\"):\n",
|
| 217 |
-
" \"\"\"Summary: Grid Search CV for 3 Folds Cross Validation\n",
|
| 218 |
-
" \"\"\"\n",
|
| 219 |
-
" res_df = pd.DataFrame(columns=['momentum','learning rate','batch size',\n",
|
| 220 |
-
" 'loss1', 'acc1','loss2', 'acc2','loss3', 'acc3', 'widing factor',\n",
|
| 221 |
-
" 'prec1', 'prec2', 'prec3', 'recall1', 'recall2', 'recall3'])\n",
|
| 222 |
-
" generator = tensorflow.keras.preprocessing.image.ImageDataGenerator(rotation_range=10,\n",
|
| 223 |
-
" width_shift_range=5./32,\n",
|
| 224 |
-
" height_shift_range=5./32,)\n",
|
| 225 |
-
" hist_dict_global = {}\n",
|
| 226 |
-
"\n",
|
| 227 |
-
" for i, combination in enumerate(combinations):\n",
|
| 228 |
-
" kf = KFold(n_splits=3, random_state=42, shuffle=False)\n",
|
| 229 |
-
" metrics_dict = {}\n",
|
| 230 |
-
" \n",
|
| 231 |
-
" for j, (train_index, test_index) in enumerate(kf.split(X)):\n",
|
| 232 |
-
" X_train, X_val = X[train_index], X[test_index]\n",
|
| 233 |
-
" y_train, y_val = Y[train_index], Y[test_index]\n",
|
| 234 |
-
" model = model_1(combination[2])\n",
|
| 235 |
-
" opt = tensorflow.keras.optimizers.SGD(learning_rate=combination[0])\n",
|
| 236 |
-
" model.compile(loss='categorical_crossentropy', optimizer=opt, metrics= ['accuracy'])\n",
|
| 237 |
-
" hist = model.fit(generator.flow(X_train, y_train, batch_size=combination[1]), steps_per_epoch=len(X_train) // combination[1], epochs=combination[3],\n",
|
| 238 |
-
" validation_data=(X_val, y_val),\n",
|
| 239 |
-
" validation_steps=len(X_val) // combination[1],)\n",
|
| 240 |
-
" loss, acc = model.evaluate(X_test, y_test)\n",
|
| 241 |
-
" #yhat_classes = encoded_label( model.predict(X_test))\n",
|
| 242 |
-
" predict = model.predict(X_test)\n",
|
| 243 |
-
" yhat_classes = np.argmax(predict, axis=1)\n",
|
| 244 |
-
" print(yhat_classes)\n",
|
| 245 |
-
" print(y_test)\n",
|
| 246 |
-
" cm = confusion_matrix(np.argmax(y_test, axis=1), yhat_classes)\n",
|
| 247 |
-
" recall = np.diag(cm) / np.sum(cm, axis = 1)\n",
|
| 248 |
-
" precision = np.diag(cm) / np.sum(cm, axis = 0)\n",
|
| 249 |
-
"\n",
|
| 250 |
-
" recall_avg = np.mean(recall)\n",
|
| 251 |
-
" precision_avg = np.mean(precision)\n",
|
| 252 |
-
" metrics_dict[j+1] = {\"loss\": loss, \"acc\": acc, \"epoch_stopped\": combination[3], \"precision_avg\":precision_avg,\n",
|
| 253 |
-
" \"avg_recall\":recall_avg}\n",
|
| 254 |
-
" graph_loss_acc = {\"id\": i, \"com\":j+1, \"val_acc\":hist.history[\"val_accuracy\"], \"train_acc\":hist.history[\"accuracy\"],\n",
|
| 255 |
-
" \"val_loss\":hist.history[\"val_loss\"], \"train_loss\":hist.history[\"loss\"], \"epoch_stopped\": combination[3], 'learning rate': combination[0],\n",
|
| 256 |
-
" 'batch size': combination[1], 'reg_penalty': combination[2]}\n",
|
| 257 |
-
"\n",
|
| 258 |
-
" # Write JSON file\n",
|
| 259 |
-
" with io.open(acc_loss_json, 'a+', encoding='utf8') as outfile:\n",
|
| 260 |
-
" str_ = json.dumps(graph_loss_acc)\n",
|
| 261 |
-
" outfile.write(to_unicode(str_))\n",
|
| 262 |
-
" row = {'momentum': combination[4],'learning rate': combination[0],\n",
|
| 263 |
-
" 'batch size': combination[1],\n",
|
| 264 |
-
" 'reg_penalty': combination[2],\n",
|
| 265 |
-
" 'epoch_stopped': metrics_dict[1][\"epoch_stopped\"],\n",
|
| 266 |
-
" 'widing factor' : 1,\n",
|
| 267 |
-
" 'loss1': metrics_dict[1][\"loss\"],\n",
|
| 268 |
-
" 'acc1': metrics_dict[1][\"acc\"],\n",
|
| 269 |
-
" 'loss2': metrics_dict[2][\"loss\"],\n",
|
| 270 |
-
" 'acc2': metrics_dict[2][\"acc\"],\n",
|
| 271 |
-
" 'loss3': metrics_dict[3][\"loss\"],\n",
|
| 272 |
-
" 'acc3': metrics_dict[3][\"acc\"],\n",
|
| 273 |
-
" 'prec1':metrics_dict[1][\"precision_avg\"],\n",
|
| 274 |
-
" 'prec2':metrics_dict[2][\"precision_avg\"],\n",
|
| 275 |
-
" 'prec3':metrics_dict[3][\"precision_avg\"],\n",
|
| 276 |
-
" 'recall1':metrics_dict[1][\"avg_recall\"],\n",
|
| 277 |
-
" 'recall2':metrics_dict[2][\"avg_recall\"],\n",
|
| 278 |
-
" 'recall3':metrics_dict[3][\"avg_recall\"]}\n",
|
| 279 |
-
" res_df = res_df.append(row , ignore_index=True)\n",
|
| 280 |
-
" res_df.to_csv(filename, sep=\";\")"
|
| 281 |
-
],
|
| 282 |
-
"execution_count": null,
|
| 283 |
-
"outputs": []
|
| 284 |
-
},
|
| 285 |
-
{
|
| 286 |
-
"cell_type": "code",
|
| 287 |
-
"metadata": {
|
| 288 |
-
"id": "tgO8xJ0Fe93I"
|
| 289 |
-
},
|
| 290 |
-
"source": [
|
| 291 |
-
"\n",
|
| 292 |
-
"if __name__ == \"__main__\":\n",
|
| 293 |
-
" learning_rate = [0.1,0.01]\n",
|
| 294 |
-
" batch_size = [64,128]\n",
|
| 295 |
-
" reg_penalty = [0,0.01, 0.001, 0.0001]\n",
|
| 296 |
-
" epochs = [50,100,150]\n",
|
| 297 |
-
" momentum = [0.9]\n",
|
| 298 |
-
" in_dim = (32,32,1)\n",
|
| 299 |
-
" grid_result = \"grid_model_1.csv\"\n",
|
| 300 |
-
" acc_loss_json = \"history.json\"\n",
|
| 301 |
-
" # create list of all different parameter combinations\n",
|
| 302 |
-
" param_grid = dict(learning_rate = learning_rate, batch_size = batch_size, \n",
|
| 303 |
-
" reg_penalty = reg_penalty, epochs = epochs, momentum=momentum)\n",
|
| 304 |
-
" combinations = list(product(*param_grid.values()))\n",
|
| 305 |
-
" KFold_GridSearchCV(in_dim,X_train,Y_train,X_test, y_test, combinations, grid_result, acc_loss_json)"
|
| 306 |
-
],
|
| 307 |
-
"execution_count": null,
|
| 308 |
-
"outputs": []
|
| 309 |
-
},
|
| 310 |
-
{
|
| 311 |
-
"cell_type": "code",
|
| 312 |
-
"metadata": {
|
| 313 |
-
"id": "iVp5lMuhpAMW"
|
| 314 |
-
},
|
| 315 |
-
"source": [
|
| 316 |
-
"data = pd.read_csv(\"grid_model_1.csv\", sep=\";\")"
|
| 317 |
-
],
|
| 318 |
-
"execution_count": null,
|
| 319 |
-
"outputs": []
|
| 320 |
-
},
|
| 321 |
-
{
|
| 322 |
-
"cell_type": "code",
|
| 323 |
-
"metadata": {
|
| 324 |
-
"colab": {
|
| 325 |
-
"base_uri": "https://localhost:8080/",
|
| 326 |
-
"height": 244
|
| 327 |
-
},
|
| 328 |
-
"id": "rru6HHyrsgbl",
|
| 329 |
-
"outputId": "01496643-5865-41a6-85c2-f69242b5912b"
|
| 330 |
-
},
|
| 331 |
-
"source": [
|
| 332 |
-
"data.head(5)"
|
| 333 |
-
],
|
| 334 |
-
"execution_count": null,
|
| 335 |
-
"outputs": [
|
| 336 |
-
{
|
| 337 |
-
"output_type": "execute_result",
|
| 338 |
-
"data": {
|
| 339 |
-
"text/html": [
|
| 340 |
-
"<div>\n",
|
| 341 |
-
"<style scoped>\n",
|
| 342 |
-
" .dataframe tbody tr th:only-of-type {\n",
|
| 343 |
-
" vertical-align: middle;\n",
|
| 344 |
-
" }\n",
|
| 345 |
-
"\n",
|
| 346 |
-
" .dataframe tbody tr th {\n",
|
| 347 |
-
" vertical-align: top;\n",
|
| 348 |
-
" }\n",
|
| 349 |
-
"\n",
|
| 350 |
-
" .dataframe thead th {\n",
|
| 351 |
-
" text-align: right;\n",
|
| 352 |
-
" }\n",
|
| 353 |
-
"</style>\n",
|
| 354 |
-
"<table border=\"1\" class=\"dataframe\">\n",
|
| 355 |
-
" <thead>\n",
|
| 356 |
-
" <tr style=\"text-align: right;\">\n",
|
| 357 |
-
" <th></th>\n",
|
| 358 |
-
" <th>momentum</th>\n",
|
| 359 |
-
" <th>learning rate</th>\n",
|
| 360 |
-
" <th>batch size</th>\n",
|
| 361 |
-
" <th>loss1</th>\n",
|
| 362 |
-
" <th>acc1</th>\n",
|
| 363 |
-
" <th>loss2</th>\n",
|
| 364 |
-
" <th>acc2</th>\n",
|
| 365 |
-
" <th>loss3</th>\n",
|
| 366 |
-
" <th>acc3</th>\n",
|
| 367 |
-
" <th>widing factor</th>\n",
|
| 368 |
-
" <th>prec1</th>\n",
|
| 369 |
-
" <th>prec2</th>\n",
|
| 370 |
-
" <th>prec3</th>\n",
|
| 371 |
-
" <th>recall1</th>\n",
|
| 372 |
-
" <th>recall2</th>\n",
|
| 373 |
-
" <th>recall3</th>\n",
|
| 374 |
-
" <th>epoch_stopped</th>\n",
|
| 375 |
-
" <th>reg_penalty</th>\n",
|
| 376 |
-
" </tr>\n",
|
| 377 |
-
" </thead>\n",
|
| 378 |
-
" <tbody>\n",
|
| 379 |
-
" <tr>\n",
|
| 380 |
-
" <th>0</th>\n",
|
| 381 |
-
" <td>0.9</td>\n",
|
| 382 |
-
" <td>0.1</td>\n",
|
| 383 |
-
" <td>64.0</td>\n",
|
| 384 |
-
" <td>1.078664</td>\n",
|
| 385 |
-
" <td>0.541012</td>\n",
|
| 386 |
-
" <td>1.087765</td>\n",
|
| 387 |
-
" <td>0.560209</td>\n",
|
| 388 |
-
" <td>1.140540</td>\n",
|
| 389 |
-
" <td>0.495637</td>\n",
|
| 390 |
-
" <td>1.0</td>\n",
|
| 391 |
-
" <td>0.545625</td>\n",
|
| 392 |
-
" <td>0.570303</td>\n",
|
| 393 |
-
" <td>0.540447</td>\n",
|
| 394 |
-
" <td>0.549296</td>\n",
|
| 395 |
-
" <td>0.561227</td>\n",
|
| 396 |
-
" <td>0.506039</td>\n",
|
| 397 |
-
" <td>50.0</td>\n",
|
| 398 |
-
" <td>0.00</td>\n",
|
| 399 |
-
" </tr>\n",
|
| 400 |
-
" <tr>\n",
|
| 401 |
-
" <th>1</th>\n",
|
| 402 |
-
" <td>0.9</td>\n",
|
| 403 |
-
" <td>0.1</td>\n",
|
| 404 |
-
" <td>64.0</td>\n",
|
| 405 |
-
" <td>1.090459</td>\n",
|
| 406 |
-
" <td>0.544503</td>\n",
|
| 407 |
-
" <td>1.052103</td>\n",
|
| 408 |
-
" <td>0.568935</td>\n",
|
| 409 |
-
" <td>0.991915</td>\n",
|
| 410 |
-
" <td>0.586387</td>\n",
|
| 411 |
-
" <td>1.0</td>\n",
|
| 412 |
-
" <td>0.567745</td>\n",
|
| 413 |
-
" <td>0.581882</td>\n",
|
| 414 |
-
" <td>0.587054</td>\n",
|
| 415 |
-
" <td>0.549632</td>\n",
|
| 416 |
-
" <td>0.571559</td>\n",
|
| 417 |
-
" <td>0.584749</td>\n",
|
| 418 |
-
" <td>100.0</td>\n",
|
| 419 |
-
" <td>0.00</td>\n",
|
| 420 |
-
" </tr>\n",
|
| 421 |
-
" <tr>\n",
|
| 422 |
-
" <th>2</th>\n",
|
| 423 |
-
" <td>0.9</td>\n",
|
| 424 |
-
" <td>0.1</td>\n",
|
| 425 |
-
" <td>64.0</td>\n",
|
| 426 |
-
" <td>1.203270</td>\n",
|
| 427 |
-
" <td>0.549738</td>\n",
|
| 428 |
-
" <td>1.024475</td>\n",
|
| 429 |
-
" <td>0.607330</td>\n",
|
| 430 |
-
" <td>1.099551</td>\n",
|
| 431 |
-
" <td>0.542757</td>\n",
|
| 432 |
-
" <td>1.0</td>\n",
|
| 433 |
-
" <td>0.597314</td>\n",
|
| 434 |
-
" <td>0.629707</td>\n",
|
| 435 |
-
" <td>0.583353</td>\n",
|
| 436 |
-
" <td>0.558664</td>\n",
|
| 437 |
-
" <td>0.612029</td>\n",
|
| 438 |
-
" <td>0.532126</td>\n",
|
| 439 |
-
" <td>150.0</td>\n",
|
| 440 |
-
" <td>0.00</td>\n",
|
| 441 |
-
" </tr>\n",
|
| 442 |
-
" <tr>\n",
|
| 443 |
-
" <th>3</th>\n",
|
| 444 |
-
" <td>0.9</td>\n",
|
| 445 |
-
" <td>0.1</td>\n",
|
| 446 |
-
" <td>64.0</td>\n",
|
| 447 |
-
" <td>1.331875</td>\n",
|
| 448 |
-
" <td>0.450262</td>\n",
|
| 449 |
-
" <td>1.357389</td>\n",
|
| 450 |
-
" <td>0.425829</td>\n",
|
| 451 |
-
" <td>1.136434</td>\n",
|
| 452 |
-
" <td>0.542757</td>\n",
|
| 453 |
-
" <td>1.0</td>\n",
|
| 454 |
-
" <td>0.520104</td>\n",
|
| 455 |
-
" <td>0.556441</td>\n",
|
| 456 |
-
" <td>0.541896</td>\n",
|
| 457 |
-
" <td>0.467984</td>\n",
|
| 458 |
-
" <td>0.441258</td>\n",
|
| 459 |
-
" <td>0.549730</td>\n",
|
| 460 |
-
" <td>50.0</td>\n",
|
| 461 |
-
" <td>0.01</td>\n",
|
| 462 |
-
" </tr>\n",
|
| 463 |
-
" <tr>\n",
|
| 464 |
-
" <th>4</th>\n",
|
| 465 |
-
" <td>0.9</td>\n",
|
| 466 |
-
" <td>0.1</td>\n",
|
| 467 |
-
" <td>64.0</td>\n",
|
| 468 |
-
" <td>1.286225</td>\n",
|
| 469 |
-
" <td>0.490401</td>\n",
|
| 470 |
-
" <td>1.218342</td>\n",
|
| 471 |
-
" <td>0.539267</td>\n",
|
| 472 |
-
" <td>1.207052</td>\n",
|
| 473 |
-
" <td>0.537522</td>\n",
|
| 474 |
-
" <td>1.0</td>\n",
|
| 475 |
-
" <td>0.552838</td>\n",
|
| 476 |
-
" <td>0.569478</td>\n",
|
| 477 |
-
" <td>0.586226</td>\n",
|
| 478 |
-
" <td>0.494587</td>\n",
|
| 479 |
-
" <td>0.541650</td>\n",
|
| 480 |
-
" <td>0.547532</td>\n",
|
| 481 |
-
" <td>100.0</td>\n",
|
| 482 |
-
" <td>0.01</td>\n",
|
| 483 |
-
" </tr>\n",
|
| 484 |
-
" </tbody>\n",
|
| 485 |
-
"</table>\n",
|
| 486 |
-
"</div>"
|
| 487 |
-
],
|
| 488 |
-
"text/plain": [
|
| 489 |
-
" momentum learning rate batch size ... recall3 epoch_stopped reg_penalty\n",
|
| 490 |
-
"0 0.9 0.1 64.0 ... 0.506039 50.0 0.00\n",
|
| 491 |
-
"1 0.9 0.1 64.0 ... 0.584749 100.0 0.00\n",
|
| 492 |
-
"2 0.9 0.1 64.0 ... 0.532126 150.0 0.00\n",
|
| 493 |
-
"3 0.9 0.1 64.0 ... 0.549730 50.0 0.01\n",
|
| 494 |
-
"4 0.9 0.1 64.0 ... 0.547532 100.0 0.01\n",
|
| 495 |
-
"\n",
|
| 496 |
-
"[5 rows x 18 columns]"
|
| 497 |
-
]
|
| 498 |
-
},
|
| 499 |
-
"metadata": {
|
| 500 |
-
"tags": []
|
| 501 |
-
},
|
| 502 |
-
"execution_count": 17
|
| 503 |
-
}
|
| 504 |
-
]
|
| 505 |
-
},
|
| 506 |
-
{
|
| 507 |
-
"cell_type": "code",
|
| 508 |
-
"metadata": {
|
| 509 |
-
"id": "LeVPJJQQt36n"
|
| 510 |
-
},
|
| 511 |
-
"source": [
|
| 512 |
-
"data[\"loss_mean\"] = (data[\"loss1\"]+data[\"loss2\"]+data[\"loss3\"])/3\n",
|
| 513 |
-
"data[\"acc_mean\"] = (data[\"acc1\"]+data[\"acc2\"]+data[\"acc3\"])/3"
|
| 514 |
-
],
|
| 515 |
-
"execution_count": null,
|
| 516 |
-
"outputs": []
|
| 517 |
-
},
|
| 518 |
-
{
|
| 519 |
-
"cell_type": "code",
|
| 520 |
-
"metadata": {
|
| 521 |
-
"id": "XPIKRbPMuUOr"
|
| 522 |
-
},
|
| 523 |
-
"source": [
|
| 524 |
-
"data['epoch'] = data['epoch_stopped']\n",
|
| 525 |
-
"data['weight_decay'] = data['reg_penalty']"
|
| 526 |
-
],
|
| 527 |
-
"execution_count": null,
|
| 528 |
-
"outputs": []
|
| 529 |
-
},
|
| 530 |
-
{
|
| 531 |
-
"cell_type": "code",
|
| 532 |
-
"metadata": {
|
| 533 |
-
"id": "5_2tMeO6x9Uq"
|
| 534 |
-
},
|
| 535 |
-
"source": [
|
| 536 |
-
"data['recall_mean'] = (data['recall1']+data['recall2']+data['recall3'])/3\n",
|
| 537 |
-
"data['prec_mean'] = (data['prec1']+data['prec2']+data['prec3'])/3"
|
| 538 |
-
],
|
| 539 |
-
"execution_count": null,
|
| 540 |
-
"outputs": []
|
| 541 |
-
},
|
| 542 |
-
{
|
| 543 |
-
"cell_type": "code",
|
| 544 |
-
"metadata": {
|
| 545 |
-
"id": "JA0eRctYuWl-"
|
| 546 |
-
},
|
| 547 |
-
"source": [
|
| 548 |
-
"column_list = [\"momentum\", \"learning rate\", \"epoch\",\"batch size\",\"weight_decay\",\"loss_mean\", \"acc_mean\",\"recall_mean\", \"prec_mean\"]"
|
| 549 |
-
],
|
| 550 |
-
"execution_count": null,
|
| 551 |
-
"outputs": []
|
| 552 |
-
},
|
| 553 |
-
{
|
| 554 |
-
"cell_type": "code",
|
| 555 |
-
"metadata": {
|
| 556 |
-
"colab": {
|
| 557 |
-
"base_uri": "https://localhost:8080/",
|
| 558 |
-
"height": 143
|
| 559 |
-
},
|
| 560 |
-
"id": "McsMaiLZuZaa",
|
| 561 |
-
"outputId": "d04090a8-b295-45ad-d64f-e5965741cd80"
|
| 562 |
-
},
|
| 563 |
-
"source": [
|
| 564 |
-
"data.sort_values(axis=0, by=\"loss_mean\", ascending=True)[column_list].head(3)"
|
| 565 |
-
],
|
| 566 |
-
"execution_count": null,
|
| 567 |
-
"outputs": [
|
| 568 |
-
{
|
| 569 |
-
"output_type": "execute_result",
|
| 570 |
-
"data": {
|
| 571 |
-
"text/html": [
|
| 572 |
-
"<div>\n",
|
| 573 |
-
"<style scoped>\n",
|
| 574 |
-
" .dataframe tbody tr th:only-of-type {\n",
|
| 575 |
-
" vertical-align: middle;\n",
|
| 576 |
-
" }\n",
|
| 577 |
-
"\n",
|
| 578 |
-
" .dataframe tbody tr th {\n",
|
| 579 |
-
" vertical-align: top;\n",
|
| 580 |
-
" }\n",
|
| 581 |
-
"\n",
|
| 582 |
-
" .dataframe thead th {\n",
|
| 583 |
-
" text-align: right;\n",
|
| 584 |
-
" }\n",
|
| 585 |
-
"</style>\n",
|
| 586 |
-
"<table border=\"1\" class=\"dataframe\">\n",
|
| 587 |
-
" <thead>\n",
|
| 588 |
-
" <tr style=\"text-align: right;\">\n",
|
| 589 |
-
" <th></th>\n",
|
| 590 |
-
" <th>momentum</th>\n",
|
| 591 |
-
" <th>learning rate</th>\n",
|
| 592 |
-
" <th>epoch</th>\n",
|
| 593 |
-
" <th>batch size</th>\n",
|
| 594 |
-
" <th>weight_decay</th>\n",
|
| 595 |
-
" <th>loss_mean</th>\n",
|
| 596 |
-
" <th>acc_mean</th>\n",
|
| 597 |
-
" <th>recall_mean</th>\n",
|
| 598 |
-
" <th>prec_mean</th>\n",
|
| 599 |
-
" </tr>\n",
|
| 600 |
-
" </thead>\n",
|
| 601 |
-
" <tbody>\n",
|
| 602 |
-
" <tr>\n",
|
| 603 |
-
" <th>24</th>\n",
|
| 604 |
-
" <td>0.9</td>\n",
|
| 605 |
-
" <td>0.01</td>\n",
|
| 606 |
-
" <td>50.0</td>\n",
|
| 607 |
-
" <td>64.0</td>\n",
|
| 608 |
-
" <td>0.0000</td>\n",
|
| 609 |
-
" <td>1.012071</td>\n",
|
| 610 |
-
" <td>0.612565</td>\n",
|
| 611 |
-
" <td>0.614360</td>\n",
|
| 612 |
-
" <td>0.615845</td>\n",
|
| 613 |
-
" </tr>\n",
|
| 614 |
-
" <tr>\n",
|
| 615 |
-
" <th>34</th>\n",
|
| 616 |
-
" <td>0.9</td>\n",
|
| 617 |
-
" <td>0.01</td>\n",
|
| 618 |
-
" <td>100.0</td>\n",
|
| 619 |
-
" <td>64.0</td>\n",
|
| 620 |
-
" <td>0.0001</td>\n",
|
| 621 |
-
" <td>1.026681</td>\n",
|
| 622 |
-
" <td>0.614311</td>\n",
|
| 623 |
-
" <td>0.612435</td>\n",
|
| 624 |
-
" <td>0.619908</td>\n",
|
| 625 |
-
" </tr>\n",
|
| 626 |
-
" <tr>\n",
|
| 627 |
-
" <th>8</th>\n",
|
| 628 |
-
" <td>0.9</td>\n",
|
| 629 |
-
" <td>0.10</td>\n",
|
| 630 |
-
" <td>150.0</td>\n",
|
| 631 |
-
" <td>64.0</td>\n",
|
| 632 |
-
" <td>0.0010</td>\n",
|
| 633 |
-
" <td>1.027309</td>\n",
|
| 634 |
-
" <td>0.603258</td>\n",
|
| 635 |
-
" <td>0.601262</td>\n",
|
| 636 |
-
" <td>0.606730</td>\n",
|
| 637 |
-
" </tr>\n",
|
| 638 |
-
" </tbody>\n",
|
| 639 |
-
"</table>\n",
|
| 640 |
-
"</div>"
|
| 641 |
-
],
|
| 642 |
-
"text/plain": [
|
| 643 |
-
" momentum learning rate epoch ... acc_mean recall_mean prec_mean\n",
|
| 644 |
-
"24 0.9 0.01 50.0 ... 0.612565 0.614360 0.615845\n",
|
| 645 |
-
"34 0.9 0.01 100.0 ... 0.614311 0.612435 0.619908\n",
|
| 646 |
-
"8 0.9 0.10 150.0 ... 0.603258 0.601262 0.606730\n",
|
| 647 |
-
"\n",
|
| 648 |
-
"[3 rows x 9 columns]"
|
| 649 |
-
]
|
| 650 |
-
},
|
| 651 |
-
"metadata": {
|
| 652 |
-
"tags": []
|
| 653 |
-
},
|
| 654 |
-
"execution_count": 26
|
| 655 |
-
}
|
| 656 |
-
]
|
| 657 |
-
},
|
| 658 |
-
{
|
| 659 |
-
"cell_type": "code",
|
| 660 |
-
"metadata": {
|
| 661 |
-
"id": "hCgRAh43udBu"
|
| 662 |
-
},
|
| 663 |
-
"source": [
|
| 664 |
-
"data[\"loss_na\"] = data.loc[:,[\"loss1\",\"loss2\", \"loss3\"]].isnull().sum(1)"
|
| 665 |
-
],
|
| 666 |
-
"execution_count": null,
|
| 667 |
-
"outputs": []
|
| 668 |
-
},
|
| 669 |
-
{
|
| 670 |
-
"cell_type": "code",
|
| 671 |
-
"metadata": {
|
| 672 |
-
"colab": {
|
| 673 |
-
"base_uri": "https://localhost:8080/",
|
| 674 |
-
"height": 181
|
| 675 |
-
},
|
| 676 |
-
"id": "-ptzEmSHudkM",
|
| 677 |
-
"outputId": "72a8ca57-603a-451f-8702-962b4be4f91a"
|
| 678 |
-
},
|
| 679 |
-
"source": [
|
| 680 |
-
"data.head(3)"
|
| 681 |
-
],
|
| 682 |
-
"execution_count": null,
|
| 683 |
-
"outputs": [
|
| 684 |
-
{
|
| 685 |
-
"output_type": "execute_result",
|
| 686 |
-
"data": {
|
| 687 |
-
"text/html": [
|
| 688 |
-
"<div>\n",
|
| 689 |
-
"<style scoped>\n",
|
| 690 |
-
" .dataframe tbody tr th:only-of-type {\n",
|
| 691 |
-
" vertical-align: middle;\n",
|
| 692 |
-
" }\n",
|
| 693 |
-
"\n",
|
| 694 |
-
" .dataframe tbody tr th {\n",
|
| 695 |
-
" vertical-align: top;\n",
|
| 696 |
-
" }\n",
|
| 697 |
-
"\n",
|
| 698 |
-
" .dataframe thead th {\n",
|
| 699 |
-
" text-align: right;\n",
|
| 700 |
-
" }\n",
|
| 701 |
-
"</style>\n",
|
| 702 |
-
"<table border=\"1\" class=\"dataframe\">\n",
|
| 703 |
-
" <thead>\n",
|
| 704 |
-
" <tr style=\"text-align: right;\">\n",
|
| 705 |
-
" <th></th>\n",
|
| 706 |
-
" <th>momentum</th>\n",
|
| 707 |
-
" <th>learning rate</th>\n",
|
| 708 |
-
" <th>batch size</th>\n",
|
| 709 |
-
" <th>loss1</th>\n",
|
| 710 |
-
" <th>acc1</th>\n",
|
| 711 |
-
" <th>loss2</th>\n",
|
| 712 |
-
" <th>acc2</th>\n",
|
| 713 |
-
" <th>loss3</th>\n",
|
| 714 |
-
" <th>acc3</th>\n",
|
| 715 |
-
" <th>widing factor</th>\n",
|
| 716 |
-
" <th>prec1</th>\n",
|
| 717 |
-
" <th>prec2</th>\n",
|
| 718 |
-
" <th>prec3</th>\n",
|
| 719 |
-
" <th>recall1</th>\n",
|
| 720 |
-
" <th>recall2</th>\n",
|
| 721 |
-
" <th>recall3</th>\n",
|
| 722 |
-
" <th>epoch_stopped</th>\n",
|
| 723 |
-
" <th>reg_penalty</th>\n",
|
| 724 |
-
" <th>loss_mean</th>\n",
|
| 725 |
-
" <th>acc_mean</th>\n",
|
| 726 |
-
" <th>epoch</th>\n",
|
| 727 |
-
" <th>weight_decay</th>\n",
|
| 728 |
-
" <th>loss_na</th>\n",
|
| 729 |
-
" </tr>\n",
|
| 730 |
-
" </thead>\n",
|
| 731 |
-
" <tbody>\n",
|
| 732 |
-
" <tr>\n",
|
| 733 |
-
" <th>0</th>\n",
|
| 734 |
-
" <td>0.9</td>\n",
|
| 735 |
-
" <td>0.1</td>\n",
|
| 736 |
-
" <td>64.0</td>\n",
|
| 737 |
-
" <td>1.078664</td>\n",
|
| 738 |
-
" <td>0.541012</td>\n",
|
| 739 |
-
" <td>1.087765</td>\n",
|
| 740 |
-
" <td>0.560209</td>\n",
|
| 741 |
-
" <td>1.140540</td>\n",
|
| 742 |
-
" <td>0.495637</td>\n",
|
| 743 |
-
" <td>1.0</td>\n",
|
| 744 |
-
" <td>0.545625</td>\n",
|
| 745 |
-
" <td>0.570303</td>\n",
|
| 746 |
-
" <td>0.540447</td>\n",
|
| 747 |
-
" <td>0.549296</td>\n",
|
| 748 |
-
" <td>0.561227</td>\n",
|
| 749 |
-
" <td>0.506039</td>\n",
|
| 750 |
-
" <td>50.0</td>\n",
|
| 751 |
-
" <td>0.0</td>\n",
|
| 752 |
-
" <td>1.102323</td>\n",
|
| 753 |
-
" <td>0.532286</td>\n",
|
| 754 |
-
" <td>50.0</td>\n",
|
| 755 |
-
" <td>0.0</td>\n",
|
| 756 |
-
" <td>0</td>\n",
|
| 757 |
-
" </tr>\n",
|
| 758 |
-
" <tr>\n",
|
| 759 |
-
" <th>1</th>\n",
|
| 760 |
-
" <td>0.9</td>\n",
|
| 761 |
-
" <td>0.1</td>\n",
|
| 762 |
-
" <td>64.0</td>\n",
|
| 763 |
-
" <td>1.090459</td>\n",
|
| 764 |
-
" <td>0.544503</td>\n",
|
| 765 |
-
" <td>1.052103</td>\n",
|
| 766 |
-
" <td>0.568935</td>\n",
|
| 767 |
-
" <td>0.991915</td>\n",
|
| 768 |
-
" <td>0.586387</td>\n",
|
| 769 |
-
" <td>1.0</td>\n",
|
| 770 |
-
" <td>0.567745</td>\n",
|
| 771 |
-
" <td>0.581882</td>\n",
|
| 772 |
-
" <td>0.587054</td>\n",
|
| 773 |
-
" <td>0.549632</td>\n",
|
| 774 |
-
" <td>0.571559</td>\n",
|
| 775 |
-
" <td>0.584749</td>\n",
|
| 776 |
-
" <td>100.0</td>\n",
|
| 777 |
-
" <td>0.0</td>\n",
|
| 778 |
-
" <td>1.044826</td>\n",
|
| 779 |
-
" <td>0.566609</td>\n",
|
| 780 |
-
" <td>100.0</td>\n",
|
| 781 |
-
" <td>0.0</td>\n",
|
| 782 |
-
" <td>0</td>\n",
|
| 783 |
-
" </tr>\n",
|
| 784 |
-
" <tr>\n",
|
| 785 |
-
" <th>2</th>\n",
|
| 786 |
-
" <td>0.9</td>\n",
|
| 787 |
-
" <td>0.1</td>\n",
|
| 788 |
-
" <td>64.0</td>\n",
|
| 789 |
-
" <td>1.203270</td>\n",
|
| 790 |
-
" <td>0.549738</td>\n",
|
| 791 |
-
" <td>1.024475</td>\n",
|
| 792 |
-
" <td>0.607330</td>\n",
|
| 793 |
-
" <td>1.099551</td>\n",
|
| 794 |
-
" <td>0.542757</td>\n",
|
| 795 |
-
" <td>1.0</td>\n",
|
| 796 |
-
" <td>0.597314</td>\n",
|
| 797 |
-
" <td>0.629707</td>\n",
|
| 798 |
-
" <td>0.583353</td>\n",
|
| 799 |
-
" <td>0.558664</td>\n",
|
| 800 |
-
" <td>0.612029</td>\n",
|
| 801 |
-
" <td>0.532126</td>\n",
|
| 802 |
-
" <td>150.0</td>\n",
|
| 803 |
-
" <td>0.0</td>\n",
|
| 804 |
-
" <td>1.109098</td>\n",
|
| 805 |
-
" <td>0.566608</td>\n",
|
| 806 |
-
" <td>150.0</td>\n",
|
| 807 |
-
" <td>0.0</td>\n",
|
| 808 |
-
" <td>0</td>\n",
|
| 809 |
-
" </tr>\n",
|
| 810 |
-
" </tbody>\n",
|
| 811 |
-
"</table>\n",
|
| 812 |
-
"</div>"
|
| 813 |
-
],
|
| 814 |
-
"text/plain": [
|
| 815 |
-
" momentum learning rate batch size ... epoch weight_decay loss_na\n",
|
| 816 |
-
"0 0.9 0.1 64.0 ... 50.0 0.0 0\n",
|
| 817 |
-
"1 0.9 0.1 64.0 ... 100.0 0.0 0\n",
|
| 818 |
-
"2 0.9 0.1 64.0 ... 150.0 0.0 0\n",
|
| 819 |
-
"\n",
|
| 820 |
-
"[3 rows x 23 columns]"
|
| 821 |
-
]
|
| 822 |
-
},
|
| 823 |
-
"metadata": {
|
| 824 |
-
"tags": []
|
| 825 |
-
},
|
| 826 |
-
"execution_count": 23
|
| 827 |
-
}
|
| 828 |
-
]
|
| 829 |
-
},
|
| 830 |
-
{
|
| 831 |
-
"cell_type": "code",
|
| 832 |
-
"metadata": {
|
| 833 |
-
"id": "YBBXsz_rzQl-"
|
| 834 |
-
},
|
| 835 |
-
"source": [
|
| 836 |
-
"generator = tensorflow.keras.preprocessing.image.ImageDataGenerator(rotation_range=10,\n",
|
| 837 |
-
" width_shift_range=5./32,\n",
|
| 838 |
-
" height_shift_range=5./32,)"
|
| 839 |
-
],
|
| 840 |
-
"execution_count": null,
|
| 841 |
-
"outputs": []
|
| 842 |
-
},
|
| 843 |
-
{
|
| 844 |
-
"cell_type": "code",
|
| 845 |
-
"metadata": {
|
| 846 |
-
"id": "ShIJn_53mawD"
|
| 847 |
-
},
|
| 848 |
-
"source": [
|
| 849 |
-
"kf = KFold(n_splits=3, random_state=42, shuffle=False)\n",
|
| 850 |
-
"result = []\n",
|
| 851 |
-
"for j, (train_index, test_index) in enumerate(kf.split(X_train)):\n",
|
| 852 |
-
" x_train, x_val = X_train[train_index], X_train[test_index]\n",
|
| 853 |
-
" y_train, y_val = Y_train[train_index], Y_train[test_index]\n",
|
| 854 |
-
" model = model_2(0)\n",
|
| 855 |
-
" opt = tensorflow.keras.optimizers.SGD(learning_rate=0.01)\n",
|
| 856 |
-
" model.compile(loss='categorical_crossentropy', optimizer=opt, metrics= ['accuracy'])\n",
|
| 857 |
-
" hist = model.fit(generator.flow(x_train, y_train, batch_size=64), steps_per_epoch=len(x_train) //64 , epochs=50,\n",
|
| 858 |
-
" validation_data=(x_val, y_val),\n",
|
| 859 |
-
" validation_steps=len(x_val) //64 ,)\n",
|
| 860 |
-
"\n",
|
| 861 |
-
" test = model.evaluate(X_test, y_test)\n",
|
| 862 |
-
" result.append(test)"
|
| 863 |
-
],
|
| 864 |
-
"execution_count": null,
|
| 865 |
-
"outputs": []
|
| 866 |
-
},
|
| 867 |
-
{
|
| 868 |
-
"cell_type": "code",
|
| 869 |
-
"metadata": {
|
| 870 |
-
"colab": {
|
| 871 |
-
"base_uri": "https://localhost:8080/"
|
| 872 |
-
},
|
| 873 |
-
"id": "Rj6hnUVRzjX-",
|
| 874 |
-
"outputId": "41b04e1e-d756-437a-afff-ca9f3fcff46c"
|
| 875 |
-
},
|
| 876 |
-
"source": [
|
| 877 |
-
"mean_acc = (result[0][1]+result[1][1]+result[2][1])/3;mean_acc"
|
| 878 |
-
],
|
| 879 |
-
"execution_count": null,
|
| 880 |
-
"outputs": [
|
| 881 |
-
{
|
| 882 |
-
"output_type": "execute_result",
|
| 883 |
-
"data": {
|
| 884 |
-
"text/plain": [
|
| 885 |
-
"0.6305991808573405"
|
| 886 |
-
]
|
| 887 |
-
},
|
| 888 |
-
"metadata": {
|
| 889 |
-
"tags": []
|
| 890 |
-
},
|
| 891 |
-
"execution_count": 47
|
| 892 |
-
}
|
| 893 |
-
]
|
| 894 |
-
},
|
| 895 |
-
{
|
| 896 |
-
"cell_type": "code",
|
| 897 |
-
"metadata": {
|
| 898 |
-
"colab": {
|
| 899 |
-
"base_uri": "https://localhost:8080/"
|
| 900 |
-
},
|
| 901 |
-
"id": "v_AccCit1De3",
|
| 902 |
-
"outputId": "a6299f58-4389-4f39-f318-e2371e94f4e5"
|
| 903 |
-
},
|
| 904 |
-
"source": [
|
| 905 |
-
"mean_loss = (result[0][0]+result[1][0]+result[2][0])/3;mean_loss"
|
| 906 |
-
],
|
| 907 |
-
"execution_count": null,
|
| 908 |
-
"outputs": [
|
| 909 |
-
{
|
| 910 |
-
"output_type": "execute_result",
|
| 911 |
-
"data": {
|
| 912 |
-
"text/plain": [
|
| 913 |
-
"0.9891296029090881"
|
| 914 |
-
]
|
| 915 |
-
},
|
| 916 |
-
"metadata": {
|
| 917 |
-
"tags": []
|
| 918 |
-
},
|
| 919 |
-
"execution_count": 48
|
| 920 |
-
}
|
| 921 |
-
]
|
| 922 |
-
},
|
| 923 |
-
{
|
| 924 |
-
"cell_type": "code",
|
| 925 |
-
"metadata": {
|
| 926 |
-
"id": "8gdDigrPzYM5"
|
| 927 |
-
},
|
| 928 |
-
"source": [
|
| 929 |
-
"kf = KFold(n_splits=3, random_state=42, shuffle=False)\n",
|
| 930 |
-
"result_2 = []\n",
|
| 931 |
-
"for j, (train_index, test_index) in enumerate(kf.split(X_train)):\n",
|
| 932 |
-
" x_train, x_val = X_train[train_index], X_train[test_index]\n",
|
| 933 |
-
" y_train, y_val = Y_train[train_index], Y_train[test_index]\n",
|
| 934 |
-
" model = model_3(0.0001)\n",
|
| 935 |
-
" opt = tensorflow.keras.optimizers.SGD(learning_rate=0.01)\n",
|
| 936 |
-
" model.compile(loss='categorical_crossentropy', optimizer=opt, metrics= ['accuracy'])\n",
|
| 937 |
-
" hist = model.fit(generator.flow(x_train, y_train, batch_size=64), steps_per_epoch=len(x_train) //64 , epochs=100,\n",
|
| 938 |
-
" validation_data=(x_val, y_val),\n",
|
| 939 |
-
" validation_steps=len(x_val) //64 ,)\n",
|
| 940 |
-
"\n",
|
| 941 |
-
" test = model.evaluate(X_test, y_test)\n",
|
| 942 |
-
" result_2.append(test)"
|
| 943 |
-
],
|
| 944 |
-
"execution_count": null,
|
| 945 |
-
"outputs": []
|
| 946 |
-
},
|
| 947 |
-
{
|
| 948 |
-
"cell_type": "code",
|
| 949 |
-
"metadata": {
|
| 950 |
-
"colab": {
|
| 951 |
-
"base_uri": "https://localhost:8080/"
|
| 952 |
-
},
|
| 953 |
-
"id": "FYVQINoA2F4l",
|
| 954 |
-
"outputId": "ac29f686-9b48-4193-bf01-4a17ab589d91"
|
| 955 |
-
},
|
| 956 |
-
"source": [
|
| 957 |
-
"mean_acc = (result_2[0][1]+result_2[1][1]+result_2[2][1])/3;mean_acc"
|
| 958 |
-
],
|
| 959 |
-
"execution_count": null,
|
| 960 |
-
"outputs": [
|
| 961 |
-
{
|
| 962 |
-
"output_type": "execute_result",
|
| 963 |
-
"data": {
|
| 964 |
-
"text/plain": [
|
| 965 |
-
"0.7108784119288126"
|
| 966 |
-
]
|
| 967 |
-
},
|
| 968 |
-
"metadata": {
|
| 969 |
-
"tags": []
|
| 970 |
-
},
|
| 971 |
-
"execution_count": 56
|
| 972 |
-
}
|
| 973 |
-
]
|
| 974 |
-
},
|
| 975 |
-
{
|
| 976 |
-
"cell_type": "code",
|
| 977 |
-
"metadata": {
|
| 978 |
-
"colab": {
|
| 979 |
-
"base_uri": "https://localhost:8080/"
|
| 980 |
-
},
|
| 981 |
-
"id": "guueg5Wo2IE8",
|
| 982 |
-
"outputId": "986e139c-1224-43a7-b8fc-59d57e72d3a8"
|
| 983 |
-
},
|
| 984 |
-
"source": [
|
| 985 |
-
"mean_loss = (result_2[0][0]+result_2[1][0]+result_2[2][0])/3;mean_loss"
|
| 986 |
-
],
|
| 987 |
-
"execution_count": null,
|
| 988 |
-
"outputs": [
|
| 989 |
-
{
|
| 990 |
-
"output_type": "execute_result",
|
| 991 |
-
"data": {
|
| 992 |
-
"text/plain": [
|
| 993 |
-
"0.9208946625391642"
|
| 994 |
-
]
|
| 995 |
-
},
|
| 996 |
-
"metadata": {
|
| 997 |
-
"tags": []
|
| 998 |
-
},
|
| 999 |
-
"execution_count": 57
|
| 1000 |
-
}
|
| 1001 |
-
]
|
| 1002 |
-
},
|
| 1003 |
-
{
|
| 1004 |
-
"cell_type": "code",
|
| 1005 |
-
"metadata": {
|
| 1006 |
-
"colab": {
|
| 1007 |
-
"base_uri": "https://localhost:8080/"
|
| 1008 |
-
},
|
| 1009 |
-
"id": "luS42Hil0H5R",
|
| 1010 |
-
"outputId": "91c28154-8014-42a9-fbf2-8935cb4c521a"
|
| 1011 |
-
},
|
| 1012 |
-
"source": [
|
| 1013 |
-
"mean_acc = (result_2[0][1]+result_2[1][1]+result_2[2][1])/3;mean_acc"
|
| 1014 |
-
],
|
| 1015 |
-
"execution_count": null,
|
| 1016 |
-
"outputs": [
|
| 1017 |
-
{
|
| 1018 |
-
"output_type": "execute_result",
|
| 1019 |
-
"data": {
|
| 1020 |
-
"text/plain": [
|
| 1021 |
-
"0.6783013343811035"
|
| 1022 |
-
]
|
| 1023 |
-
},
|
| 1024 |
-
"metadata": {
|
| 1025 |
-
"tags": []
|
| 1026 |
-
},
|
| 1027 |
-
"execution_count": 50
|
| 1028 |
-
}
|
| 1029 |
-
]
|
| 1030 |
-
},
|
| 1031 |
-
{
|
| 1032 |
-
"cell_type": "code",
|
| 1033 |
-
"metadata": {
|
| 1034 |
-
"colab": {
|
| 1035 |
-
"base_uri": "https://localhost:8080/"
|
| 1036 |
-
},
|
| 1037 |
-
"id": "F9Zg3pI61W_J",
|
| 1038 |
-
"outputId": "fe5f7861-58e9-47f6-8134-8872e494e268"
|
| 1039 |
-
},
|
| 1040 |
-
"source": [
|
| 1041 |
-
"mean_loss = (result_2[0][0]+result_2[1][0]+result_2[2][0])/3;mean_loss"
|
| 1042 |
-
],
|
| 1043 |
-
"execution_count": null,
|
| 1044 |
-
"outputs": [
|
| 1045 |
-
{
|
| 1046 |
-
"output_type": "execute_result",
|
| 1047 |
-
"data": {
|
| 1048 |
-
"text/plain": [
|
| 1049 |
-
"0.8747362097104391"
|
| 1050 |
-
]
|
| 1051 |
-
},
|
| 1052 |
-
"metadata": {
|
| 1053 |
-
"tags": []
|
| 1054 |
-
},
|
| 1055 |
-
"execution_count": 51
|
| 1056 |
-
}
|
| 1057 |
-
]
|
| 1058 |
-
},
|
| 1059 |
-
{
|
| 1060 |
-
"cell_type": "code",
|
| 1061 |
-
"metadata": {
|
| 1062 |
-
"colab": {
|
| 1063 |
-
"base_uri": "https://localhost:8080/"
|
| 1064 |
-
},
|
| 1065 |
-
"id": "OWx9Jej11f_V",
|
| 1066 |
-
"outputId": "a5b72e6c-de64-4bee-da66-d48a5e6e1870"
|
| 1067 |
-
},
|
| 1068 |
-
"source": [
|
| 1069 |
-
"model = model_1(0)\n",
|
| 1070 |
-
"opt = tensorflow.keras.optimizers.SGD(learning_rate=0.01)\n",
|
| 1071 |
-
"model.compile(loss='categorical_crossentropy', optimizer=opt, metrics= ['accuracy'])\n",
|
| 1072 |
-
"model.summary()"
|
| 1073 |
-
],
|
| 1074 |
-
"execution_count": null,
|
| 1075 |
-
"outputs": [
|
| 1076 |
-
{
|
| 1077 |
-
"output_type": "stream",
|
| 1078 |
-
"text": [
|
| 1079 |
-
"Model: \"sequential_162\"\n",
|
| 1080 |
-
"_________________________________________________________________\n",
|
| 1081 |
-
"Layer (type) Output Shape Param # \n",
|
| 1082 |
-
"=================================================================\n",
|
| 1083 |
-
"conv2d_342 (Conv2D) (None, 30, 30, 32) 320 \n",
|
| 1084 |
-
"_________________________________________________________________\n",
|
| 1085 |
-
"conv2d_343 (Conv2D) (None, 28, 28, 64) 18496 \n",
|
| 1086 |
-
"_________________________________________________________________\n",
|
| 1087 |
-
"max_pooling2d_183 (MaxPoolin (None, 14, 14, 64) 0 \n",
|
| 1088 |
-
"_________________________________________________________________\n",
|
| 1089 |
-
"batch_normalization_183 (Bat (None, 14, 14, 64) 256 \n",
|
| 1090 |
-
"_________________________________________________________________\n",
|
| 1091 |
-
"flatten_162 (Flatten) (None, 12544) 0 \n",
|
| 1092 |
-
"_________________________________________________________________\n",
|
| 1093 |
-
"dense_162 (Dense) (None, 4) 50180 \n",
|
| 1094 |
-
"=================================================================\n",
|
| 1095 |
-
"Total params: 69,252\n",
|
| 1096 |
-
"Trainable params: 69,124\n",
|
| 1097 |
-
"Non-trainable params: 128\n",
|
| 1098 |
-
"_________________________________________________________________\n"
|
| 1099 |
-
],
|
| 1100 |
-
"name": "stdout"
|
| 1101 |
-
}
|
| 1102 |
-
]
|
| 1103 |
-
},
|
| 1104 |
-
{
|
| 1105 |
-
"cell_type": "code",
|
| 1106 |
-
"metadata": {
|
| 1107 |
-
"colab": {
|
| 1108 |
-
"base_uri": "https://localhost:8080/"
|
| 1109 |
-
},
|
| 1110 |
-
"id": "l8mg2Bup1lPJ",
|
| 1111 |
-
"outputId": "6a2348f9-f6d8-47d3-a5d7-9169d3906d76"
|
| 1112 |
-
},
|
| 1113 |
-
"source": [
|
| 1114 |
-
"model = model_2(0)\n",
|
| 1115 |
-
"opt = tensorflow.keras.optimizers.SGD(learning_rate=0.01)\n",
|
| 1116 |
-
"model.compile(loss='categorical_crossentropy', optimizer=opt, metrics= ['accuracy'])\n",
|
| 1117 |
-
"model.summary()"
|
| 1118 |
-
],
|
| 1119 |
-
"execution_count": null,
|
| 1120 |
-
"outputs": [
|
| 1121 |
-
{
|
| 1122 |
-
"output_type": "stream",
|
| 1123 |
-
"text": [
|
| 1124 |
-
"Model: \"sequential_154\"\n",
|
| 1125 |
-
"_________________________________________________________________\n",
|
| 1126 |
-
"Layer (type) Output Shape Param # \n",
|
| 1127 |
-
"=================================================================\n",
|
| 1128 |
-
"conv2d_320 (Conv2D) (None, 30, 30, 32) 320 \n",
|
| 1129 |
-
"_________________________________________________________________\n",
|
| 1130 |
-
"conv2d_321 (Conv2D) (None, 28, 28, 64) 18496 \n",
|
| 1131 |
-
"_________________________________________________________________\n",
|
| 1132 |
-
"max_pooling2d_166 (MaxPoolin (None, 14, 14, 64) 0 \n",
|
| 1133 |
-
"_________________________________________________________________\n",
|
| 1134 |
-
"batch_normalization_166 (Bat (None, 14, 14, 64) 256 \n",
|
| 1135 |
-
"_________________________________________________________________\n",
|
| 1136 |
-
"conv2d_322 (Conv2D) (None, 12, 12, 128) 73856 \n",
|
| 1137 |
-
"_________________________________________________________________\n",
|
| 1138 |
-
"max_pooling2d_167 (MaxPoolin (None, 6, 6, 128) 0 \n",
|
| 1139 |
-
"_________________________________________________________________\n",
|
| 1140 |
-
"batch_normalization_167 (Bat (None, 6, 6, 128) 512 \n",
|
| 1141 |
-
"_________________________________________________________________\n",
|
| 1142 |
-
"flatten_154 (Flatten) (None, 4608) 0 \n",
|
| 1143 |
-
"_________________________________________________________________\n",
|
| 1144 |
-
"dense_154 (Dense) (None, 4) 18436 \n",
|
| 1145 |
-
"=================================================================\n",
|
| 1146 |
-
"Total params: 111,876\n",
|
| 1147 |
-
"Trainable params: 111,492\n",
|
| 1148 |
-
"Non-trainable params: 384\n",
|
| 1149 |
-
"_________________________________________________________________\n"
|
| 1150 |
-
],
|
| 1151 |
-
"name": "stdout"
|
| 1152 |
-
}
|
| 1153 |
-
]
|
| 1154 |
-
},
|
| 1155 |
-
{
|
| 1156 |
-
"cell_type": "code",
|
| 1157 |
-
"metadata": {
|
| 1158 |
-
"colab": {
|
| 1159 |
-
"base_uri": "https://localhost:8080/"
|
| 1160 |
-
},
|
| 1161 |
-
"id": "vZImk4zB1mxk",
|
| 1162 |
-
"outputId": "805f844e-1312-423a-b309-35944f11a210"
|
| 1163 |
-
},
|
| 1164 |
-
"source": [
|
| 1165 |
-
"model = model_3(0)\n",
|
| 1166 |
-
"opt = tensorflow.keras.optimizers.SGD(learning_rate=0.01)\n",
|
| 1167 |
-
"model.compile(loss='categorical_crossentropy', optimizer=opt, metrics= ['accuracy'])\n",
|
| 1168 |
-
"model.summary()"
|
| 1169 |
-
],
|
| 1170 |
-
"execution_count": null,
|
| 1171 |
-
"outputs": [
|
| 1172 |
-
{
|
| 1173 |
-
"output_type": "stream",
|
| 1174 |
-
"text": [
|
| 1175 |
-
"Model: \"sequential_155\"\n",
|
| 1176 |
-
"_________________________________________________________________\n",
|
| 1177 |
-
"Layer (type) Output Shape Param # \n",
|
| 1178 |
-
"=================================================================\n",
|
| 1179 |
-
"conv2d_323 (Conv2D) (None, 30, 30, 32) 320 \n",
|
| 1180 |
-
"_________________________________________________________________\n",
|
| 1181 |
-
"conv2d_324 (Conv2D) (None, 28, 28, 64) 18496 \n",
|
| 1182 |
-
"_________________________________________________________________\n",
|
| 1183 |
-
"max_pooling2d_168 (MaxPoolin (None, 14, 14, 64) 0 \n",
|
| 1184 |
-
"_________________________________________________________________\n",
|
| 1185 |
-
"batch_normalization_168 (Bat (None, 14, 14, 64) 256 \n",
|
| 1186 |
-
"_________________________________________________________________\n",
|
| 1187 |
-
"conv2d_325 (Conv2D) (None, 12, 12, 128) 73856 \n",
|
| 1188 |
-
"_________________________________________________________________\n",
|
| 1189 |
-
"max_pooling2d_169 (MaxPoolin (None, 6, 6, 128) 0 \n",
|
| 1190 |
-
"_________________________________________________________________\n",
|
| 1191 |
-
"batch_normalization_169 (Bat (None, 6, 6, 128) 512 \n",
|
| 1192 |
-
"_________________________________________________________________\n",
|
| 1193 |
-
"conv2d_326 (Conv2D) (None, 4, 4, 256) 295168 \n",
|
| 1194 |
-
"_________________________________________________________________\n",
|
| 1195 |
-
"max_pooling2d_170 (MaxPoolin (None, 2, 2, 256) 0 \n",
|
| 1196 |
-
"_________________________________________________________________\n",
|
| 1197 |
-
"batch_normalization_170 (Bat (None, 2, 2, 256) 1024 \n",
|
| 1198 |
-
"_________________________________________________________________\n",
|
| 1199 |
-
"flatten_155 (Flatten) (None, 1024) 0 \n",
|
| 1200 |
-
"_________________________________________________________________\n",
|
| 1201 |
-
"dense_155 (Dense) (None, 4) 4100 \n",
|
| 1202 |
-
"=================================================================\n",
|
| 1203 |
-
"Total params: 393,732\n",
|
| 1204 |
-
"Trainable params: 392,836\n",
|
| 1205 |
-
"Non-trainable params: 896\n",
|
| 1206 |
-
"_________________________________________________________________\n"
|
| 1207 |
-
],
|
| 1208 |
-
"name": "stdout"
|
| 1209 |
-
}
|
| 1210 |
-
]
|
| 1211 |
-
}
|
| 1212 |
-
]
|
| 1213 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/FullyConectedModels/gridsearchcv/grid_model_1.csv
DELETED
|
@@ -1,49 +0,0 @@
|
|
| 1 |
-
;momentum;learning rate;batch size;loss1;acc1;loss2;acc2;loss3;acc3;widing factor;prec1;prec2;prec3;recall1;recall2;recall3;epoch_stopped;reg_penalty
|
| 2 |
-
0;0.9;0.1;64.0;1.0786635875701904;0.5410122275352478;1.087765097618103;0.5602094531059265;1.1405400037765503;0.4956369996070862;1.0;0.5456249725117649;0.5703025971923495;0.5404468631355052;0.549295511252033;0.561227419923072;0.5060386473429952;50.0;0.0
|
| 3 |
-
1;0.9;0.1;64.0;1.0904589891433716;0.5445026159286499;1.0521031618118286;0.5689354538917542;0.9919149279594421;0.5863874554634094;1.0;0.5677450027120325;0.5818822696361521;0.5870538777082881;0.5496315278923974;0.5715594791681748;0.5847485847485847;100.0;0.0
|
| 4 |
-
2;0.9;0.1;64.0;1.203269600868225;0.5497382283210754;1.0244745016098022;0.6073298454284668;1.0995512008666992;0.5427573919296265;1.0;0.5973140968206758;0.6297070150268066;0.5833525841963643;0.5586641619250315;0.6120292750727534;0.5321260864739126;150.0;0.0
|
| 5 |
-
3;0.9;0.1;64.0;1.3318754434585571;0.45026177167892456;1.3573890924453735;0.4258289635181427;1.1364340782165527;0.5427573919296265;1.0;0.520103685420754;0.5564409752732054;0.5418956991063357;0.4679843103756147;0.44125801734497383;0.5497298595124682;50.0;0.01
|
| 6 |
-
4;0.9;0.1;64.0;1.2862250804901123;0.49040138721466064;1.2183420658111572;0.5392670035362244;1.2070523500442505;0.5375218391418457;1.0;0.5528382317603805;0.5694776622208874;0.586225988700565;0.4945869348043261;0.5416503786069004;0.5475315747054877;100.0;0.01
|
| 7 |
-
5;0.9;0.1;64.0;1.1847248077392578;0.5445026159286499;1.4426966905593872;0.43106457591056824;1.2306617498397827;0.47993019223213196;1.0;0.5312498290025648;0.5339512821095386;0.5311977318393986;0.544515870602827;0.4474462735332301;0.49295993861211257;150.0;0.01
|
| 8 |
-
6;0.9;0.1;64.0;1.224439263343811;0.5095986127853394;1.1083790063858032;0.5584642291069031;1.042079210281372;0.584642231464386;1.0;0.5344176263987584;0.5580401772808269;0.5725199160342243;0.5072463768115942;0.564412640499597;0.5839504698200351;50.0;0.001
|
| 9 |
-
7;0.9;0.1;64.0;1.1233032941818237;0.5671902298927307;1.279549241065979;0.49389180541038513;1.0785572528839111;0.5811518430709839;1.0;0.5835013620894915;0.5600592746287815;0.5944705565842157;0.5647160810204288;0.4860936165283991;0.5825973543364849;100.0;0.001
|
| 10 |
-
8;0.9;0.1;64.0;1.0618174076080322;0.5898778438568115;1.0041369199752808;0.6178010702133179;1.0159738063812256;0.6020942330360413;1.0;0.5958448974789733;0.6197077143812308;0.6046378435659604;0.5848010684967206;0.6172076715554976;0.6017768463420637;150.0;0.001
|
| 11 |
-
9;0.9;0.1;64.0;1.1275826692581177;0.5095986127853394;1.037006139755249;0.547993004322052;1.1576757431030273;0.4712041914463043;1.0;0.5163513023182834;0.5527921101786539;0.4780335054814982;0.5184676434676434;0.5468667805624328;0.47072553050813926;50.0;0.0001
|
| 12 |
-
10;0.9;0.1;64.0;1.0660938024520874;0.5811518430709839;1.0170561075210571;0.5828970074653625;1.0196107625961304;0.5811518430709839;1.0;0.5900164662084766;0.5999640724915478;0.5950983436203666;0.5901320901320901;0.5829641373119634;0.5774551535421101;100.0;0.0001
|
| 13 |
-
11;0.9;0.1;64.0;1.0177899599075317;0.5881326198577881;1.0439162254333496;0.5654450058937073;1.1021149158477783;0.5759162306785583;1.0;0.6036141210870313;0.5902555674010898;0.6140637358081388;0.5845627802149542;0.561161664422534;0.5720825068651156;150.0;0.0001
|
| 14 |
-
12;0.9;0.1;128.0;1.0818946361541748;0.5253053903579712;1.0099669694900513;0.5968586206436157;1.0918453931808472;0.5410122275352478;1.0;0.521080315961989;0.6039058290508166;0.5263435769673295;0.5298210243862418;0.6008387747518182;0.5431760268716791;50.0;0.0
|
| 15 |
-
13;0.9;0.1;128.0;1.164839744567871;0.45724257826805115;0.9867958426475525;0.5933682322502136;1.0679926872253418;0.554973840713501;1.0;0.5123214368077382;0.5988966761697891;0.5643833075235515;0.4695654586958935;0.597261434217956;0.5518648018648019;100.0;0.0
|
| 16 |
-
14;0.9;0.1;128.0;1.113759160041809;0.5567190051078796;1.046524167060852;0.584642231464386;0.9935498833656311;0.5724258422851562;1.0;0.6122945895111208;0.5897342892147153;0.5820350700747354;0.553781966825445;0.58606007519051;0.5708771904424078;150.0;0.0
|
| 17 |
-
15;0.9;0.1;128.0;1.1370242834091187;0.5619546175003052;1.1299653053283691;0.5462478399276733;1.1480666399002075;0.5322862267494202;1.0;0.5756164609818162;0.5500235261699395;0.5222666160817157;0.5605360822752127;0.5444175389827564;0.53630849826502;50.0;0.01
|
| 18 |
-
16;0.9;0.1;128.0;1.1349637508392334;0.5636998414993286;1.1408828496932983;0.5410122275352478;1.3376970291137695;0.47993019223213196;1.0;0.5815890943000785;0.5523155608073429;0.5205896541198826;0.5641309173917869;0.5482174830000918;0.4919084538649756;100.0;0.01
|
| 19 |
-
17;0.9;0.1;128.0;1.1322417259216309;0.5881326198577881;1.1545387506484985;0.5602094531059265;1.2242604494094849;0.5654450058937073;1.0;0.599198600670369;0.5865846668946237;0.5603340273492765;0.5861801242236024;0.5607828162175988;0.5646859179467875;150.0;0.01
|
| 20 |
-
18;0.9;0.1;128.0;1.1220964193344116;0.5514833927154541;1.1653169393539429;0.5270506143569946;1.1562250852584839;0.5235602259635925;1.0;0.5715813891898507;0.5249991236833342;0.5223967978570353;0.5467980087545306;0.5330587287109027;0.5233751755490886;50.0;0.001
|
| 21 |
-
19;0.9;0.1;128.0;1.153067708015442;0.5881326198577881;1.1445388793945312;0.5602094531059265;1.1189345121383667;0.5357766151428223;1.0;0.6072767352919373;0.5604236873448876;0.5574203282247268;0.5862139068660808;0.5572374485417964;0.5373183579705318;100.0;0.001
|
| 22 |
-
20;0.9;0.1;128.0;1.2125698328018188;0.5410122275352478;1.1219922304153442;0.5602094531059265;1.0713446140289307;0.5828970074653625;1.0;0.5605981222253124;0.5931479015961776;0.5828110117010443;0.5311391507043681;0.5714448594883378;0.5890498390498391;150.0;0.001
|
| 23 |
-
21;0.9;0.1;128.0;1.1447523832321167;0.5567190051078796;1.1324081420898438;0.518324613571167;1.1090422868728638;0.49738219380378723;1.0;0.596030303030303;0.5490280922716101;0.5121340146227983;0.5638238573021181;0.5146260744086831;0.4972123287340679;50.0;0.0001
|
| 24 |
-
22;0.9;0.1;128.0;0.9825822710990906;0.584642231464386;1.108230471611023;0.5218150019645691;1.1356145143508911;0.5200698375701904;1.0;0.5948660954866036;0.5665316277437249;0.5656796449319814;0.5828724415680937;0.5286633656198874;0.5303995521386826;100.0;0.0001
|
| 25 |
-
23;0.9;0.1;128.0;1.0473191738128662;0.5951134562492371;1.0678406953811646;0.518324613571167;1.0169429779052734;0.5881326198577881;1.0;0.6219409183078248;0.5693268693732438;0.593293789725479;0.5926543263499786;0.5205706129619173;0.5907293189901885;150.0;0.0001
|
| 26 |
-
24;0.9;0.01;64.0;1.0598105192184448;0.5968586206436157;1.0257790088653564;0.6038394570350647;0.9506248235702515;0.6369982361793518;1.0;0.6051437947470779;0.602723696484381;0.6396683984470244;0.5915582002538524;0.609172831998919;0.6423491966970227;50.0;0.0
|
| 27 |
-
25;0.9;0.01;64.0;1.1066389083862305;0.6108202338218689;1.0695348978042603;0.6195462346076965;1.1887136697769165;0.5741710066795349;1.0;0.6467146374525526;0.6245149204071099;0.5971208662939392;0.6051062464105943;0.6201624462494028;0.5657778212126039;100.0;0.0
|
| 28 |
-
26;0.9;0.01;64.0;1.076154351234436;0.6404886841773987;1.1792618036270142;0.6073298454284668;1.4671285152435303;0.547993004322052;1.0;0.6529888091081401;0.6461774006977742;0.6066030792539256;0.6368685662163923;0.6003965840922363;0.537592841940668;150.0;0.0
|
| 29 |
-
27;0.9;0.01;64.0;1.2690242528915405;0.5968586206436157;1.2563925981521606;0.5968586206436157;1.2163962125778198;0.5968586206436157;1.0;0.5916521579476596;0.6248306369232546;0.5985189460592925;0.5966666184057489;0.6007983562331388;0.5971992982862548;50.0;0.01
|
| 30 |
-
28;0.9;0.01;64.0;1.116417646408081;0.6160558462142944;1.1469776630401611;0.6178010702133179;1.1717371940612793;0.584642231464386;1.0;0.6261048220941269;0.6335991775298133;0.610479317595543;0.6109440076831381;0.6120992534036012;0.5811808963982876;100.0;0.01
|
| 31 |
-
29;0.9;0.01;64.0;1.214686393737793;0.5811518430709839;1.2714550495147705;0.5584642291069031;1.2080868482589722;0.6073298454284668;1.0;0.6001486481081634;0.5946394895073237;0.6365187860446218;0.5804340586949283;0.5500978490108924;0.6034834730486904;150.0;0.01
|
| 32 |
-
30;0.9;0.01;64.0;1.1514617204666138;0.5881326198577881;1.0881938934326172;0.5881326198577881;1.053907871246338;0.6020942330360413;1.0;0.5834891320869433;0.6065228770652975;0.6088152818711143;0.5858362651840913;0.5805794447098794;0.6013455143889928;50.0;0.001
|
| 33 |
-
31;0.9;0.01;64.0;1.22362220287323;0.5968586206436157;1.0609272718429565;0.6300174593925476;1.1912355422973633;0.5968586206436157;1.0;0.621725687037072;0.6327141334647053;0.6186631707154095;0.5975027388070866;0.6298701298701299;0.5883030013464796;100.0;0.001
|
| 34 |
-
32;0.9;0.01;64.0;1.1824991703033447;0.5986038446426392;1.226974368095398;0.6090750694274902;1.2909026145935059;0.6038394570350647;1.0;0.6073992988417198;0.6445184510325356;0.619910049538771;0.5955710955710956;0.6017605582822975;0.5952151713021279;150.0;0.001
|
| 35 |
-
33;0.9;0.01;64.0;1.1622414588928223;0.5287958383560181;1.0493906736373901;0.5794066190719604;1.2138421535491943;0.5497382283210754;1.0;0.5809122116993566;0.6019885476001634;0.5910039233578294;0.5309117211291124;0.5840572471007254;0.5480895915678524;50.0;0.0001
|
| 36 |
-
34;0.9;0.01;64.0;1.0573968887329102;0.6073298454284668;1.0098820924758911;0.62129145860672;1.0127638578414917;0.614310622215271;1.0;0.6070338521099506;0.62878369303158;0.6239062815043207;0.603968495272843;0.6201449516666908;0.6131923631923631;100.0;0.0001
|
| 37 |
-
35;0.9;0.01;64.0;1.0984822511672974;0.657940685749054;1.0174990892410278;0.6561954617500305;1.1236447095870972;0.5986038446426392;1.0;0.6531590453460197;0.6671137560770346;0.5932537237439376;0.6610702099832534;0.6503381883816667;0.5962919930311235;150.0;0.0001
|
| 38 |
-
36;0.9;0.01;128.0;1.0621216297149658;0.584642231464386;1.0866228342056274;0.5497382283210754;1.073479413986206;0.5671902298927307;1.0;0.5977450132718186;0.5629235472659129;0.5687495323396486;0.5795900958944438;0.545400251921991;0.565356141443098;50.0;0.0
|
| 39 |
-
37;0.9;0.01;128.0;1.0547319650650024;0.6125654578208923;1.1377087831497192;0.5776614546775818;1.127797245979309;0.5636998414993286;1.0;0.6235207292967124;0.5720610223175946;0.5682047407187032;0.6082697495740974;0.5770853542592673;0.5593319723754506;100.0;0.0
|
| 40 |
-
38;0.9;0.01;128.0;0.9883608222007751;0.6265270709991455;1.316406488418579;0.5567190051078796;1.1691138744354248;0.5828970074653625;1.0;0.6431009815575375;0.5975566000144896;0.6062137814624874;0.624762918241179;0.5490433479563914;0.577417148069322;150.0;0.0
|
| 41 |
-
39;0.9;0.01;128.0;1.443315863609314;0.5253053903579712;1.3692501783370972;0.5828970074653625;1.342460036277771;0.5933682322502136;1.0;0.5716634678107281;0.597909057945944;0.5906145139282292;0.5164768806073153;0.57604231517275;0.5995459854155506;50.0;0.01
|
| 42 |
-
40;0.9;0.01;128.0;1.2021657228469849;0.6265270709991455;1.3414463996887207;0.5619546175003052;1.2604358196258545;0.5828970074653625;1.0;0.6285688058844148;0.6186485688650869;0.6176072056565729;0.6251604675517719;0.5610753980319197;0.5805716023107328;100.0;0.01
|
| 43 |
-
41;0.9;0.01;128.0;1.1669517755508423;0.6003490686416626;1.195756435394287;0.6020942330360413;1.1677254438400269;0.614310622215271;1.0;0.6203690104168971;0.6203570670428672;0.6432026808004327;0.5943404421665291;0.596520025867852;0.6156590993547515;150.0;0.01
|
| 44 |
-
42;0.9;0.01;128.0;1.1474894285202026;0.5497382283210754;1.1794898509979248;0.5357766151428223;1.0964933633804321;0.5759162306785583;1.0;0.5853768536604357;0.5617280741536463;0.5739665220002184;0.5477795151708195;0.5463829648612257;0.5779051866008388;50.0;0.001
|
| 45 |
-
43;0.9;0.01;128.0;1.253674864768982;0.5253053903579712;1.160112977027893;0.5776614546775818;1.1979146003723145;0.5602094531059265;1.0;0.5522893772893773;0.5946065751075158;0.609541492289036;0.5178848928848929;0.5693081073515855;0.5646515320428364;100.0;0.001
|
| 46 |
-
44;0.9;0.01;128.0;1.0433732271194458;0.62129145860672;1.1583247184753418;0.6055846214294434;1.2541570663452148;0.5584642291069031;1.0;0.6268407792270415;0.6419238559031413;0.5858269376442662;0.6190530484008745;0.5973609723609723;0.5483924288272114;150.0;0.001
|
| 47 |
-
45;0.9;0.01;128.0;1.1987037658691406;0.5375218391418457;1.1068445444107056;0.5445026159286499;1.1949487924575806;0.518324613571167;1.0;0.5710625395113904;0.5581101609004185;0.5326604175732083;0.5442896475505171;0.5378775813558422;0.5108786141394838;50.0;0.0001
|
| 48 |
-
46;0.9;0.01;128.0;1.1030073165893555;0.5863874554634094;1.0547511577606201;0.6195462346076965;1.066347599029541;0.5881326198577881;1.0;0.6079505660141936;0.6188059698281521;0.5895490996780354;0.5875187614318049;0.6249137336093857;0.5872448807231415;100.0;0.0001
|
| 49 |
-
47;0.9;0.01;128.0;1.132765293121338;0.5898778438568115;1.2393956184387207;0.5340313911437988;1.0497409105300903;0.5881326198577881;1.0;0.5911296865320181;0.566333885194723;0.6075930260346379;0.5854749115618681;0.5218344457474892;0.5897791821704865;150.0;0.0001
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/FullyConectedModels/model.py
DELETED
|
@@ -1,120 +0,0 @@
|
|
| 1 |
-
from tensorflow.data import Dataset
|
| 2 |
-
import tensorflow.keras as keras
|
| 3 |
-
from tensorflow.keras.optimizers import Adam
|
| 4 |
-
from tensorflow.keras.layers import (
|
| 5 |
-
Conv2D,
|
| 6 |
-
Input,
|
| 7 |
-
MaxPooling2D,
|
| 8 |
-
Dense,
|
| 9 |
-
Dropout,
|
| 10 |
-
MaxPool1D,
|
| 11 |
-
Flatten,
|
| 12 |
-
AveragePooling1D,
|
| 13 |
-
BatchNormalization,
|
| 14 |
-
)
|
| 15 |
-
from tensorflow.keras import Model
|
| 16 |
-
import numpy as np
|
| 17 |
-
import tensorflow as tf
|
| 18 |
-
from tensorflow.keras.models import Sequential
|
| 19 |
-
from tensorflow.keras.models import Model
|
| 20 |
-
from tensorflow.keras.layers import Input, Add, Activation, Dropout, Flatten, Dense
|
| 21 |
-
from tensorflow.keras.layers import Convolution2D, MaxPooling2D, AveragePooling2D
|
| 22 |
-
from tensorflow.keras.layers import BatchNormalization
|
| 23 |
-
from tensorflow.keras.regularizers import l2
|
| 24 |
-
from tensorflow.keras import backend as K
|
| 25 |
-
from tensorflow.keras.optimizers import SGD
|
| 26 |
-
import warnings
|
| 27 |
-
|
| 28 |
-
warnings.filterwarnings("ignore")
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
def basemodel(weight_decay):
|
| 32 |
-
# 2 hidden layers
|
| 33 |
-
model_input = Input(
|
| 34 |
-
shape=(
|
| 35 |
-
32,
|
| 36 |
-
32,
|
| 37 |
-
1,
|
| 38 |
-
)
|
| 39 |
-
)
|
| 40 |
-
model = Conv2D(
|
| 41 |
-
32,
|
| 42 |
-
kernel_size=(3, 3),
|
| 43 |
-
kernel_regularizer=l2(weight_decay),
|
| 44 |
-
activation="relu",
|
| 45 |
-
)(model_input)
|
| 46 |
-
model = Conv2D(
|
| 47 |
-
64, kernel_size=(3, 3), kernel_regularizer=l2(weight_decay), activation="relu"
|
| 48 |
-
)(model)
|
| 49 |
-
model = MaxPooling2D(pool_size=(2, 2))(model)
|
| 50 |
-
model = BatchNormalization()(model)
|
| 51 |
-
model = Flatten()(model)
|
| 52 |
-
model = Dense(4, kernel_regularizer=l2(weight_decay), activation="softmax")(model)
|
| 53 |
-
model = Model(inputs=model_input, outputs=model)
|
| 54 |
-
return model
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
def model_2(weight_decay):
|
| 58 |
-
model_input = Input(
|
| 59 |
-
shape=(
|
| 60 |
-
32,
|
| 61 |
-
32,
|
| 62 |
-
1,
|
| 63 |
-
)
|
| 64 |
-
)
|
| 65 |
-
model = Conv2D(
|
| 66 |
-
32,
|
| 67 |
-
kernel_size=(3, 3),
|
| 68 |
-
kernel_regularizer=l2(weight_decay),
|
| 69 |
-
activation="relu",
|
| 70 |
-
)(model_input)
|
| 71 |
-
model = Conv2D(
|
| 72 |
-
64, kernel_size=(3, 3), kernel_regularizer=l2(weight_decay), activation="relu"
|
| 73 |
-
)(model)
|
| 74 |
-
model = MaxPooling2D(pool_size=(2, 2))(model)
|
| 75 |
-
model = BatchNormalization()(model)
|
| 76 |
-
model = Conv2D(
|
| 77 |
-
128, kernel_size=(3, 3), kernel_regularizer=l2(weight_decay), activation="relu"
|
| 78 |
-
)(model)
|
| 79 |
-
model = MaxPooling2D(pool_size=(2, 2))(model)
|
| 80 |
-
model = BatchNormalization()(model)
|
| 81 |
-
model = Flatten()(model)
|
| 82 |
-
model = Dense(4, kernel_regularizer=l2(weight_decay), activation="softmax")(model)
|
| 83 |
-
model = Model(inputs=model_input, outputs=model)
|
| 84 |
-
return model
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
def model_3(weight_decay):
|
| 88 |
-
# 4 hidden layers
|
| 89 |
-
model_input = Input(
|
| 90 |
-
shape=(
|
| 91 |
-
32,
|
| 92 |
-
32,
|
| 93 |
-
1,
|
| 94 |
-
)
|
| 95 |
-
)
|
| 96 |
-
model = Conv2D(
|
| 97 |
-
32,
|
| 98 |
-
kernel_size=(3, 3),
|
| 99 |
-
kernel_regularizer=l2(weight_decay),
|
| 100 |
-
activation="relu",
|
| 101 |
-
)(model_input)
|
| 102 |
-
model = Conv2D(
|
| 103 |
-
64, kernel_size=(3, 3), kernel_regularizer=l2(weight_decay), activation="relu"
|
| 104 |
-
)(model)
|
| 105 |
-
model = MaxPooling2D(pool_size=(2, 2))(model)
|
| 106 |
-
model = BatchNormalization()(model)
|
| 107 |
-
model = Conv2D(
|
| 108 |
-
128, kernel_size=(3, 3), kernel_regularizer=l2(weight_decay), activation="relu"
|
| 109 |
-
)(model)
|
| 110 |
-
model = MaxPooling2D(pool_size=(2, 2))(model)
|
| 111 |
-
model = BatchNormalization()(model)
|
| 112 |
-
model = Conv2D(
|
| 113 |
-
256, kernel_size=(3, 3), kernel_regularizer=l2(weight_decay), activation="relu"
|
| 114 |
-
)(model)
|
| 115 |
-
model = MaxPooling2D(pool_size=(2, 2))(model)
|
| 116 |
-
model = BatchNormalization()(model)
|
| 117 |
-
model = Flatten()(model)
|
| 118 |
-
model = Dense(4, kernel_regularizer=l2(weight_decay), activation="softmax")(model)
|
| 119 |
-
model = Model(inputs=model_input, outputs=model)
|
| 120 |
-
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/FullyConectedModels/parseval.py
DELETED
|
@@ -1,83 +0,0 @@
|
|
| 1 |
-
from tensorflow.data import Dataset
|
| 2 |
-
import tensorflow.keras as keras
|
| 3 |
-
from tensorflow.keras.optimizers import Adam
|
| 4 |
-
from tensorflow.keras.layers import (
|
| 5 |
-
Conv2D,
|
| 6 |
-
Input,
|
| 7 |
-
MaxPooling2D,
|
| 8 |
-
Dense,
|
| 9 |
-
Dropout,
|
| 10 |
-
MaxPool1D,
|
| 11 |
-
Flatten,
|
| 12 |
-
AveragePooling1D,
|
| 13 |
-
BatchNormalization,
|
| 14 |
-
)
|
| 15 |
-
from tensorflow.keras import Model
|
| 16 |
-
import numpy as np
|
| 17 |
-
import tensorflow as tf
|
| 18 |
-
from tensorflow.keras.models import Sequential
|
| 19 |
-
from tensorflow.keras.models import Model
|
| 20 |
-
from tensorflow.keras.layers import Input, Add, Activation, Dropout, Flatten, Dense
|
| 21 |
-
from tensorflow.keras.layers import Convolution2D, MaxPooling2D, AveragePooling2D
|
| 22 |
-
from tensorflow.keras.layers import BatchNormalization
|
| 23 |
-
from tensorflow.keras.regularizers import l2
|
| 24 |
-
from tensorflow.keras import backend as K
|
| 25 |
-
from tensorflow.keras.optimizers import SGD
|
| 26 |
-
import warnings
|
| 27 |
-
from constraint import tight_frame
|
| 28 |
-
|
| 29 |
-
warnings.filterwarnings("ignore")
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
def model_parseval(weight_decay):
|
| 33 |
-
|
| 34 |
-
model_input = Input(
|
| 35 |
-
shape=(
|
| 36 |
-
32,
|
| 37 |
-
32,
|
| 38 |
-
1,
|
| 39 |
-
)
|
| 40 |
-
)
|
| 41 |
-
model = Conv2D(
|
| 42 |
-
32,
|
| 43 |
-
kernel_size=(3, 3),
|
| 44 |
-
activation="relu",
|
| 45 |
-
input_shape=(32, 32, 1),
|
| 46 |
-
kernel_regularizer=l2(weight_decay),
|
| 47 |
-
kernel_constraint=tight_frame(0.001),
|
| 48 |
-
kernel_initializer="Orthogonal",
|
| 49 |
-
)(model_input)
|
| 50 |
-
model = Conv2D(
|
| 51 |
-
64,
|
| 52 |
-
kernel_size=(3, 3),
|
| 53 |
-
activation="relu",
|
| 54 |
-
kernel_regularizer=l2(weight_decay),
|
| 55 |
-
kernel_initializer="Orthogonal",
|
| 56 |
-
kernel_constraint=tight_frame(0.001),
|
| 57 |
-
)(model)
|
| 58 |
-
model = MaxPooling2D(pool_size=(2, 2))(model)
|
| 59 |
-
model = BatchNormalization()(model)
|
| 60 |
-
model = Conv2D(
|
| 61 |
-
128,
|
| 62 |
-
kernel_size=(3, 3),
|
| 63 |
-
activation="relu",
|
| 64 |
-
kernel_initializer="Orthogonal",
|
| 65 |
-
kernel_regularizer=l2(weight_decay),
|
| 66 |
-
kernel_constraint=tight_frame(0.001),
|
| 67 |
-
)(model)
|
| 68 |
-
model = MaxPooling2D(pool_size=(2, 2))(model)
|
| 69 |
-
model = BatchNormalization()(model)
|
| 70 |
-
model = Conv2D(
|
| 71 |
-
256,
|
| 72 |
-
kernel_size=(3, 3),
|
| 73 |
-
activation="relu",
|
| 74 |
-
kernel_initializer="Orthogonal",
|
| 75 |
-
kernel_regularizer=l2(weight_decay),
|
| 76 |
-
kernel_constraint=tight_frame(0.001),
|
| 77 |
-
)(model)
|
| 78 |
-
model = MaxPooling2D(pool_size=(2, 2))(model)
|
| 79 |
-
model = BatchNormalization()(model)
|
| 80 |
-
model = Flatten()(model)
|
| 81 |
-
model = Dense(4, activation="softmax", kernel_regularizer=l2(weight_decay))(model)
|
| 82 |
-
model = Model(inputs=model_input, outputs=model)
|
| 83 |
-
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/Parseval_Networks/README.md
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
## ParsevalNetworks
|
| 2 |
-
* Orthogonality Constraint
|
| 3 |
-
* Convexity Constraint
|
|
|
|
|
|
|
|
|
|
|
|
models/Parseval_Networks/constraint.py
DELETED
|
@@ -1,81 +0,0 @@
|
|
| 1 |
-
from tensorflow.python.keras.constraints import Constraint
|
| 2 |
-
from tensorflow.python.ops import math_ops, array_ops
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
class TightFrame(Constraint):
|
| 6 |
-
"""
|
| 7 |
-
Parseval (tight) frame contstraint, as introduced in https://arxiv.org/abs/1704.08847
|
| 8 |
-
|
| 9 |
-
Constraints the weight matrix to be a tight frame, so that the Lipschitz
|
| 10 |
-
constant of the layer is <= 1. This increases the robustness of the network
|
| 11 |
-
to adversarial noise.
|
| 12 |
-
|
| 13 |
-
Warning: This constraint simply performs the update step on the weight matrix
|
| 14 |
-
(or the unfolded weight matrix for convolutional layers). Thus, it does not
|
| 15 |
-
handle the necessary scalings for convolutional layers.
|
| 16 |
-
|
| 17 |
-
Args:
|
| 18 |
-
scale (float): Retraction parameter (length of retraction step).
|
| 19 |
-
num_passes (int): Number of retraction steps.
|
| 20 |
-
|
| 21 |
-
Returns:
|
| 22 |
-
Weight matrix after applying regularizer.
|
| 23 |
-
"""
|
| 24 |
-
|
| 25 |
-
def __init__(self, scale, num_passes=1):
|
| 26 |
-
"""[summary]
|
| 27 |
-
|
| 28 |
-
Args:
|
| 29 |
-
scale ([type]): [description]
|
| 30 |
-
num_passes (int, optional): [description]. Defaults to 1.
|
| 31 |
-
|
| 32 |
-
Raises:
|
| 33 |
-
ValueError: [description]
|
| 34 |
-
"""
|
| 35 |
-
self.scale = scale
|
| 36 |
-
|
| 37 |
-
if num_passes < 1:
|
| 38 |
-
raise ValueError(
|
| 39 |
-
"Number of passes cannot be non-positive! (got {})".format(num_passes)
|
| 40 |
-
)
|
| 41 |
-
self.num_passes = num_passes
|
| 42 |
-
|
| 43 |
-
def __call__(self, w):
|
| 44 |
-
"""[summary]
|
| 45 |
-
|
| 46 |
-
Args:
|
| 47 |
-
w ([type]): weight of conv or linear layers
|
| 48 |
-
|
| 49 |
-
Returns:
|
| 50 |
-
[type]: returns new weights
|
| 51 |
-
"""
|
| 52 |
-
transpose_channels = len(w.shape) == 4
|
| 53 |
-
|
| 54 |
-
# Move channels_num to the front in order to make the dimensions correct for matmul
|
| 55 |
-
if transpose_channels:
|
| 56 |
-
w_reordered = array_ops.reshape(w, (-1, w.shape[3]))
|
| 57 |
-
|
| 58 |
-
else:
|
| 59 |
-
w_reordered = w
|
| 60 |
-
|
| 61 |
-
last = w_reordered
|
| 62 |
-
for i in range(self.num_passes):
|
| 63 |
-
temp1 = math_ops.matmul(last, last, transpose_a=True)
|
| 64 |
-
temp2 = (1 + self.scale) * w_reordered - self.scale * math_ops.matmul(
|
| 65 |
-
w_reordered, temp1
|
| 66 |
-
)
|
| 67 |
-
|
| 68 |
-
last = temp2
|
| 69 |
-
|
| 70 |
-
# Move channels_num to the back again
|
| 71 |
-
if transpose_channels:
|
| 72 |
-
return array_ops.reshape(last, w.shape)
|
| 73 |
-
else:
|
| 74 |
-
return last
|
| 75 |
-
|
| 76 |
-
def get_config(self):
|
| 77 |
-
return {"scale": self.scale, "num_passes": self.num_passes}
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
# Alias
|
| 81 |
-
tight_frame = TightFrame
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/Parseval_Networks/convexity_constraint.py
DELETED
|
@@ -1,53 +0,0 @@
|
|
| 1 |
-
from tensorflow.python.ops import math_ops
|
| 2 |
-
from tensorflow.python.ops import variables
|
| 3 |
-
from tensorflow.python.framework import dtypes
|
| 4 |
-
import numpy as _np
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
def convex_add(input_layer, layer_3, initial_convex_par=0.5, trainable=False):
|
| 8 |
-
"""
|
| 9 |
-
Do a convex combination of input_layer and layer_3. That is, return the output of
|
| 10 |
-
|
| 11 |
-
lamda* input_layer + (1 - lamda) * layer_3
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
Args:
|
| 15 |
-
input_layer (tf.Tensor): Input to take convex combinatio of
|
| 16 |
-
layer_3 (tf.Tensor): Input to take convex combinatio of
|
| 17 |
-
initial_convex_par (float): Initial value for convex parameter. Must be
|
| 18 |
-
in [0, 1].
|
| 19 |
-
trainable (bool): Whether convex parameter should be trainable
|
| 20 |
-
or not.
|
| 21 |
-
|
| 22 |
-
Returns:
|
| 23 |
-
tf.Tensor: Result of convex combination
|
| 24 |
-
"""
|
| 25 |
-
# Will implement this as sigmoid(p)*input_layer + (1-sigmoid(p))*layer_3 to ensure
|
| 26 |
-
# convex parameter to be in the unit interval without constraints during
|
| 27 |
-
# optimization
|
| 28 |
-
|
| 29 |
-
# Find value for p, also check for legal initial_convex_par
|
| 30 |
-
if initial_convex_par < 0:
|
| 31 |
-
raise ValueError("Convex parameter must be >=0")
|
| 32 |
-
|
| 33 |
-
elif initial_convex_par == 0:
|
| 34 |
-
# sigmoid(-16) is approximately a 32bit roundoff error, practically 0
|
| 35 |
-
initial_p_value = -16
|
| 36 |
-
|
| 37 |
-
elif initial_convex_par < 1:
|
| 38 |
-
# Compute inverse of sigmoid to find initial p value
|
| 39 |
-
initial_p_value = -_np.log(1 / initial_convex_par - 1)
|
| 40 |
-
|
| 41 |
-
elif initial_convex_par == 1:
|
| 42 |
-
# Same argument as for 0
|
| 43 |
-
initial_p_value = 16
|
| 44 |
-
|
| 45 |
-
else:
|
| 46 |
-
raise ValueError("Convex parameter must be <=1")
|
| 47 |
-
|
| 48 |
-
p = variables.Variable(
|
| 49 |
-
initial_value=initial_p_value, dtype=dtypes.float32, trainable=trainable
|
| 50 |
-
)
|
| 51 |
-
|
| 52 |
-
lam = math_ops.sigmoid(p)
|
| 53 |
-
return input_layer * lam + (1 - lam) * layer_3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/Parseval_Networks/parsevalnet.py
DELETED
|
@@ -1,328 +0,0 @@
|
|
| 1 |
-
from tensorflow.keras.models import Model
|
| 2 |
-
from tensorflow.keras.layers import Input, Add, Activation, Dropout, Flatten, Dense
|
| 3 |
-
from tensorflow.keras.layers import Convolution2D, MaxPooling2D, AveragePooling2D
|
| 4 |
-
from tensorflow.keras.layers import BatchNormalization
|
| 5 |
-
from tensorflow.keras.regularizers import l2
|
| 6 |
-
from tensorflow.keras import backend as K
|
| 7 |
-
from tensorflow.keras.optimizers import SGD
|
| 8 |
-
import warnings
|
| 9 |
-
from constraint import tight_frame
|
| 10 |
-
from convexity_constraint import convex_add
|
| 11 |
-
|
| 12 |
-
warnings.filterwarnings("ignore")
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
class ParsevalNetwork(Model):
|
| 16 |
-
def __init__(
|
| 17 |
-
self,
|
| 18 |
-
input_dim,
|
| 19 |
-
weight_decay,
|
| 20 |
-
momentum,
|
| 21 |
-
nb_classes=4,
|
| 22 |
-
N=2,
|
| 23 |
-
k=1,
|
| 24 |
-
dropout=0.0,
|
| 25 |
-
verbose=1,
|
| 26 |
-
):
|
| 27 |
-
"""[Assign the initial parameters of the wide residual network]
|
| 28 |
-
|
| 29 |
-
Args:
|
| 30 |
-
weight_decay ([float]): [description]
|
| 31 |
-
input_dim ([tuple]): [input dimension]
|
| 32 |
-
nb_classes (int, optional): [output class]. Defaults to 4.
|
| 33 |
-
N (int, optional): [the number of blocks]. Defaults to 2.
|
| 34 |
-
k (int, optional): [network width]. Defaults to 1.
|
| 35 |
-
dropout (float, optional): [dropout value to prevent overfitting]. Defaults to 0.0.
|
| 36 |
-
verbose (int, optional): [description]. Defaults to 1.
|
| 37 |
-
|
| 38 |
-
Returns:
|
| 39 |
-
[Model]: [parsevalnetwork]
|
| 40 |
-
"""
|
| 41 |
-
self.weight_decay = weight_decay
|
| 42 |
-
self.input_dim = input_dim
|
| 43 |
-
self.nb_classes = nb_classes
|
| 44 |
-
self.N = N
|
| 45 |
-
self.k = k
|
| 46 |
-
self.dropout = dropout
|
| 47 |
-
self.verbose = verbose
|
| 48 |
-
|
| 49 |
-
def initial_conv(self, input):
|
| 50 |
-
"""[summary]
|
| 51 |
-
|
| 52 |
-
Args:
|
| 53 |
-
input ([type]): [description]
|
| 54 |
-
|
| 55 |
-
Returns:
|
| 56 |
-
[type]: [description]
|
| 57 |
-
"""
|
| 58 |
-
x = Convolution2D(
|
| 59 |
-
16,
|
| 60 |
-
(3, 3),
|
| 61 |
-
padding="same",
|
| 62 |
-
kernel_initializer="orthogonal",
|
| 63 |
-
kernel_regularizer=l2(self.weight_decay),
|
| 64 |
-
kernel_constraint=tight_frame(0.001),
|
| 65 |
-
use_bias=False,
|
| 66 |
-
)(input)
|
| 67 |
-
|
| 68 |
-
channel_axis = 1 if K.image_data_format() == "channels_first" else -1
|
| 69 |
-
|
| 70 |
-
x = BatchNormalization(
|
| 71 |
-
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
|
| 72 |
-
)(x)
|
| 73 |
-
x = Activation("relu")(x)
|
| 74 |
-
return x
|
| 75 |
-
|
| 76 |
-
def expand_conv(self, init, base, k, strides=(1, 1)):
|
| 77 |
-
"""[summary]
|
| 78 |
-
|
| 79 |
-
Args:
|
| 80 |
-
init ([type]): [description]
|
| 81 |
-
base ([type]): [description]
|
| 82 |
-
k ([type]): [description]
|
| 83 |
-
strides (tuple, optional): [description]. Defaults to (1, 1).
|
| 84 |
-
|
| 85 |
-
Returns:
|
| 86 |
-
[type]: [description]
|
| 87 |
-
"""
|
| 88 |
-
x = Convolution2D(
|
| 89 |
-
base * k,
|
| 90 |
-
(3, 3),
|
| 91 |
-
padding="same",
|
| 92 |
-
strides=strides,
|
| 93 |
-
kernel_initializer="Orthogonal",
|
| 94 |
-
kernel_regularizer=l2(self.weight_decay),
|
| 95 |
-
kernel_constraint=tight_frame(0.001),
|
| 96 |
-
use_bias=False,
|
| 97 |
-
)(init)
|
| 98 |
-
|
| 99 |
-
channel_axis = 1 if K.image_data_format() == "channels_first" else -1
|
| 100 |
-
|
| 101 |
-
x = BatchNormalization(
|
| 102 |
-
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
|
| 103 |
-
)(x)
|
| 104 |
-
x = Activation("relu")(x)
|
| 105 |
-
|
| 106 |
-
x = Convolution2D(
|
| 107 |
-
base * k,
|
| 108 |
-
(3, 3),
|
| 109 |
-
padding="same",
|
| 110 |
-
kernel_initializer="Orthogonal",
|
| 111 |
-
kernel_regularizer=l2(self.weight_decay),
|
| 112 |
-
kernel_constraint=tight_frame(0.001),
|
| 113 |
-
use_bias=False,
|
| 114 |
-
)(x)
|
| 115 |
-
|
| 116 |
-
skip = Convolution2D(
|
| 117 |
-
base * k,
|
| 118 |
-
(1, 1),
|
| 119 |
-
padding="same",
|
| 120 |
-
strides=strides,
|
| 121 |
-
kernel_initializer="Orthogonal",
|
| 122 |
-
kernel_regularizer=l2(self.weight_decay),
|
| 123 |
-
kernel_constraint=tight_frame(0.001),
|
| 124 |
-
use_bias=False,
|
| 125 |
-
)(init)
|
| 126 |
-
|
| 127 |
-
m = Add()([x, skip])
|
| 128 |
-
|
| 129 |
-
return m
|
| 130 |
-
|
| 131 |
-
def conv1_block(self, input, k=1, dropout=0.0):
|
| 132 |
-
"""[summary]
|
| 133 |
-
|
| 134 |
-
Args:
|
| 135 |
-
input ([type]): [description]
|
| 136 |
-
k (int, optional): [description]. Defaults to 1.
|
| 137 |
-
dropout (float, optional): [description]. Defaults to 0.0.
|
| 138 |
-
|
| 139 |
-
Returns:
|
| 140 |
-
[type]: [description]
|
| 141 |
-
"""
|
| 142 |
-
init = input
|
| 143 |
-
|
| 144 |
-
channel_axis = 1 if K.image_data_format() == "channels_first" else -1
|
| 145 |
-
|
| 146 |
-
x = BatchNormalization(
|
| 147 |
-
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
|
| 148 |
-
)(input)
|
| 149 |
-
x = Activation("relu")(x)
|
| 150 |
-
x = Convolution2D(
|
| 151 |
-
16 * k,
|
| 152 |
-
(3, 3),
|
| 153 |
-
padding="same",
|
| 154 |
-
kernel_initializer="Orthogonal",
|
| 155 |
-
kernel_regularizer=l2(self.weight_decay),
|
| 156 |
-
kernel_constraint=tight_frame(0.001),
|
| 157 |
-
use_bias=False,
|
| 158 |
-
)(x)
|
| 159 |
-
|
| 160 |
-
if dropout > 0.0:
|
| 161 |
-
x = Dropout(dropout)(x)
|
| 162 |
-
|
| 163 |
-
x = BatchNormalization(
|
| 164 |
-
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
|
| 165 |
-
)(x)
|
| 166 |
-
x = Activation("relu")(x)
|
| 167 |
-
x = Convolution2D(
|
| 168 |
-
16 * k,
|
| 169 |
-
(3, 3),
|
| 170 |
-
padding="same",
|
| 171 |
-
kernel_initializer="Orthogonal",
|
| 172 |
-
kernel_regularizer=l2(self.weight_decay),
|
| 173 |
-
kernel_constraint=tight_frame(0.001),
|
| 174 |
-
use_bias=False,
|
| 175 |
-
)(x)
|
| 176 |
-
m = convex_add(init, x, initial_convex_par=0.5, trainable=True)
|
| 177 |
-
return m
|
| 178 |
-
|
| 179 |
-
def conv2_block(self, input, k=1, dropout=0.0):
|
| 180 |
-
"""[summary]
|
| 181 |
-
|
| 182 |
-
Args:
|
| 183 |
-
input ([type]): [description]
|
| 184 |
-
k (int, optional): [description]. Defaults to 1.
|
| 185 |
-
dropout (float, optional): [description]. Defaults to 0.0.
|
| 186 |
-
|
| 187 |
-
Returns:
|
| 188 |
-
[type]: [description]
|
| 189 |
-
"""
|
| 190 |
-
init = input
|
| 191 |
-
|
| 192 |
-
channel_axis = 1 if K.image_data_format() == "channels_first" else -1
|
| 193 |
-
x = BatchNormalization(
|
| 194 |
-
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
|
| 195 |
-
)(input)
|
| 196 |
-
x = Activation("relu")(x)
|
| 197 |
-
x = Convolution2D(
|
| 198 |
-
32 * k,
|
| 199 |
-
(3, 3),
|
| 200 |
-
padding="same",
|
| 201 |
-
kernel_initializer="Orthogonal",
|
| 202 |
-
kernel_regularizer=l2(self.weight_decay),
|
| 203 |
-
kernel_constraint=tight_frame(0.001),
|
| 204 |
-
use_bias=False,
|
| 205 |
-
)(x)
|
| 206 |
-
|
| 207 |
-
if dropout > 0.0:
|
| 208 |
-
x = Dropout(dropout)(x)
|
| 209 |
-
|
| 210 |
-
x = BatchNormalization(
|
| 211 |
-
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
|
| 212 |
-
)(x)
|
| 213 |
-
x = Activation("relu")(x)
|
| 214 |
-
x = Convolution2D(
|
| 215 |
-
32 * k,
|
| 216 |
-
(3, 3),
|
| 217 |
-
padding="same",
|
| 218 |
-
kernel_initializer="Orthogonal",
|
| 219 |
-
kernel_regularizer=l2(self.weight_decay),
|
| 220 |
-
kernel_constraint=tight_frame(0.001),
|
| 221 |
-
use_bias=False,
|
| 222 |
-
)(x)
|
| 223 |
-
|
| 224 |
-
m = convex_add(init, x, initial_convex_par=0.5, trainable=True)
|
| 225 |
-
return m
|
| 226 |
-
|
| 227 |
-
def conv3_block(self, input, k=1, dropout=0.0):
|
| 228 |
-
init = input
|
| 229 |
-
|
| 230 |
-
channel_axis = 1 if K.image_data_format() == "channels_first" else -1
|
| 231 |
-
x = BatchNormalization(
|
| 232 |
-
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
|
| 233 |
-
)(input)
|
| 234 |
-
x = Activation("relu")(x)
|
| 235 |
-
x = Convolution2D(
|
| 236 |
-
64 * k,
|
| 237 |
-
(3, 3),
|
| 238 |
-
padding="same",
|
| 239 |
-
kernel_initializer="Orthogonal",
|
| 240 |
-
kernel_constraint=tight_frame(0.001),
|
| 241 |
-
kernel_regularizer=l2(self.weight_decay),
|
| 242 |
-
use_bias=False,
|
| 243 |
-
)(x)
|
| 244 |
-
|
| 245 |
-
if dropout > 0.0:
|
| 246 |
-
x = Dropout(dropout)(x)
|
| 247 |
-
|
| 248 |
-
x = BatchNormalization(
|
| 249 |
-
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
|
| 250 |
-
)(x)
|
| 251 |
-
x = Activation("relu")(x)
|
| 252 |
-
x = Convolution2D(
|
| 253 |
-
64 * k,
|
| 254 |
-
(3, 3),
|
| 255 |
-
padding="same",
|
| 256 |
-
kernel_initializer="Orthogonal",
|
| 257 |
-
kernel_constraint=tight_frame(0.001),
|
| 258 |
-
kernel_regularizer=l2(self.weight_decay),
|
| 259 |
-
use_bias=False,
|
| 260 |
-
)(x)
|
| 261 |
-
|
| 262 |
-
m = convex_add(init, x, initial_convex_par=0.5, trainable=True)
|
| 263 |
-
return m
|
| 264 |
-
|
| 265 |
-
def create_wide_residual_network(self):
|
| 266 |
-
"""create a wide residual network model
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
Returns:
|
| 270 |
-
[Model]: [wide residual network]
|
| 271 |
-
"""
|
| 272 |
-
channel_axis = 1 if K.image_data_format() == "channels_first" else -1
|
| 273 |
-
|
| 274 |
-
ip = Input(shape=self.input_dim)
|
| 275 |
-
|
| 276 |
-
x = self.initial_conv(ip)
|
| 277 |
-
nb_conv = 4
|
| 278 |
-
|
| 279 |
-
x = self.expand_conv(x, 16, self.k)
|
| 280 |
-
nb_conv += 2
|
| 281 |
-
|
| 282 |
-
for i in range(self.N - 1):
|
| 283 |
-
x = self.conv1_block(x, self.k, self.dropout)
|
| 284 |
-
nb_conv += 2
|
| 285 |
-
|
| 286 |
-
x = BatchNormalization(
|
| 287 |
-
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
|
| 288 |
-
)(x)
|
| 289 |
-
x = Activation("relu")(x)
|
| 290 |
-
|
| 291 |
-
x = self.expand_conv(x, 32, self.k, strides=(2, 2))
|
| 292 |
-
nb_conv += 2
|
| 293 |
-
|
| 294 |
-
for i in range(self.N - 1):
|
| 295 |
-
x = self.conv2_block(x, self.k, self.dropout)
|
| 296 |
-
nb_conv += 2
|
| 297 |
-
|
| 298 |
-
x = BatchNormalization(
|
| 299 |
-
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
|
| 300 |
-
)(x)
|
| 301 |
-
x = Activation("relu")(x)
|
| 302 |
-
|
| 303 |
-
x = self.expand_conv(x, 64, self.k, strides=(2, 2))
|
| 304 |
-
nb_conv += 2
|
| 305 |
-
|
| 306 |
-
for i in range(self.N - 1):
|
| 307 |
-
x = self.conv3_block(x, self.k, self.dropout)
|
| 308 |
-
nb_conv += 2
|
| 309 |
-
|
| 310 |
-
x = BatchNormalization(
|
| 311 |
-
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
|
| 312 |
-
)(x)
|
| 313 |
-
x = Activation("relu")(x)
|
| 314 |
-
|
| 315 |
-
x = AveragePooling2D((8, 8))(x)
|
| 316 |
-
x = Flatten()(x)
|
| 317 |
-
|
| 318 |
-
x = Dense(
|
| 319 |
-
self.nb_classes,
|
| 320 |
-
kernel_regularizer=l2(self.weight_decay),
|
| 321 |
-
activation="softmax",
|
| 322 |
-
)(x)
|
| 323 |
-
|
| 324 |
-
model = Model(ip, x)
|
| 325 |
-
|
| 326 |
-
if self.verbose:
|
| 327 |
-
print("Parseval Network-%d-%d created." % (nb_conv, self.k))
|
| 328 |
-
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/README.md
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
## Models
|
| 2 |
-
|
| 3 |
-
````
|
| 4 |
-
├── Parseval_network
|
| 5 |
-
│ ├── __init__.py
|
| 6 |
-
│ └── Parseval_resnet.py
|
| 7 |
-
├── Parseval_Networks_OC
|
| 8 |
-
│ ├── constraint.py
|
| 9 |
-
│ ├── parsnet_oc.py
|
| 10 |
-
│ └── README.md
|
| 11 |
-
├── README.md
|
| 12 |
-
├── _utility.py
|
| 13 |
-
└── wideresnet
|
| 14 |
-
└── wresnet.py
|
| 15 |
-
````
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/_utility.py
DELETED
|
@@ -1,109 +0,0 @@
|
|
| 1 |
-
from tensorflow.keras.callbacks import LearningRateScheduler
|
| 2 |
-
|
| 3 |
-
# Define configuration parameters
|
| 4 |
-
import math
|
| 5 |
-
import cleverhans
|
| 6 |
-
from cleverhans.tf2.attacks.fast_gradient_method import fast_gradient_method
|
| 7 |
-
import tensorflow as tf
|
| 8 |
-
|
| 9 |
-
import numpy as np
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
def step_decay(epoch):
|
| 13 |
-
"""[summary]
|
| 14 |
-
|
| 15 |
-
Args:
|
| 16 |
-
epoch (int): epoch number
|
| 17 |
-
|
| 18 |
-
Returns:
|
| 19 |
-
lrate(float): new learning rate
|
| 20 |
-
"""
|
| 21 |
-
initial_lrate = 0.1
|
| 22 |
-
factor = 0.1
|
| 23 |
-
if epoch < 10:
|
| 24 |
-
lrate = initial_lrate
|
| 25 |
-
elif epoch < 20:
|
| 26 |
-
lrate = initial_lrate * math.pow(factor, 1)
|
| 27 |
-
elif epoch < 30:
|
| 28 |
-
lrate = initial_lrate * math.pow(factor, 2)
|
| 29 |
-
elif epoch < 40:
|
| 30 |
-
lrate = initial_lrate * math.pow(factor, 3)
|
| 31 |
-
else:
|
| 32 |
-
lrate = initial_lrate * math.pow(factor, 4)
|
| 33 |
-
return lrate
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
def step_decay_conv(epoch):
|
| 37 |
-
"""step decay for learning rate in convolutional networks
|
| 38 |
-
|
| 39 |
-
Args:
|
| 40 |
-
epoch (int): epoch number
|
| 41 |
-
|
| 42 |
-
Returns:
|
| 43 |
-
lrate(float): new learning rate
|
| 44 |
-
"""
|
| 45 |
-
initial_lrate = 0.01
|
| 46 |
-
factor = 0.1
|
| 47 |
-
if epoch < 10:
|
| 48 |
-
lrate = initial_lrate
|
| 49 |
-
elif epoch < 20:
|
| 50 |
-
lrate = initial_lrate * math.pow(factor, 1)
|
| 51 |
-
elif epoch < 30:
|
| 52 |
-
lrate = initial_lrate * math.pow(factor, 2)
|
| 53 |
-
elif epoch < 40:
|
| 54 |
-
lrate = initial_lrate * math.pow(factor, 3)
|
| 55 |
-
else:
|
| 56 |
-
lrate = initial_lrate * math.pow(factor, 4)
|
| 57 |
-
return lrate
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
def print_test(model, X_adv, X_test, y_test, epsilon):
|
| 61 |
-
"""
|
| 62 |
-
returns the test results and show the SNR and evaluation results
|
| 63 |
-
"""
|
| 64 |
-
loss, acc = model.evaluate(X_adv, y_test)
|
| 65 |
-
print("epsilon: {} and test evaluation : {}, {}".format(epsilon, loss, acc))
|
| 66 |
-
SNR = 20 * np.log10(np.linalg.norm(X_test) / np.linalg.norm(X_test - X_adv))
|
| 67 |
-
print("SNR: {}".format(SNR))
|
| 68 |
-
return loss, acc
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
def get_adversarial_examples(pretrained_model, X_true, y_true, epsilon):
|
| 72 |
-
"""
|
| 73 |
-
The attack requires the model to ouput the logits
|
| 74 |
-
returns the adversarial example/s of a given image/s for epsilon value using
|
| 75 |
-
fast gradient sign method
|
| 76 |
-
"""
|
| 77 |
-
logits_model = tf.keras.Model(
|
| 78 |
-
pretrained_model.input, pretrained_model.layers[-1].output
|
| 79 |
-
)
|
| 80 |
-
X_adv = []
|
| 81 |
-
|
| 82 |
-
for i in range(len(X_true)):
|
| 83 |
-
|
| 84 |
-
random_index = i
|
| 85 |
-
|
| 86 |
-
original_image = X_true[random_index]
|
| 87 |
-
original_image = tf.convert_to_tensor(
|
| 88 |
-
original_image.reshape((1, 32, 32))
|
| 89 |
-
) # The .reshape just gives it the proper form to input into the model, a batch of 1 a.k.a a tensor
|
| 90 |
-
original_label = y_true[random_index]
|
| 91 |
-
original_label = np.reshape(np.argmax(original_label), (1,)).astype("int64")
|
| 92 |
-
|
| 93 |
-
adv_example_targeted_label = fast_gradient_method(
|
| 94 |
-
logits_model,
|
| 95 |
-
original_image,
|
| 96 |
-
epsilon,
|
| 97 |
-
np.inf,
|
| 98 |
-
y=original_label,
|
| 99 |
-
targeted=False,
|
| 100 |
-
)
|
| 101 |
-
X_adv.append(np.array(adv_example_targeted_label).reshape(32, 32, 1))
|
| 102 |
-
|
| 103 |
-
X_adv = np.array(X_adv)
|
| 104 |
-
|
| 105 |
-
return X_adv
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
lrate_conv = LearningRateScheduler(step_decay_conv)
|
| 109 |
-
lrate = LearningRateScheduler(step_decay)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/wideresnet/wresnet.py
DELETED
|
@@ -1,329 +0,0 @@
|
|
| 1 |
-
from tensorflow.keras.models import Model
|
| 2 |
-
from tensorflow.keras.layers import Input, Add, Activation, Dropout, Flatten, Dense
|
| 3 |
-
from tensorflow.keras.layers import Convolution2D, MaxPooling2D, AveragePooling2D
|
| 4 |
-
from tensorflow.keras.layers import BatchNormalization
|
| 5 |
-
from tensorflow.keras.regularizers import l2
|
| 6 |
-
from tensorflow.keras import backend as K
|
| 7 |
-
from tensorflow.keras.optimizers import SGD
|
| 8 |
-
import warnings
|
| 9 |
-
|
| 10 |
-
warnings.filterwarnings("ignore")
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
class WideResidualNetwork(object):
|
| 14 |
-
def __init__(
|
| 15 |
-
self,
|
| 16 |
-
input_dim,
|
| 17 |
-
weight_decay,
|
| 18 |
-
momentum,
|
| 19 |
-
nb_classes=100,
|
| 20 |
-
N=2,
|
| 21 |
-
k=1,
|
| 22 |
-
dropout=0.0,
|
| 23 |
-
verbose=1,
|
| 24 |
-
):
|
| 25 |
-
"""[Assign the initial parameters of the wide residual network]
|
| 26 |
-
|
| 27 |
-
Args:
|
| 28 |
-
weight_decay ([float]): [description]
|
| 29 |
-
input_dim ([tuple]): [input dimension]
|
| 30 |
-
nb_classes (int, optional): [output class]. Defaults to 100.
|
| 31 |
-
N (int, optional): [the number of blocks]. Defaults to 2.
|
| 32 |
-
k (int, optional): [network width]. Defaults to 1.
|
| 33 |
-
dropout (float, optional): [dropout value to prevent overfitting]. Defaults to 0.0.
|
| 34 |
-
verbose (int, optional): [description]. Defaults to 1.
|
| 35 |
-
|
| 36 |
-
Returns:
|
| 37 |
-
[Model]: [wideresnet]
|
| 38 |
-
"""
|
| 39 |
-
self.weight_decay = weight_decay
|
| 40 |
-
self.input_dim = input_dim
|
| 41 |
-
self.nb_classes = nb_classes
|
| 42 |
-
self.N = N
|
| 43 |
-
self.k = k
|
| 44 |
-
self.dropout = dropout
|
| 45 |
-
self.verbose = verbose
|
| 46 |
-
|
| 47 |
-
def initial_conv(self, input):
|
| 48 |
-
"""[summary]
|
| 49 |
-
|
| 50 |
-
Args:
|
| 51 |
-
input ([type]): [description]
|
| 52 |
-
|
| 53 |
-
Returns:
|
| 54 |
-
[type]: [description]
|
| 55 |
-
"""
|
| 56 |
-
x = Convolution2D(
|
| 57 |
-
16,
|
| 58 |
-
(3, 3),
|
| 59 |
-
padding="same",
|
| 60 |
-
kernel_initializer="he_normal",
|
| 61 |
-
kernel_regularizer=l2(self.weight_decay),
|
| 62 |
-
use_bias=False,
|
| 63 |
-
)(input)
|
| 64 |
-
|
| 65 |
-
channel_axis = 1 if K.image_data_format() == "channels_first" else -1
|
| 66 |
-
|
| 67 |
-
x = BatchNormalization(
|
| 68 |
-
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
|
| 69 |
-
)(x)
|
| 70 |
-
x = Activation("relu")(x)
|
| 71 |
-
return x
|
| 72 |
-
|
| 73 |
-
def expand_conv(self, init, base, k, strides=(1, 1)):
|
| 74 |
-
"""[summary]
|
| 75 |
-
|
| 76 |
-
Args:
|
| 77 |
-
init ([type]): [description]
|
| 78 |
-
base ([type]): [description]
|
| 79 |
-
k ([type]): [description]
|
| 80 |
-
strides (tuple, optional): [description]. Defaults to (1, 1).
|
| 81 |
-
|
| 82 |
-
Returns:
|
| 83 |
-
[type]: [description]
|
| 84 |
-
"""
|
| 85 |
-
x = Convolution2D(
|
| 86 |
-
base * k,
|
| 87 |
-
(3, 3),
|
| 88 |
-
padding="same",
|
| 89 |
-
strides=strides,
|
| 90 |
-
kernel_initializer="he_normal",
|
| 91 |
-
kernel_regularizer=l2(self.weight_decay),
|
| 92 |
-
use_bias=False,
|
| 93 |
-
)(init)
|
| 94 |
-
|
| 95 |
-
channel_axis = 1 if K.image_data_format() == "channels_first" else -1
|
| 96 |
-
|
| 97 |
-
x = BatchNormalization(
|
| 98 |
-
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
|
| 99 |
-
)(x)
|
| 100 |
-
x = Activation("relu")(x)
|
| 101 |
-
|
| 102 |
-
x = Convolution2D(
|
| 103 |
-
base * k,
|
| 104 |
-
(3, 3),
|
| 105 |
-
padding="same",
|
| 106 |
-
kernel_initializer="he_normal",
|
| 107 |
-
kernel_regularizer=l2(self.weight_decay),
|
| 108 |
-
use_bias=False,
|
| 109 |
-
)(x)
|
| 110 |
-
|
| 111 |
-
skip = Convolution2D(
|
| 112 |
-
base * k,
|
| 113 |
-
(1, 1),
|
| 114 |
-
padding="same",
|
| 115 |
-
strides=strides,
|
| 116 |
-
kernel_initializer="he_normal",
|
| 117 |
-
kernel_regularizer=l2(self.weight_decay),
|
| 118 |
-
use_bias=False,
|
| 119 |
-
)(init)
|
| 120 |
-
|
| 121 |
-
m = Add()([x, skip])
|
| 122 |
-
|
| 123 |
-
return m
|
| 124 |
-
|
| 125 |
-
def conv1_block(self, input, k=1, dropout=0.0):
|
| 126 |
-
"""[summary]
|
| 127 |
-
|
| 128 |
-
Args:
|
| 129 |
-
input ([type]): [description]
|
| 130 |
-
k (int, optional): [description]. Defaults to 1.
|
| 131 |
-
dropout (float, optional): [description]. Defaults to 0.0.
|
| 132 |
-
|
| 133 |
-
Returns:
|
| 134 |
-
[type]: [description]
|
| 135 |
-
"""
|
| 136 |
-
init = input
|
| 137 |
-
|
| 138 |
-
channel_axis = 1 if K.image_data_format() == "channels_first" else -1
|
| 139 |
-
|
| 140 |
-
x = BatchNormalization(
|
| 141 |
-
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
|
| 142 |
-
)(input)
|
| 143 |
-
x = Activation("relu")(x)
|
| 144 |
-
x = Convolution2D(
|
| 145 |
-
16 * k,
|
| 146 |
-
(3, 3),
|
| 147 |
-
padding="same",
|
| 148 |
-
kernel_initializer="he_normal",
|
| 149 |
-
kernel_regularizer=l2(self.weight_decay),
|
| 150 |
-
use_bias=False,
|
| 151 |
-
)(x)
|
| 152 |
-
|
| 153 |
-
if dropout > 0.0:
|
| 154 |
-
x = Dropout(dropout)(x)
|
| 155 |
-
|
| 156 |
-
x = BatchNormalization(
|
| 157 |
-
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
|
| 158 |
-
)(x)
|
| 159 |
-
x = Activation("relu")(x)
|
| 160 |
-
x = Convolution2D(
|
| 161 |
-
16 * k,
|
| 162 |
-
(3, 3),
|
| 163 |
-
padding="same",
|
| 164 |
-
kernel_initializer="he_normal",
|
| 165 |
-
kernel_regularizer=l2(self.weight_decay),
|
| 166 |
-
use_bias=False,
|
| 167 |
-
)(x)
|
| 168 |
-
|
| 169 |
-
m = Add()([init, x])
|
| 170 |
-
return m
|
| 171 |
-
|
| 172 |
-
def conv2_block(self, input, k=1, dropout=0.0):
|
| 173 |
-
"""[summary]
|
| 174 |
-
|
| 175 |
-
Args:
|
| 176 |
-
input ([type]): [description]
|
| 177 |
-
k (int, optional): [description]. Defaults to 1.
|
| 178 |
-
dropout (float, optional): [description]. Defaults to 0.0.
|
| 179 |
-
|
| 180 |
-
Returns:
|
| 181 |
-
[type]: [description]
|
| 182 |
-
"""
|
| 183 |
-
init = input
|
| 184 |
-
|
| 185 |
-
channel_axis = 1 if K.image_data_format() == "channels_first" else -1
|
| 186 |
-
print("conv2:channel: {}".format(channel_axis))
|
| 187 |
-
x = BatchNormalization(
|
| 188 |
-
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
|
| 189 |
-
)(input)
|
| 190 |
-
x = Activation("relu")(x)
|
| 191 |
-
x = Convolution2D(
|
| 192 |
-
32 * k,
|
| 193 |
-
(3, 3),
|
| 194 |
-
padding="same",
|
| 195 |
-
kernel_initializer="he_normal",
|
| 196 |
-
kernel_regularizer=l2(self.weight_decay),
|
| 197 |
-
use_bias=False,
|
| 198 |
-
)(x)
|
| 199 |
-
|
| 200 |
-
if dropout > 0.0:
|
| 201 |
-
x = Dropout(dropout)(x)
|
| 202 |
-
|
| 203 |
-
x = BatchNormalization(
|
| 204 |
-
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
|
| 205 |
-
)(x)
|
| 206 |
-
x = Activation("relu")(x)
|
| 207 |
-
x = Convolution2D(
|
| 208 |
-
32 * k,
|
| 209 |
-
(3, 3),
|
| 210 |
-
padding="same",
|
| 211 |
-
kernel_initializer="he_normal",
|
| 212 |
-
kernel_regularizer=l2(self.weight_decay),
|
| 213 |
-
use_bias=False,
|
| 214 |
-
)(x)
|
| 215 |
-
|
| 216 |
-
m = Add()([init, x])
|
| 217 |
-
return m
|
| 218 |
-
|
| 219 |
-
def conv3_block(self, input, k=1, dropout=0.0):
|
| 220 |
-
"""[summary]
|
| 221 |
-
|
| 222 |
-
Args:
|
| 223 |
-
input ([type]): [description]
|
| 224 |
-
k (int, optional): [description]. Defaults to 1.
|
| 225 |
-
dropout (float, optional): [description]. Defaults to 0.0.
|
| 226 |
-
|
| 227 |
-
Returns:
|
| 228 |
-
[type]: [description]
|
| 229 |
-
"""
|
| 230 |
-
init = input
|
| 231 |
-
|
| 232 |
-
channel_axis = 1 if K.image_data_format() == "channels_first" else -1
|
| 233 |
-
|
| 234 |
-
x = BatchNormalization(
|
| 235 |
-
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
|
| 236 |
-
)(input)
|
| 237 |
-
x = Activation("relu")(x)
|
| 238 |
-
x = Convolution2D(
|
| 239 |
-
64 * k,
|
| 240 |
-
(3, 3),
|
| 241 |
-
padding="same",
|
| 242 |
-
kernel_initializer="he_normal",
|
| 243 |
-
kernel_regularizer=l2(self.weight_decay),
|
| 244 |
-
use_bias=False,
|
| 245 |
-
)(x)
|
| 246 |
-
|
| 247 |
-
if dropout > 0.0:
|
| 248 |
-
x = Dropout(dropout)(x)
|
| 249 |
-
|
| 250 |
-
x = BatchNormalization(
|
| 251 |
-
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
|
| 252 |
-
)(x)
|
| 253 |
-
x = Activation("relu")(x)
|
| 254 |
-
x = Convolution2D(
|
| 255 |
-
64 * k,
|
| 256 |
-
(3, 3),
|
| 257 |
-
padding="same",
|
| 258 |
-
kernel_initializer="he_normal",
|
| 259 |
-
kernel_regularizer=l2(self.weight_decay),
|
| 260 |
-
use_bias=False,
|
| 261 |
-
)(x)
|
| 262 |
-
|
| 263 |
-
m = Add()([init, x])
|
| 264 |
-
return m
|
| 265 |
-
|
| 266 |
-
def create_wide_residual_network(self):
|
| 267 |
-
"""create a wide residual network model
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
Returns:
|
| 271 |
-
[Model]: [wide residual network]
|
| 272 |
-
"""
|
| 273 |
-
channel_axis = 1 if K.image_data_format() == "channels_first" else -1
|
| 274 |
-
|
| 275 |
-
ip = Input(shape=self.input_dim)
|
| 276 |
-
|
| 277 |
-
x = self.initial_conv(ip)
|
| 278 |
-
nb_conv = 4
|
| 279 |
-
|
| 280 |
-
x = self.expand_conv(x, 16, self.k)
|
| 281 |
-
nb_conv += 2
|
| 282 |
-
|
| 283 |
-
for i in range(self.N - 1):
|
| 284 |
-
x = self.conv1_block(x, self.k, self.dropout)
|
| 285 |
-
nb_conv += 2
|
| 286 |
-
|
| 287 |
-
x = BatchNormalization(
|
| 288 |
-
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
|
| 289 |
-
)(x)
|
| 290 |
-
x = Activation("relu")(x)
|
| 291 |
-
|
| 292 |
-
x = self.expand_conv(x, 32, self.k, strides=(2, 2))
|
| 293 |
-
nb_conv += 2
|
| 294 |
-
|
| 295 |
-
for i in range(self.N - 1):
|
| 296 |
-
x = self.conv2_block(x, self.k, self.dropout)
|
| 297 |
-
nb_conv += 2
|
| 298 |
-
|
| 299 |
-
x = BatchNormalization(
|
| 300 |
-
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
|
| 301 |
-
)(x)
|
| 302 |
-
x = Activation("relu")(x)
|
| 303 |
-
|
| 304 |
-
x = self.expand_conv(x, 64, self.k, strides=(2, 2))
|
| 305 |
-
nb_conv += 2
|
| 306 |
-
|
| 307 |
-
for i in range(self.N - 1):
|
| 308 |
-
x = self.conv3_block(x, self.k, self.dropout)
|
| 309 |
-
nb_conv += 2
|
| 310 |
-
|
| 311 |
-
x = BatchNormalization(
|
| 312 |
-
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
|
| 313 |
-
)(x)
|
| 314 |
-
x = Activation("relu")(x)
|
| 315 |
-
|
| 316 |
-
x = AveragePooling2D((8, 8))(x)
|
| 317 |
-
x = Flatten()(x)
|
| 318 |
-
|
| 319 |
-
x = Dense(
|
| 320 |
-
self.nb_classes,
|
| 321 |
-
kernel_regularizer=l2(self.weight_decay),
|
| 322 |
-
activation="softmax",
|
| 323 |
-
)(x)
|
| 324 |
-
|
| 325 |
-
model = Model(ip, x)
|
| 326 |
-
|
| 327 |
-
if self.verbose:
|
| 328 |
-
print("Wide Residual Network-%d-%d created." % (nb_conv, self.k))
|
| 329 |
-
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|