/*
 * Decompiled with CFR 0.152.
 */
package amten.ml.examples;

import amten.ml.NNParams;
import amten.ml.NeuralNetwork;
import amten.ml.matrix.Matrix;
import amten.ml.matrix.MatrixUtils;

public class NNClassificationExample {
    public static void runKaggleDigitsClassification(boolean useConvolution) throws Exception {
        NNParams.NNLayerParams[] nNLayerParamsArray;
        if (useConvolution) {
            System.out.println("Running classification on Kaggle Digits dataset, with convolution...\n");
        } else {
            System.out.println("Running classification on Kaggle Digits dataset...\n");
        }
        int headerRows = 1;
        char separator = ',';
        Matrix data = MatrixUtils.readCSV("example_data/Kaggle_Digits_1000.csv", separator, headerRows);
        float crossValidationPercent = 33.0f;
        Matrix[] split = MatrixUtils.split(data, crossValidationPercent, 0.0f);
        Matrix dataTrain = split[0];
        Matrix dataCV = split[1];
        Matrix xTrain = dataTrain.getColumns(1, -1);
        Matrix yTrain = dataTrain.getColumns(0, 0);
        Matrix xCV = dataCV.getColumns(1, -1);
        Matrix yCV = dataCV.getColumns(0, 0);
        NNParams params = new NNParams();
        params.numClasses = 10;
        if (useConvolution) {
            NNParams.NNLayerParams[] nNLayerParamsArray2 = new NNParams.NNLayerParams[2];
            nNLayerParamsArray2[0] = new NNParams.NNLayerParams(20, 5, 5, 2, 2);
            nNLayerParamsArray = nNLayerParamsArray2;
            nNLayerParamsArray2[1] = new NNParams.NNLayerParams(100, 5, 5, 2, 2);
        } else {
            NNParams.NNLayerParams[] nNLayerParamsArray3 = new NNParams.NNLayerParams[1];
            nNLayerParamsArray = nNLayerParamsArray3;
            nNLayerParamsArray3[0] = new NNParams.NNLayerParams(100);
        }
        params.hiddenLayerParams = nNLayerParamsArray;
        params.maxIterations = useConvolution ? 10 : 200;
        params.learningRate = useConvolution ? 0.01 : 0.0;
        long startTime = System.currentTimeMillis();
        NeuralNetwork nn = new NeuralNetwork(params);
        nn.train(xTrain, yTrain);
        System.out.println("\nTraining time: " + String.format("%.3g", (double)(System.currentTimeMillis() - startTime) / 1000.0) + "s");
        int[] predictedClasses = nn.getPredictedClasses(xTrain);
        int correct = 0;
        int i = 0;
        while (i < predictedClasses.length) {
            if ((double)predictedClasses[i] == yTrain.get(i, 0)) {
                ++correct;
            }
            ++i;
        }
        System.out.println("Training set accuracy: " + String.format("%.3g", (double)correct / (double)predictedClasses.length * 100.0) + "%");
        predictedClasses = nn.getPredictedClasses(xCV);
        correct = 0;
        i = 0;
        while (i < predictedClasses.length) {
            if ((double)predictedClasses[i] == yCV.get(i, 0)) {
                ++correct;
            }
            ++i;
        }
        System.out.println("Crossvalidation set accuracy: " + String.format("%.3g", (double)correct / (double)predictedClasses.length * 100.0) + "%");
    }

    public static void runKaggleTitanicClassification() throws Exception {
        System.out.println("Running classification on Kaggle Titanic dataset...\n");
        int headerRows = 1;
        char separator = ',';
        Matrix data = MatrixUtils.readCSV("example_data/Kaggle_Titanic_Cleaned.csv", separator, headerRows);
        float crossValidationPercent = 33.0f;
        Matrix[] split = MatrixUtils.split(data, crossValidationPercent, 0.0f);
        Matrix dataTrain = split[0];
        Matrix dataCV = split[1];
        Matrix xTrain = dataTrain.getColumns(1, -1);
        Matrix yTrain = dataTrain.getColumns(0, 0);
        Matrix xCV = dataCV.getColumns(1, -1);
        Matrix yCV = dataCV.getColumns(0, 0);
        NNParams params = new NNParams();
        params.numCategories = new int[]{3, 2, 1, 1, 1, 1, 3};
        params.numClasses = 2;
        long startTime = System.currentTimeMillis();
        NeuralNetwork nn = new NeuralNetwork(params);
        nn.train(xTrain, yTrain);
        System.out.println("\nTraining time: " + String.format("%.3g", (double)(System.currentTimeMillis() - startTime) / 1000.0) + "s");
        int[] predictedClasses = nn.getPredictedClasses(xTrain);
        int correct = 0;
        int i = 0;
        while (i < predictedClasses.length) {
            if ((double)predictedClasses[i] == yTrain.get(i, 0)) {
                ++correct;
            }
            ++i;
        }
        System.out.println("Training set accuracy: " + String.format("%.3g", (double)correct / (double)predictedClasses.length * 100.0) + "%");
        predictedClasses = nn.getPredictedClasses(xCV);
        correct = 0;
        i = 0;
        while (i < predictedClasses.length) {
            if ((double)predictedClasses[i] == yCV.get(i, 0)) {
                ++correct;
            }
            ++i;
        }
        System.out.println("Crossvalidation set accuracy: " + String.format("%.3g", (double)correct / (double)predictedClasses.length * 100.0) + "%");
    }

    public static void main(String[] args) throws Exception {
        NNClassificationExample.runKaggleDigitsClassification(false);
        System.out.println("\n\n\n");
        NNClassificationExample.runKaggleDigitsClassification(true);
        System.out.println("\n\n\n");
        NNClassificationExample.runKaggleTitanicClassification();
    }
}

