/*
 * Decompiled with CFR 0.152.
 */
package amten.ml.examples;

import amten.ml.NNParams;
import amten.ml.NeuralNetwork;
import amten.ml.matrix.Matrix;
import amten.ml.matrix.MatrixUtils;

public class NNRegressionExample {
    public static void runCarPricesRegression() throws Exception {
        System.out.println("Running regression on Car Prices dataset...\n");
        int headerRows = 1;
        char separator = ',';
        Matrix data = MatrixUtils.readCSV("example_data/Car_Prices.csv", separator, headerRows);
        float crossValidationPercent = 33.0f;
        Matrix[] split = MatrixUtils.split(data, crossValidationPercent, 0.0f);
        Matrix dataTrain = split[0];
        Matrix dataCV = split[1];
        Matrix xTrain = dataTrain.getColumns(0, 13);
        Matrix yTrain = dataTrain.getColumns(14, 14);
        Matrix xCV = dataCV.getColumns(0, 13);
        Matrix yCV = dataCV.getColumns(14, 14);
        NNParams params = new NNParams();
        long startTime = System.currentTimeMillis();
        NeuralNetwork nn = new NeuralNetwork(params);
        nn.train(xTrain, yTrain);
        System.out.println("\nTraining time: " + String.format("%.3g", (double)(System.currentTimeMillis() - startTime) / 1000.0) + "s");
        Matrix predictions = nn.getPredictions(xTrain);
        double error = 0.0;
        int i = 0;
        while (i < predictions.numRows()) {
            error += Math.pow(predictions.get(i, 0) - yTrain.get(i, 0), 2.0);
            ++i;
        }
        error = Math.sqrt(error / (double)predictions.numRows());
        System.out.println("Training set root mean squared error: " + String.format("%.4g", error));
        predictions = nn.getPredictions(xCV);
        error = 0.0;
        i = 0;
        while (i < predictions.numRows()) {
            error += Math.pow(predictions.get(i, 0) - yCV.get(i, 0), 2.0);
            ++i;
        }
        error = Math.sqrt(error / (double)predictions.numRows());
        System.out.println("Crossvalidation set root mean squared error: " + String.format("%.4g", error));
    }

    public static void main(String[] args) throws Exception {
        NNRegressionExample.runCarPricesRegression();
    }
}

