/*
 * Decompiled with CFR 0.152.
 */
package amten.ml.test;

import amten.ml.NNParams;
import amten.ml.NeuralNetwork;
import amten.ml.matrix.Matrix;
import amten.ml.matrix.MatrixUtils;
import java.util.ArrayList;

public class NeuralNetworkTest {
    public static void runKaggleDigitsClassification() throws Exception {
        int i;
        int[] predictedClasses;
        int headerRows = 1;
        char separator = ',';
        Matrix data = MatrixUtils.readCSV("example_data/Kaggle_Digits_1000.csv", separator, headerRows);
        float crossValidationPercent = 33.0f;
        Matrix[] split = MatrixUtils.split(data, crossValidationPercent, 0.0f);
        Matrix dataTrain = split[0];
        Matrix dataCV = split[1];
        Matrix xTrain = dataTrain.getColumns(1, -1);
        Matrix yTrain = dataTrain.getColumns(0, 0);
        Matrix xCV = dataCV.getColumns(1, -1);
        Matrix yCV = dataCV.getColumns(0, 0);
        NNParams params = new NNParams();
        params.numClasses = 10;
        params.hiddenLayerParams = new NNParams.NNLayerParams[]{new NNParams.NNLayerParams(20, 5, 5, 2, 2), new NNParams.NNLayerParams(100, 5, 5, 2, 2)};
        params.learningRate = 0.01;
        params.maxIterations = 10;
        long startTime = System.currentTimeMillis();
        NeuralNetwork nn = new NeuralNetwork(params);
        nn.train(xTrain, yTrain);
        System.out.println("Training time: " + (double)(System.currentTimeMillis() - startTime) / 1000.0 + "s");
        ArrayList<Matrix> batchesX = new ArrayList<Matrix>();
        ArrayList<Matrix> batchesY = new ArrayList<Matrix>();
        MatrixUtils.split(xTrain, yTrain, params.batchSize, batchesX, batchesY);
        int correct = 0;
        int batch = 0;
        while (batch < batchesX.size()) {
            predictedClasses = nn.getPredictedClasses((Matrix)batchesX.get(batch));
            i = 0;
            while (i < predictedClasses.length) {
                if ((double)predictedClasses[i] == ((Matrix)batchesY.get(batch)).get(i, 0)) {
                    ++correct;
                }
                ++i;
            }
            ++batch;
        }
        System.out.println("Training set accuracy: " + (double)correct / (double)xTrain.numRows() * 100.0 + "%");
        batchesX = new ArrayList();
        batchesY = new ArrayList();
        MatrixUtils.split(xCV, yCV, params.batchSize, batchesX, batchesY);
        correct = 0;
        batch = 0;
        while (batch < batchesX.size()) {
            predictedClasses = nn.getPredictedClasses((Matrix)batchesX.get(batch));
            i = 0;
            while (i < predictedClasses.length) {
                if ((double)predictedClasses[i] == ((Matrix)batchesY.get(batch)).get(i, 0)) {
                    ++correct;
                }
                ++i;
            }
            ++batch;
        }
        System.out.println("Crossvalidation set accuracy: " + (double)correct / (double)xCV.numRows() * 100.0 + "%");
    }

    public static void main(String[] args) throws Exception {
        NeuralNetworkTest.runKaggleDigitsClassification();
    }
}

